From df14096e00b80ea44d0838b3abcfdd692489acda Mon Sep 17 00:00:00 2001 From: Shilei Tian Date: Mon, 15 Dec 2025 09:16:26 -0500 Subject: [PATCH] [NFC][AMDGPU] Refactor the multiclass for WMMA_F8F6F4 instructions (#172245) --- llvm/lib/Target/AMDGPU/VOP3PInstructions.td | 47 +++++++++++++++------ 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 2dfa905848a3..410e56d83331 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -1814,21 +1814,42 @@ def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 0, 0, 0, 1>; def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 0, 0, 0, 1>; -multiclass WMMA_F8F6F4_Profiles { - def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; - def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; +// Helper class to compute the destination vector type of WMMA_F8F6F4 instructions based on element type and dimensions. +class getWMMAF8F6F4DstVTy { + // Size in bits = (M * N / 32) * element_size_in_bits + defvar Size = !mul(!div(!mul(M, N), 32), DstEltTy.Size); + ValueType ret = !cond(!eq(Size, 256) : v8f32, + !eq(Size, 1024) : v64f16); } -defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<0, 0, 0>; -defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles<1, 0, 1>; -defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<1, 1, 1>; +// Helper class to compute the type of matrix A and B of WMMA_F8F6F4 instructions based on format and dimensions. +class getWMMAF8F6F4ABVTy { + defvar FmtBits = !cond(!eq(Fmt, "f8") : 8, + !eq(Fmt, "f6") : 6, + !eq(Fmt, "f4") : 4); + // TypeSize in bits = (D1 * D2 / 32) * format_bits + defvar TypeSize = !mul(!div(!mul(D1, D2), 32), FmtBits); + ValueType ret = !cond(!eq(TypeSize, 256) : v8i32, + !eq(TypeSize, 384) : v12i32, + !eq(TypeSize, 512) : v16i32, + !eq(TypeSize, 1024) : v32i32); +} + +multiclass WMMA_F8F6F4_Profiles { + defvar DstTy = getWMMAF8F6F4DstVTy.ret; + foreach ATy = ["f8", "f6", "f4"] in { + foreach BTy = ["f8", "f6", "f4"] in { + def _#ATy#_#BTy#_w32 : VOP3PWMMA_Profile< + [DstTy, getWMMAF8F6F4ABVTy.ret, getWMMAF8F6F4ABVTy.ret, DstTy], + 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>; + } + } +} + +defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles; +defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles; +defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles; class VOP_WMMA_LD_SCALE : VOP3P_Profile> { let HasMatrixScale = 1;