mirror of
https://github.com/intel/intel-graphics-compiler.git
synced 2025-11-04 08:21:06 +08:00
Add 1x64x32 Joint Matrix support, refactoring
Add 1x64x32 Joint Matrix support. Refactor load/store/mad built-ins for Joint Matrix.
This commit is contained in:
@ -57,6 +57,8 @@ SPDX-License-Identifier: MIT
|
||||
#define IND_VNNI_TX(slid, stride, skip_factor, i, sg_cols) (i + (slid * stride))
|
||||
|
||||
// no int7, int6, int5 types
|
||||
#define VEC_TO_VEC16(type, vec) \
|
||||
(type##16)(vec.s0, vec.s1, vec.s2, vec.s3, vec.s4, vec.s5, vec.s6, vec.s7, vec.s8, vec.s9, vec.sA, vec.sB, vec.sC, vec.sD, vec.sE, vec.sF)
|
||||
#define VEC_TO_VEC8(type, vec) \
|
||||
(type##8)(vec.s0, vec.s1, vec.s2, vec.s3, vec.s4, vec.s5, vec.s6, vec.s7)
|
||||
#define VEC_TO_VEC7(type, vec) \
|
||||
@ -71,6 +73,7 @@ SPDX-License-Identifier: MIT
|
||||
#define VEC_TO_VEC1(type, vec) (type)(vec)
|
||||
|
||||
// in case of store, we can not use uint3 with intel_sub_group_block_write4
|
||||
#define VEC_TO_VEC_STORE16(type, vec) VEC_TO_VEC16(type, vec)
|
||||
#define VEC_TO_VEC_STORE8(type, vec) VEC_TO_VEC8(type, vec)
|
||||
#define VEC_TO_VEC_STORE7(type, vec) VEC_TO_VEC7(type, vec)
|
||||
#define VEC_TO_VEC_STORE6(type, vec) VEC_TO_VEC6(type, vec)
|
||||
@ -134,6 +137,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
#define OUT_STORE_VEC1(type) type
|
||||
|
||||
// Math division macros
|
||||
#define MATH_128_DIV_4 32
|
||||
#define MATH_64_DIV_64 1
|
||||
#define MATH_64_DIV_32 2
|
||||
#define MATH_64_DIV_16 4
|
||||
@ -209,6 +213,19 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
#define SHAPE_Accumulator_ColumnMajor(M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT(M, K)
|
||||
#define SHAPE(layout, M, K, element_type, contrib_type) SHAPE_##layout(M, K, BITWIDTH(element_type), BITWIDTH(contrib_type))
|
||||
|
||||
// Get number of 2d block stores needed for a given number of rows.
|
||||
// R parameter is number of rows.
|
||||
#define GET_NUM_STORES_1 1
|
||||
#define GET_NUM_STORES_2 1
|
||||
#define GET_NUM_STORES_3 1
|
||||
#define GET_NUM_STORES_4 1
|
||||
#define GET_NUM_STORES_5 1
|
||||
#define GET_NUM_STORES_6 1
|
||||
#define GET_NUM_STORES_7 1
|
||||
#define GET_NUM_STORES_8 1
|
||||
#define GET_NUM_STORES_16 2
|
||||
#define GET_NUM_STORES(R) GET_NUM_STORES_##R
|
||||
|
||||
// layout can be PackedA_RowMajor, PackedB_ColumnMajor, PackedB_PackedB, etc.
|
||||
// sg is empty for XMX8 and _SG16 for PVC
|
||||
// elem_bitwidth is 8, 16 or 32
|
||||
@ -274,7 +291,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
/* not supported, fallthrough */
|
||||
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
|
||||
/* not supported, fallthrough */
|
||||
#define IMPLEMENT_BLOCK2D_STORE(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
#define IMPLEMENT_BLOCK2D_STORE_1(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
/* not supported, fallthrough */
|
||||
|
||||
// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth);
|
||||
@ -288,7 +305,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
long offset = as_long(mem); \
|
||||
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
|
||||
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
|
||||
int height = M - 1; /* row count */ \
|
||||
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
|
||||
int2 coords = (int2)(x, 0); \
|
||||
@ -303,7 +320,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
long offset = as_long(mem); \
|
||||
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
|
||||
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
|
||||
int height = contrib_K - 1; /* column count */ \
|
||||
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
|
||||
int2 coords = (int2)(x, 0); \
|
||||
@ -319,7 +336,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
long offset = as_long(mem); \
|
||||
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
|
||||
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
|
||||
int height = K - 1; /* row count */ \
|
||||
long x = (offset - baseoffset) / (sizeof (element_type)); /* in elements */ \
|
||||
int2 coords = (int2)(x, 0); \
|
||||
@ -337,11 +354,12 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
IMPLEMENT_BLOCK2D_LOAD__(sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
|
||||
M, K, WI_rows)
|
||||
|
||||
#define IMPLEMENT_BLOCK2D_STORE_SG16(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
// _1 suffix in the name indicates that the function is using 1 2d block store
|
||||
#define IMPLEMENT_BLOCK2D_STORE_SG16_1(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
long offset = as_long(mem); \
|
||||
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
|
||||
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
|
||||
int height = M - 1; /* row count */ \
|
||||
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
|
||||
int2 coords = (int2)(x, 0); \
|
||||
@ -350,6 +368,18 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, M, M, K)(baseoffset, width, height, pitch, coords, val, cacheOpt); \
|
||||
return;
|
||||
|
||||
// _2 suffix in the name indicates that the function is using 2 2d block stores
|
||||
#define IMPLEMENT_BLOCK2D_STORE_SG16_2(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
__private char *c0 = src + 0 * 8 * (sizeof (int)); \
|
||||
__private char *c1 = src + 1 * 8 * (sizeof (int)); \
|
||||
\
|
||||
char *mem0 = mem; \
|
||||
char *mem1 = mem + 8 * (sizeof (int)) * stride; \
|
||||
\
|
||||
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem0, c0, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem1, c1, stride, cacheOpt); \
|
||||
return;
|
||||
|
||||
// layout can be PackedA_RowMajor, PackedB_ColumnMajor, PackedB_PackedB, etc.
|
||||
// sg is empty for XMX8 and _SG16 for PVC
|
||||
// element_type is char for i8, short for i16 and int for i32
|
||||
@ -381,7 +411,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
|
||||
&& stride == K && (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
|
||||
&& (address_space == AS_GLOBAL || address_space == AS_LOCAL) \
|
||||
) { \
|
||||
OUT_STORE_VEC##M(u##contrib_type) OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(read, us)(ATTRIBUTE_##address_space u##contrib_type *); \
|
||||
OUT_STORE_VEC##M(u##contrib_type) OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(read, us)(const ATTRIBUTE_##address_space u##contrib_type *); \
|
||||
OUT_STORE_VEC##M(u##contrib_type) res = DEFINE_BLOCK_RW_NAME##M(read, us)((ATTRIBUTE_##address_space u##contrib_type *)mem); \
|
||||
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
|
||||
return; \
|
||||
@ -599,19 +629,21 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
|
||||
#define VEC_IND1(var, ind) var
|
||||
|
||||
// set block_opt to false to disable block non-continous optimization per one built-in as a workaround
|
||||
#define DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt, address_space) \
|
||||
// num_stores - how many block 2d store operations are needed to store the whole Joint Matrix of this shape
|
||||
#define DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt, address_space, num_stores) \
|
||||
INLINE void MANGLE_STORE_NAME_##address_space(layout, sg, elem_bitwidth, shape, WI_rows) (char *mem, __private char *src, long stride, int cacheOpt) { \
|
||||
int sg_size = get_sub_group_size(); \
|
||||
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
|
||||
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8 || M == 16) \
|
||||
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \
|
||||
) { \
|
||||
IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
IMPLEMENT_BLOCK2D_STORE##sg##_##num_stores(element_type, contrib_type, contrib_bitwidth, M, K) \
|
||||
} \
|
||||
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL && stride == K \
|
||||
&& (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
|
||||
&& (address_space == AS_GLOBAL || address_space == AS_LOCAL) \
|
||||
) { \
|
||||
OUT_VEC##M(contrib_type) vec = *(__private OUT_VEC##M(contrib_type) *)src; \
|
||||
void OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(write, us)(ATTRIBUTE_##address_space u##contrib_type *, OUT_STORE_VEC##M(u##contrib_type)); \
|
||||
DEFINE_BLOCK_RW_NAME##M(write, us)((ATTRIBUTE_##address_space u##contrib_type *)mem, VEC_TO_VEC_STORE##M(u##contrib_type , vec)); \
|
||||
return; \
|
||||
} \
|
||||
@ -642,15 +674,15 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
|
||||
} \
|
||||
}
|
||||
|
||||
#define DEFINE_STORE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt) \
|
||||
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GENERIC) \
|
||||
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_LOCAL) \
|
||||
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GLOBAL)
|
||||
#define DEFINE_STORE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt, num_stores) \
|
||||
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GENERIC, num_stores) \
|
||||
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_LOCAL, num_stores) \
|
||||
DEFINE_STORE_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, block_opt, AS_GLOBAL, num_stores)
|
||||
|
||||
#define DEFINE_STORE(layout, sg, element_type, contrib_type, M, K, order, us, WI_rows, block_opt) \
|
||||
DEFINE_STORE__(layout, sg, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type),\
|
||||
M, K, SHAPE(layout, M, K, element_type, contrib_type), \
|
||||
order, us, WI_rows, block_opt)
|
||||
order, us, WI_rows, block_opt, GET_NUM_STORES(M))
|
||||
|
||||
// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
|
||||
|
||||
@ -901,52 +933,54 @@ DEFINE_GET_COORD(Accumulator, , 32, 32, 8, 8, 1)
|
||||
|
||||
/* experimental large slice support: */
|
||||
|
||||
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) {
|
||||
short16 a = *(short16 *)a_ptr;
|
||||
int8 b = *(int8 *)b_ptr;
|
||||
int16 raw_c = *(int16 *)raw_c_ptr;
|
||||
|
||||
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7);
|
||||
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf);
|
||||
|
||||
float16 c = *(float16 *)&raw_c;
|
||||
|
||||
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7);
|
||||
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf);
|
||||
|
||||
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c0, a0, b);
|
||||
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c1, a1, b);
|
||||
|
||||
int8 res0 = *(int8 *)&fres0;
|
||||
int8 res1 = *(int8 *)&fres1;
|
||||
|
||||
__private int16 *dst = (__private int16 *)result;
|
||||
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7,
|
||||
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7);
|
||||
#define DEFINE_MAD_16x16x16_IMPL(a_type, b_type, a_suffix, b_suffix) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) { \
|
||||
short16 a = *(short16 *)a_ptr; \
|
||||
int8 b = *(int8 *)b_ptr; \
|
||||
int16 raw_c = *(int16 *)raw_c_ptr; \
|
||||
\
|
||||
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7); \
|
||||
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf); \
|
||||
\
|
||||
float16 c = *(float16 *)&raw_c; \
|
||||
\
|
||||
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7); \
|
||||
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf); \
|
||||
\
|
||||
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(c0, a0, b); \
|
||||
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(c1, a1, b); \
|
||||
\
|
||||
int8 res0 = *(int8 *)&fres0; \
|
||||
int8 res1 = *(int8 *)&fres1; \
|
||||
\
|
||||
__private int16 *dst = (__private int16 *)result; \
|
||||
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7, \
|
||||
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7); \
|
||||
}
|
||||
|
||||
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_fp16_fp16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) {
|
||||
short16 a = *(short16 *)a_ptr;
|
||||
int8 b = *(int8 *)b_ptr;
|
||||
int16 raw_c = *(int16 *)raw_c_ptr;
|
||||
DEFINE_MAD_16x16x16_IMPL(bf16, bf16, bf, bf)
|
||||
DEFINE_MAD_16x16x16_IMPL(fp16, fp16, hf, hf)
|
||||
|
||||
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7);
|
||||
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf);
|
||||
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
|
||||
short a = *(short *) a_ptr;
|
||||
|
||||
float16 c = *(float16 *)&raw_c;
|
||||
int8 b0 = *(int8 *) b_ptr;
|
||||
int8 b1 = *(int8 *)(b_ptr + 1 * 16 * (sizeof (short)));
|
||||
int8 b2 = *(int8 *)(b_ptr + 2 * 16 * (sizeof (short)));
|
||||
int8 b3 = *(int8 *)(b_ptr + 3 * 16 * (sizeof (short)));
|
||||
|
||||
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7);
|
||||
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf);
|
||||
float c0 = *(float *) c_ptr;
|
||||
float c1 = *(float *) (c_ptr + 1 * (sizeof (int)));
|
||||
float c2 = *(float *) (c_ptr + 2 * (sizeof (int)));
|
||||
float c3 = *(float *) (c_ptr + 3 * (sizeof (int)));
|
||||
|
||||
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8(c0, a0, b);
|
||||
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8(c1, a1, b);
|
||||
float d0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c0, a, b0);
|
||||
float d1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c1, a, b1);
|
||||
float d2 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c2, a, b2);
|
||||
float d3 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c3, a, b3);
|
||||
|
||||
int8 res0 = *(int8 *)&fres0;
|
||||
int8 res1 = *(int8 *)&fres1;
|
||||
|
||||
__private int16 *dst = (__private int16 *)result;
|
||||
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7,
|
||||
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7);
|
||||
__private int4 *dst = (__private int4 *)d_ptr;
|
||||
*dst = (int4)(as_int(d0), as_int(d1), as_int(d2), as_int(d3));
|
||||
}
|
||||
|
||||
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
|
||||
@ -987,101 +1021,77 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__priv
|
||||
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7);
|
||||
}
|
||||
|
||||
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 16, 16, ROW_MAJOR, , 16)
|
||||
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 32, 16, ROW_MAJOR, , 32)
|
||||
|
||||
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16)
|
||||
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, int, 32, 16, ROW_MAJOR, , 32)
|
||||
|
||||
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 16, 16, ROW_MAJOR, , 16)
|
||||
|
||||
#define DEFINE_ACC_ROW_MAJOR_32x64(address_space) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
|
||||
__private char *c0 = dst + 0 * 32 * (sizeof (int)); \
|
||||
__private char *c1 = dst + 1 * 32 * (sizeof (int)); \
|
||||
__private char *c2 = dst + 2 * 32 * (sizeof (int)); \
|
||||
__private char *c3 = dst + 3 * 32 * (sizeof (int)); \
|
||||
\
|
||||
char *mem0 = mem + 0 * 16 * (sizeof (int)); \
|
||||
char *mem1 = mem + 1 * 16 * (sizeof (int)); \
|
||||
char *mem2 = mem + 2 * 16 * (sizeof (int)); \
|
||||
char *mem3 = mem + 3 * 16 * (sizeof (int)); \
|
||||
\
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c0, mem0, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c1, mem1, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c2, mem2, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x16_i32_32_##address_space##_v8i8_pi32_i32(c3, mem3, stride, cacheOpt); \
|
||||
}
|
||||
|
||||
DEFINE_ACC_ROW_MAJOR_32x64(generic)
|
||||
DEFINE_ACC_ROW_MAJOR_32x64(global)
|
||||
DEFINE_ACC_ROW_MAJOR_32x64(local)
|
||||
|
||||
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 32, 16, ROW_MAJOR, , 32)
|
||||
|
||||
#define DEFINE_B_B_16x64(address_space) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_32_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
|
||||
__private char *b0 = dst; \
|
||||
__private char *b1 = dst + 1 * 16 * (sizeof (short)); \
|
||||
__private char *b2 = dst + 2 * 16 * (sizeof (short)); \
|
||||
__private char *b3 = dst + 3 * 16 * (sizeof (short)); \
|
||||
\
|
||||
char *mem0 = mem + 0 * 16 * (sizeof (int)); \
|
||||
char *mem1 = mem + 1 * 16 * (sizeof (int)); \
|
||||
char *mem2 = mem + 2 * 16 * (sizeof (int)); \
|
||||
char *mem3 = mem + 3 * 16 * (sizeof (int)); \
|
||||
\
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b0, mem0, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b1, mem1, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b2, mem2, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b3, mem3, stride, cacheOpt); \
|
||||
}
|
||||
|
||||
DEFINE_B_B_16x64(generic)
|
||||
DEFINE_B_B_16x64(global)
|
||||
DEFINE_B_B_16x64(local)
|
||||
|
||||
#define DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, address_space) \
|
||||
INLINE void MANGLE_STORE_NAME_##address_space(layout, sg, elem_bitwidth, shape, WI_rows) (char *mem, __private char *src, long stride, int cacheOpt) { \
|
||||
int sg_size = get_sub_group_size(); \
|
||||
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && M == 16 \
|
||||
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8) { \
|
||||
__private char *c0 = src + 0 * 8 * (sizeof (int)); \
|
||||
__private char *c1 = src + 1 * 8 * (sizeof (int)); \
|
||||
\
|
||||
char *mem0 = mem; \
|
||||
char *mem1 = mem + 8 * (sizeof (int)) * stride; \
|
||||
\
|
||||
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem0, c0, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem1, c1, stride, cacheOpt); \
|
||||
return; \
|
||||
} \
|
||||
contrib_type *ptr = (contrib_type *)mem; \
|
||||
int slid = get_sub_group_local_id(); \
|
||||
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
|
||||
stride = stride / pack_factor; \
|
||||
int sg_cols = K / pack_factor; \
|
||||
int skip_factor = sg_size / sg_cols; \
|
||||
__private contrib_type *slice = (__private contrib_type *)src; \
|
||||
for (int i = 0; i < WI_rows; i++) { \
|
||||
if ( (i*skip_factor + slid/sg_cols) < M ) \
|
||||
ptr[IND##order(slid, stride, skip_factor, i, sg_cols)] = slice[i]; \
|
||||
else \
|
||||
continue; /*last even row for matrix with odd number of rows doesn't exist*/ \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DEFINE_STORE_LARGE__(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
|
||||
DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_GENERIC) \
|
||||
DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_LOCAL) \
|
||||
DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_GLOBAL)
|
||||
|
||||
#define DEFINE_STORE_LARGE(layout, sg, element_type, contrib_type, M, K, order, us, WI_rows) \
|
||||
DEFINE_STORE_LARGE__(layout, sg, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
|
||||
M, K, SHAPE(layout, M, K, element_type, contrib_type), \
|
||||
order, us, WI_rows)
|
||||
|
||||
// sub group size 16
|
||||
DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16)
|
||||
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 16, false)
|
||||
// sub group size 32
|
||||
DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 8)
|
||||
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 8, false)
|
||||
|
||||
// special case for 1x64 C load: Joint Matrices are expected to be contiguous in memory, without padding at the end of a row
|
||||
// hence, we can load 1x64 shape using single 2d block load of shape 4x16 instead of 4 1x16 loads
|
||||
#define DEFINE_LOAD_LARGE_IMPL_(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, address_space) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_1x64_i32_4_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
|
||||
long offset = as_long(mem); \
|
||||
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
|
||||
int width = sizeof(int) * 16 - 1; /* load 1x64 as 4x16, hence, width is 16 int in bytes */ \
|
||||
int height = 4 - 1; /* row count */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
|
||||
long x = (offset - baseoffset) / sizeof(int); /* in elements */ \
|
||||
int2 coords = (int2)(x, 0); \
|
||||
uint4 __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(long, int, int, int, int2, int); \
|
||||
uint4 res = __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(baseoffset, width, height, pitch, coords, cacheOpt); \
|
||||
*(__private uint4 *)dst = res; \
|
||||
}
|
||||
|
||||
// _4 in the name is for 4 2d block loads
|
||||
#define DEFINE_LOAD_LARGE_IMPL_4(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, address_space) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x##C##_i##elem_bitwidth##_##WI_rows##_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride, int cacheOpt) { \
|
||||
__private char *dst0 = dst; \
|
||||
__private char *dst1 = dst + 1 * R * (sizeof (elem_type)); \
|
||||
__private char *dst2 = dst + 2 * R * (sizeof (elem_type)); \
|
||||
__private char *dst3 = dst + 3 * R * (sizeof (elem_type)); \
|
||||
\
|
||||
char *mem0 = mem + 0 * 16 * (sizeof (int)); \
|
||||
char *mem1 = mem + 1 * 16 * (sizeof (int)); \
|
||||
char *mem2 = mem + 2 * 16 * (sizeof (int)); \
|
||||
char *mem3 = mem + 3 * 16 * (sizeof (int)); \
|
||||
\
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst0, mem0, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst1, mem1, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst2, mem2, stride, cacheOpt); \
|
||||
__builtin_spriv_OpJointMatrixLoadINTEL_##layout##_SG16_##R##x16_i##elem_bitwidth##_##WI_rows_per_load##_##address_space##_v8i8_pi32_i32(dst3, mem3, stride, cacheOpt); \
|
||||
}
|
||||
|
||||
#define DEFINE_LOAD_LARGE__(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, num_loads) \
|
||||
DEFINE_LOAD_LARGE_IMPL_##num_loads(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, generic) \
|
||||
DEFINE_LOAD_LARGE_IMPL_##num_loads(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, global ) \
|
||||
DEFINE_LOAD_LARGE_IMPL_##num_loads(layout, elem_type, elem_bitwidth, R, C, WI_rows, WI_rows_per_load, local )
|
||||
|
||||
#define DEFINE_LOAD_LARGE(layout, elem_type, R, C, WI_rows, num_loads) \
|
||||
DEFINE_LOAD_LARGE__(layout, elem_type, BITWIDTH(elem_type), R, C, WI_rows, MATH_DIV(WI_rows, num_loads), num_loads)
|
||||
|
||||
DEFINE_LOAD_LARGE(PackedB_PackedB, short, 16, 64, 32, 4)
|
||||
DEFINE_LOAD_LARGE(Accumulator_RowMajor, , 1, 64, , )
|
||||
DEFINE_LOAD_LARGE(Accumulator_RowMajor, int, 32, 64, 128, 4)
|
||||
|
||||
#define DEFINE_STORE_ACC_ROW_MAJOR_1x64(address_space) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_1x64_i32_4_##address_space##_pi64_v8i8(char *mem, __private char *src, long stride, int cacheOpt) { \
|
||||
long offset = as_long(mem); \
|
||||
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
|
||||
int width = sizeof(int) * 16 - 1; /* in bytes, load 1x64 as 4x16 to use one load instead of 4 */ \
|
||||
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
|
||||
int height = 4 - 1; /* row count */ \
|
||||
long x = (offset - baseoffset) / sizeof(int); /* in elements */ \
|
||||
int2 coords = (int2)(x, 0); \
|
||||
uint4 val = *(uint4 *)src; \
|
||||
void __builtin_IB_subgroup_block_write_flat_u32_wi4_m4k16v1(long, int, int, int, int2, uint4, int); \
|
||||
__builtin_IB_subgroup_block_write_flat_u32_wi4_m4k16v1(baseoffset, width, height, pitch, coords, val, cacheOpt); \
|
||||
}
|
||||
|
||||
#define DEFINE_STORE_ACC_ROW_MAJOR_32x64(address_space) \
|
||||
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_##address_space##_pi64_v8i8(char *mem, __private char *src, long stride, int cacheOpt) { \
|
||||
@ -1113,6 +1123,10 @@ DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, int, 16, 16, ROW_MAJOR, , 8
|
||||
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem7, c7, stride, cacheOpt); \
|
||||
}
|
||||
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_32x64(generic)
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_32x64(global)
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_32x64(local)
|
||||
#define DEFINE_STORE_ACC_ROW_MAJOR_LARGE(R, C) \
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_##R##x##C(generic) \
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_##R##x##C(global) \
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_##R##x##C(local)
|
||||
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_LARGE( 1, 64)
|
||||
DEFINE_STORE_ACC_ROW_MAJOR_LARGE(32, 64)
|
||||
|
||||
@ -444,6 +444,8 @@ static bool isSupprtedLargeSlice(const JointMatrixTypeDescription *desc, bool us
|
||||
if (desc->layout == LayoutRowMajor) {
|
||||
if (desc->rows == 16 && desc->columns == 16 && desc->bitWidth == 32)
|
||||
return true;
|
||||
if (desc->rows == 1 && desc->columns == 64 && desc->bitWidth == 32)
|
||||
return true;
|
||||
if (desc->rows == 32 && desc->columns == 64 && desc->bitWidth == 32)
|
||||
return true;
|
||||
}
|
||||
@ -1089,6 +1091,8 @@ static bool isMADSupportedAsBuiltin(unsigned M, unsigned N, unsigned K) {
|
||||
return true;
|
||||
if (M == 32 && N == 64 && K == 16)
|
||||
return true;
|
||||
if (M == 1 && N == 64 && K == 16)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user