Add 1x64x32 Joint Matrix support, refactoring

Add 1x64x32 Joint Matrix support.
Refactor load/store/mad built-ins for Joint Matrix.
This commit is contained in:
Yury Plyakhin
2024-04-22 20:11:05 +00:00
committed by igcbot
parent 0ed8916f61
commit 55c27784c7
2 changed files with 165 additions and 147 deletions

View File

@ -57,6 +57,8 @@ SPDX-License-Identifier: MIT
#define IND_VNNI_TX(slid, stride, skip_factor, i, sg_cols) (i + (slid * stride))
// no int7, int6, int5 types
#define VEC_TO_VEC16(type, vec) \
(type##16)(vec.s0, vec.s1, vec.s2, vec.s3, vec.s4, vec.s5, vec.s6, vec.s7, vec.s8, vec.s9, vec.sA, vec.sB, vec.sC, vec.sD, vec.sE, vec.sF)
#define VEC_TO_VEC8(type, vec) \
(type##8)(vec.s0, vec.s1, vec.s2, vec.s3, vec.s4, vec.s5, vec.s6, vec.s7)
#define VEC_TO_VEC7(type, vec) \
@ -71,6 +73,7 @@ SPDX-License-Identifier: MIT
#define VEC_TO_VEC1(type, vec) (type)(vec)
// in case of store, we can not use uint3 with intel_sub_group_block_write4
#define VEC_TO_VEC_STORE16(type, vec) VEC_TO_VEC16(type, vec)
#define VEC_TO_VEC_STORE8(type, vec) VEC_TO_VEC8(type, vec)
#define VEC_TO_VEC_STORE7(type, vec) VEC_TO_VEC7(type, vec)
#define VEC_TO_VEC_STORE6(type, vec) VEC_TO_VEC6(type, vec)
@ -134,6 +137,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
#define OUT_STORE_VEC1(type) type
// Math division macros
#define MATH_128_DIV_4 32
#define MATH_64_DIV_64 1
#define MATH_64_DIV_32 2
#define MATH_64_DIV_16 4
@ -209,6 +213,19 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
#define SHAPE_Accumulator_ColumnMajor(M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT(M, K)
#define SHAPE(layout, M, K, element_type, contrib_type) SHAPE_##layout(M, K, BITWIDTH(element_type), BITWIDTH(contrib_type))
// Get number of 2d block stores needed for a given number of rows.
// R parameter is number of rows.
#define GET_NUM_STORES_1 1
#define GET_NUM_STORES_2 1
#define GET_NUM_STORES_3 1
#define GET_NUM_STORES_4 1
#define GET_NUM_STORES_5 1
#define GET_NUM_STORES_6 1
#define GET_NUM_STORES_7 1
#define GET_NUM_STORES_8 1
#define GET_NUM_STORES_16 2
#define GET_NUM_STORES(R) GET_NUM_STORES_##R
// layout can be PackedA_RowMajor, PackedB_ColumnMajor, PackedB_PackedB, etc.
// sg is empty for XMX8 and _SG16 for PVC
// elem_bitwidth is 8, 16 or 32
@ -274,7 +291,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
/* not supported, fallthrough */
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
/* not supported, fallthrough */
#define IMPLEMENT_BLOCK2D_STORE(element_type, contrib_type, contrib_bitwidth, M, K) \
#define IMPLEMENT_BLOCK2D_STORE_1(element_type, contrib_type, contrib_bitwidth, M, K) \
/* not supported, fallthrough */
// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth);
@ -288,7 +305,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
long offset = as_long(mem); \
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
int height = M - 1; /* row count */ \
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
int2 coords = (int2)(x, 0); \
@ -303,7 +320,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
long offset = as_long(mem); \
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
int height = contrib_K - 1; /* column count */ \
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
int2 coords = (int2)(x, 0); \
@ -319,7 +336,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
long offset = as_long(mem); \
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
int height = K - 1; /* row count */ \
long x = (offset - baseoffset) / (sizeof (element_type)); /* in elements */ \
int2 coords = (int2)(x, 0); \
@ -337,11 +354,12 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
IMPLEMENT_BLOCK2D_LOAD__(sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
M, K, WI_rows)
#define IMPLEMENT_BLOCK2D_STORE_SG16(element_type, contrib_type, contrib_bitwidth, M, K) \
// _1 suffix in the name indicates that the function is using 1 2d block store
#define IMPLEMENT_BLOCK2D_STORE_SG16_1(element_type, contrib_type, contrib_bitwidth, M, K) \
long offset = as_long(mem); \
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
int height = M - 1; /* row count */ \
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
int2 coords = (int2)(x, 0); \
@ -350,6 +368,18 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, M, M, K)(baseoffset, width, height, pitch, coords, val, cacheOpt); \
return;
// _2 suffix in the name indicates that the function is using 2 2d block stores
#define IMPLEMENT_BLOCK2D_STORE_SG16_2(element_type, contrib_type, contrib_bitwidth, M, K) \
__private char *c0 = src + 0 * 8 * (sizeof (int)); \
__private char *c1 = src + 1 * 8 * (sizeof (int)); \
\
char *mem0 = mem; \
char *mem1 = mem + 8 * (sizeof (int)) * stride; \
\
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem0, c0, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem1, c1, stride, cacheOpt); \
return;
// layout can be PackedA_RowMajor, PackedB_ColumnMajor, PackedB_PackedB, etc.
// sg is empty for XMX8 and _SG16 for PVC
// element_type is char for i8, short for i16 and int for i32
@ -381,7 +411,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
&& stride == K && (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
&& (address_space == AS_GLOBAL || address_space == AS_LOCAL) \
) { \
OUT_STORE_VEC##M(u##contrib_type) OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(read, us)(ATTRIBUTE_##address_space u##contrib_type *); \
OUT_STORE_VEC##M(u##contrib_type) OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(read, us)(const ATTRIBUTE_##address_space u##contrib_type *); \
OUT_STORE_VEC##M(u##contrib_type) res = DEFINE_BLOCK_RW_NAME##M(read, us)((ATTRIBUTE_##address_space u##contrib_type *)mem); \
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
return; \
@ -599,19 +629,21 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
#define VEC_IND1(var, ind) var
// set block_opt to false to disable block non-continous optimization per one built-in as a workaround
#define DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt, address_space) \
// num_stores - how many block 2d store operations are needed to store the whole Joint Matrix of this shape
#define DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt, address_space, num_stores) \
INLINE void MANGLE_STORE_NAME_##address_space(layout, sg, elem_bitwidth, shape, WI_rows) (char *mem, __private char *src, long stride, int cacheOpt) { \
int sg_size = get_sub_group_size(); \
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8 || M == 16) \
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \
) { \
IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K) \
IMPLEMENT_BLOCK2D_STORE##sg##_##num_stores(element_type, contrib_type, contrib_bitwidth, M, K) \
} \
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL && stride == K \
&& (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
&& (address_space == AS_GLOBAL || address_space == AS_LOCAL) \
) { \
OUT_VEC##M(contrib_type) vec = *(__private OUT_VEC##M(contrib_type) *)src; \
void OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(write, us)(ATTRIBUTE_##address_space u##contrib_type *, OUT_STORE_VEC##M(u##contrib_type)); \
DEFINE_BLOCK_RW_NAME##M(write, us)((ATTRIBUTE_##address_space u##contrib_type *)mem, VEC_TO_VEC_STORE##M(u##contrib_type , vec)); \
return; \
} \
@ -642,15 +674,15 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
} \
}
#define DEFINE_STORE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt) \
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GENERIC) \
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_LOCAL) \
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GLOBAL)
#define DEFINE_STORE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt, num_stores) \
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GENERIC, num_stores) \
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_LOCAL, num_stores) \
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GLOBAL, num_stores)
#define DEFINE_STORE(layout, sg, element_type, contrib_type, M, K, order, us, WI_rows, block_opt) \
DEFINE_STORE__(layout, sg, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type),\
M, K, SHAPE(layout, M, K, element_type, contrib_type), \
order, us, WI_rows, block_opt)
order, us, WI_rows, block_opt, GET_NUM_STORES(M))
// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
@ -901,52 +933,54 @@ DEFINE_GET_COORD(Accumulator, , 32, 32, 8, 8, 1)
/* experimental large slice support: */
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) {
short16 a = *(short16 *)a_ptr;
int8 b = *(int8 *)b_ptr;
int16 raw_c = *(int16 *)raw_c_ptr;
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7);
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf);
float16 c = *(float16 *)&raw_c;
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7);
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf);
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c0, a0, b);
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c1, a1, b);
int8 res0 = *(int8 *)&fres0;
int8 res1 = *(int8 *)&fres1;
__private int16 *dst = (__private int16 *)result;
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7,
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7);
#define DEFINE_MAD_16x16x16_IMPL(a_type, b_type, a_suffix, b_suffix) \
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) { \
short16 a = *(short16 *)a_ptr; \
int8 b = *(int8 *)b_ptr; \
int16 raw_c = *(int16 *)raw_c_ptr; \
\
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7); \
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf); \
\
float16 c = *(float16 *)&raw_c; \
\
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7); \
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf); \
\
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(c0, a0, b); \
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(c1, a1, b); \
\
int8 res0 = *(int8 *)&fres0; \
int8 res1 = *(int8 *)&fres1; \
\
__private int16 *dst = (__private int16 *)result; \
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7, \
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7); \
}
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_fp16_fp16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) {
short16 a = *(short16 *)a_ptr;
int8 b = *(int8 *)b_ptr;
int16 raw_c = *(int16 *)raw_c_ptr;
DEFINE_MAD_16x16x16_IMPL(bf16, bf16, bf, bf)
DEFINE_MAD_16x16x16_IMPL(fp16, fp16, hf, hf)
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7);
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf);
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
short a = *(short *) a_ptr;
float16 c = *(float16 *)&raw_c;
int8 b0 = *(int8 *) b_ptr;
int8 b1 = *(int8 *)(b_ptr + 1 * 16 * (sizeof (short)));
int8 b2 = *(int8 *)(b_ptr + 2 * 16 * (sizeof (short)));
int8 b3 = *(int8 *)(b_ptr + 3 * 16 * (sizeof (short)));
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7);
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf);
float c0 = *(float *) c_ptr;
float c1 = *(float *) (c_ptr + 1 * (sizeof (int)));
float c2 = *(float *) (c_ptr + 2 * (sizeof (int)));
float c3 = *(float *) (c_ptr + 3 * (sizeof (int)));
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8(c0, a0, b);
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8(c1, a1, b);
float d0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c0, a, b0);
float d1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c1, a, b1);
float d2 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c2, a, b2);
float d3 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c3, a, b3);
int8 res0 = *(int8 *)&fres0;
int8 res1 = *(int8 *)&fres1;
__private int16 *dst = (__private int16 *)result;
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7,
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7);
__private int4 *dst = (__private int4 *)d_ptr;
*dst = (int4)(as_int(d0), as_int(d1), as_int(d2), as_int(d3));
}
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
@ -987,101 +1021,77 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__priv
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7);
}
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 16, 16, ROW_MAJOR, , 16)
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 32, 16, ROW_MAJOR, , 32)
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16)
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, int, 32, 16, ROW_MAJOR, , 32)
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 16, 16, ROW_MAJOR, , 16)
#define DEFINE_ACC_ROW_MAJOR_32x64(address_space) \
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
__private char *c0 = dst + 0 * 32 * (sizeof (int)); \
__private char *c1 = dst + 1 * 32 * (sizeof (int)); \
__private char *c2 = dst + 2 * 32 * (sizeof (int)); \
__private char *c3 = dst + 3 * 32 * (sizeof (int)); \
\
char *mem0 = mem + 0 * 16 * (sizeof (int)); \
char *mem1 = mem + 1 * 16 * (sizeof (int)); \
char *mem2 = mem + 2 * 16 * (sizeof (int)); \
char *mem3 = mem + 3 * 16 * (sizeof (int)); \
\
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c0, mem0, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c1, mem1, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c2, mem2, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c3, mem3, stride, cacheOpt); \
}
DEFINE_ACC_ROW_MAJOR_32x64(generic)
DEFINE_ACC_ROW_MAJOR_32x64(global)
DEFINE_ACC_ROW_MAJOR_32x64(local)
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 32, 16, ROW_MAJOR, , 32)
#define DEFINE_B_B_16x64(address_space) \
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_32_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
__private char *b0 = dst; \
__private char *b1 = dst + 1 * 16 * (sizeof (short)); \
__private char *b2 = dst + 2 * 16 * (sizeof (short)); \
__private char *b3 = dst + 3 * 16 * (sizeof (short)); \
\
char *mem0 = mem + 0 * 16 * (sizeof (int)); \
char *mem1 = mem + 1 * 16 * (sizeof (int)); \
char *mem2 = mem + 2 * 16 * (sizeof (int)); \
char *mem3 = mem + 3 * 16 * (sizeof (int)); \
\
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b0, mem0, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b1, mem1, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b2, mem2, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b3, mem3, stride, cacheOpt); \
}
DEFINE_B_B_16x64(generic)
DEFINE_B_B_16x64(global)
DEFINE_B_B_16x64(local)
#define DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, address_space) \
INLINE void MANGLE_STORE_NAME_##address_space(layout, sg, elem_bitwidth, shape, WI_rows) (char *mem, __private char *src, long stride, int cacheOpt) { \
int sg_size = get_sub_group_size(); \
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && M == 16 \
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8) { \
__private char *c0 = src + 0 * 8 * (sizeof (int)); \
__private char *c1 = src + 1 * 8 * (sizeof (int)); \
\
char *mem0 = mem; \
char *mem1 = mem + 8 * (sizeof (int)) * stride; \
\
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem0, c0, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem1, c1, stride, cacheOpt); \
return; \
} \
contrib_type *ptr = (contrib_type *)mem; \
int slid = get_sub_group_local_id(); \
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
stride = stride / pack_factor; \
int sg_cols = K / pack_factor; \
int skip_factor = sg_size / sg_cols; \
__private contrib_type *slice = (__private contrib_type *)src; \
for (int i = 0; i < WI_rows; i++) { \
if ( (i*skip_factor + slid/sg_cols) < M ) \
ptr[IND##order(slid, stride, skip_factor, i, sg_cols)] = slice[i]; \
else \
continue; /*last even row for matrix with odd number of rows doesn't exist*/ \
} \
}
#define DEFINE_STORE_LARGE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_GENERIC) \
DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_LOCAL) \
DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_GLOBAL)
#define DEFINE_STORE_LARGE(layout, sg, element_type, contrib_type, M, K, order, us, WI_rows) \
DEFINE_STORE_LARGE__(layout, sg, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
M, K, SHAPE(layout, M, K, element_type, contrib_type), \
order, us, WI_rows)
// sub group size 16
DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16)
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16, false)
// sub group size 32
DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 8)
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 8, false)
// special case for 1x64 C load: Joint Matrices are expected to be contiguous in memory, without padding at the end of a row
// hence, we can load 1x64 shape using single 2d block load of shape 4x16 instead of 4 1x16 loads
#define DEFINE_LOAD_LARGE_IMPL_(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, address_space) \
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_1x64_i32_4_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
long offset = as_long(mem); \
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
int width = sizeof(int) * 16 - 1; /* load 1x64 as 4x16, hence, width is 16 int in bytes */ \
int height = 4 - 1; /* row count */ \
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
long x = (offset - baseoffset) / sizeof(int); /* in elements */ \
int2 coords = (int2)(x, 0); \
uint4 __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(long, int, int, int, int2, int); \
uint4 res = __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(baseoffset, width, height, pitch, coords, cacheOpt); \
*(__private uint4 *)dst = res; \
}
// _4 in the name is for 4 2d block loads
#define DEFINE_LOAD_LARGE_IMPL_4(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, address_space) \
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x##C##_i##elem_bitwidth##_##WI_rows##_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
__private char *dst0 = dst; \
__private char *dst1 = dst + 1 * R * (sizeof (elem_type)); \
__private char *dst2 = dst + 2 * R * (sizeof (elem_type)); \
__private char *dst3 = dst + 3 * R * (sizeof (elem_type)); \
\
char *mem0 = mem + 0 * 16 * (sizeof (int)); \
char *mem1 = mem + 1 * 16 * (sizeof (int)); \
char *mem2 = mem + 2 * 16 * (sizeof (int)); \
char *mem3 = mem + 3 * 16 * (sizeof (int)); \
\
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst0, mem0, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst1, mem1, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst2, mem2, stride, cacheOpt); \
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst3, mem3, stride, cacheOpt); \
}
#define DEFINE_LOAD_LARGE__(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, num_loads) \
DEFINE_LOAD_LARGE_IMPL_##num_loads(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, generic) \
DEFINE_LOAD_LARGE_IMPL_##num_loads(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, global ) \
DEFINE_LOAD_LARGE_IMPL_##num_loads(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, local )
#define DEFINE_LOAD_LARGE(layout, elem_type, R, C, WI_rows, num_loads) \
DEFINE_LOAD_LARGE__(layout, elem_type, BITWIDTH(elem_type), R, C, WI_rows, MATH_DIV(WI_rows, num_loads), num_loads)
DEFINE_LOAD_LARGE(PackedB_PackedB, short, 16, 64, 32, 4)
DEFINE_LOAD_LARGE(Accumulator_RowMajor, , 1, 64, , )
DEFINE_LOAD_LARGE(Accumulator_RowMajor, int, 32, 64, 128, 4)
#define DEFINE_STORE_ACC_ROW_MAJOR_1x64(address_space) \
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_1x64_i32_4_##address_space##_pi64_v8i8(char *mem, __private char *src, long stride, int cacheOpt) { \
long offset = as_long(mem); \
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
int width = sizeof(int) * 16 - 1; /* in bytes, load 1x64 as 4x16 to use one load instead of 4 */ \
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
int height = 4 - 1; /* row count */ \
long x = (offset - baseoffset) / sizeof(int); /* in elements */ \
int2 coords = (int2)(x, 0); \
uint4 val = *(uint4 *)src; \
void __builtin_IB_subgroup_block_write_flat_u32_wi4_m4k16v1(long, int, int, int, int2, uint4, int); \
__builtin_IB_subgroup_block_write_flat_u32_wi4_m4k16v1(baseoffset, width, height, pitch, coords, val, cacheOpt); \
}
#define DEFINE_STORE_ACC_ROW_MAJOR_32x64(address_space) \
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_##address_space##_pi64_v8i8(char *mem, __private char *src, long stride, int cacheOpt) { \
@ -1113,6 +1123,10 @@ DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 8
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem7, c7, stride, cacheOpt); \
}
DEFINE_STORE_ACC_ROW_MAJOR_32x64(generic)
DEFINE_STORE_ACC_ROW_MAJOR_32x64(global)
DEFINE_STORE_ACC_ROW_MAJOR_32x64(local)
#define DEFINE_STORE_ACC_ROW_MAJOR_LARGE(R, C) \
DEFINE_STORE_ACC_ROW_MAJOR_##R##x##C(generic) \
DEFINE_STORE_ACC_ROW_MAJOR_##R##x##C(global) \
DEFINE_STORE_ACC_ROW_MAJOR_##R##x##C(local)
DEFINE_STORE_ACC_ROW_MAJOR_LARGE( 1, 64)
DEFINE_STORE_ACC_ROW_MAJOR_LARGE(32, 64)

View File

@ -444,6 +444,8 @@ static bool isSupprtedLargeSlice(const JointMatrixTypeDescription *desc, bool us
if (desc->layout == LayoutRowMajor) {
if (desc->rows == 16 && desc->columns == 16 && desc->bitWidth == 32)
return true;
if (desc->rows == 1 && desc->columns == 64 && desc->bitWidth == 32)
return true;
if (desc->rows == 32 && desc->columns == 64 && desc->bitWidth == 32)
return true;
}
@ -1089,6 +1091,8 @@ static bool isMADSupportedAsBuiltin(unsigned M, unsigned N, unsigned K) {
return true;
if (M == 32 && N == 64 && K == 16)
return true;
if (M == 1 && N == 64 && K == 16)
return true;
return false;
}