[MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)

Some usage case or shapes for 2D block op with sub byte types can be
emulated with 2D block operations for non-sub byte types. Add sub byte
type i4 as a valid XeGPU type. And add lowering of certain 2D
block operations by emulating with larger element types.
This commit is contained in:
Sang Ik Lee
2025-12-08 09:23:07 -08:00
committed by GitHub
parent f88d060c41
commit c3579f0199
4 changed files with 184 additions and 14 deletions

View File

@@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
include "mlir/IR/BuiltinTypes.td"
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16,
SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType

View File

@@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}
//
// Note:
// Block operations for tile of sub byte element types are handled by
// emulating with larger element types.
// Tensor descriptor are keep intact and only ops consuming them are
// emulated
//
class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
@@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
if (elemBitSize % 8 != 0)
bool isSubByte = elemBitSize < 8;
uint64_t wScaleFactor = 1;
if (!isSubByte && (elemBitSize % 8 != 0))
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
auto tileW = tdescTy.getDimSize(tileRank - 1);
// For sub byte types, only 4bits are currently supported.
if (isSubByte) {
if (elemBitSize != 4)
return rewriter.notifyMatchFailure(
op, "Only sub byte types of 4bits are supported.");
if (tileRank != 2)
return rewriter.notifyMatchFailure(
op, "Sub byte types are only supported for 2D tensor descriptors.");
auto subByteFactor = 8 / elemBitSize;
auto tileH = tdescTy.getDimSize(0);
// Handle special case for packed load.
if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
if (op.getPacked().value_or(false)) {
// packed load is implemented as packed loads of 8bit elements.
if (tileH == systolicDepth * 4 &&
tileW == executionSize * subByteFactor) {
// Usage case for loading as Matrix B with pack request.
// source is assumed to pre-packed into 8bit elements
// Emulate with 8bit loads with pack request.
// scaled_tileW = executionSize
elemType = rewriter.getIntegerType(8);
tileW = executionSize;
wScaleFactor = subByteFactor;
}
}
}
// If not handled by packed load case above, handle other cases.
if (wScaleFactor == 1) {
auto sub16BitFactor = subByteFactor * 2;
if (tileW == executionSize * sub16BitFactor) {
// Usage case for loading as Matrix A operand
// Emulate with 16bit loads/stores.
// scaled_tileW = executionSize
elemType = rewriter.getIntegerType(16);
tileW = executionSize;
wScaleFactor = sub16BitFactor;
} else {
return rewriter.notifyMatchFailure(
op, "Unsupported tile shape for sub byte types.");
}
}
// recompute element bit size for emulation.
elemBitSize = elemType.getIntOrFloatBitWidth();
}
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
@@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// FIXME: width or pitch is not the same as baseShapeW it should be the
// stride of the second to last dimension in row major layout.
// Compute width in bytes.
Value baseWidthByte =
Value baseShapeWInBytes =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
// Compute pitch in bytes.
Value basePitchByte =
Value basePitchBytes =
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
// Get tile width from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(tileRank - 1);
if (wScaleFactor > 1) {
// Scale offsetW, baseShapeWInBytes for sub byte emulation.
// Note: tileW is already scaled above.
Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
baseShapeWInBytes = arith::ShRSIOp::create(
rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
wScaleFactorValLog2);
offsetW =
arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
}
// Get tile height from the tensor descriptor type.
auto tileH = tdescTy.getDimSize(0);
// Get vblocks from the tensor descriptor type.
@@ -330,8 +398,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
@@ -339,8 +407,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
@@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
: rewriter.getIntegerType(elemBitSize));
Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
vblocks, transpose, vnni,
rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
tileH, vblocks, transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,

View File

@@ -0,0 +1,80 @@
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
gpu.module @load_store_check {
// CHECK-LABEL: gpu.func @load_store_matrix_a
// CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1>
gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel {
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
// CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
// CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
// CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
// CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
%srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
// CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
// CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
%dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4>
// CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
// CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
// CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
// CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
// CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
// CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
// CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
// CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
%loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x64xi4> -> vector<32xi4>
// CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
// CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
// CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
// CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
// CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
// CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
// CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
// CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
gpu.return
}
// CHECK-LABEL: gpu.func @load_matrix_b_request_pack
gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel {
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
%srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4>
%dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4>
// CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
// CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
// CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
%loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<32x32xi4> -> vector<64xi4>
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4>
gpu.return
}
}

View File

@@ -0,0 +1,21 @@
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
gpu.module @prefetch_check {
// CHECK-LABEL: gpu.func @prefetch_matrix_a
gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel {
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
// CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
%srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
// CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
// CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
// CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x64xi4>
gpu.return
}
}