mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 06:40:01 +08:00
[MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)
Some usage case or shapes for 2D block op with sub byte types can be emulated with 2D block operations for non-sub byte types. Add sub byte type i4 as a valid XeGPU type. And add lowering of certain 2D block operations by emulating with larger element types.
This commit is contained in:
@@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
|
||||
include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
|
||||
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
|
||||
def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16,
|
||||
SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
|
||||
def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
|
||||
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
|
||||
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
|
||||
def XeGPU_BaseAddrType
|
||||
|
||||
@@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Note:
|
||||
// Block operations for tile of sub byte element types are handled by
|
||||
// emulating with larger element types.
|
||||
// Tensor descriptor are keep intact and only ops consuming them are
|
||||
// emulated
|
||||
//
|
||||
|
||||
class CreateNdDescToXeVMPattern
|
||||
: public OpConversionPattern<xegpu::CreateNdDescOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
op, "Expected offset rank to match descriptor rank.");
|
||||
auto elemType = tdescTy.getElementType();
|
||||
auto elemBitSize = elemType.getIntOrFloatBitWidth();
|
||||
if (elemBitSize % 8 != 0)
|
||||
bool isSubByte = elemBitSize < 8;
|
||||
uint64_t wScaleFactor = 1;
|
||||
|
||||
if (!isSubByte && (elemBitSize % 8 != 0))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected element type bit width to be multiple of 8.");
|
||||
auto tileW = tdescTy.getDimSize(tileRank - 1);
|
||||
// For sub byte types, only 4bits are currently supported.
|
||||
if (isSubByte) {
|
||||
if (elemBitSize != 4)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only sub byte types of 4bits are supported.");
|
||||
if (tileRank != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Sub byte types are only supported for 2D tensor descriptors.");
|
||||
auto subByteFactor = 8 / elemBitSize;
|
||||
auto tileH = tdescTy.getDimSize(0);
|
||||
// Handle special case for packed load.
|
||||
if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
|
||||
if (op.getPacked().value_or(false)) {
|
||||
// packed load is implemented as packed loads of 8bit elements.
|
||||
if (tileH == systolicDepth * 4 &&
|
||||
tileW == executionSize * subByteFactor) {
|
||||
// Usage case for loading as Matrix B with pack request.
|
||||
// source is assumed to pre-packed into 8bit elements
|
||||
// Emulate with 8bit loads with pack request.
|
||||
// scaled_tileW = executionSize
|
||||
elemType = rewriter.getIntegerType(8);
|
||||
tileW = executionSize;
|
||||
wScaleFactor = subByteFactor;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If not handled by packed load case above, handle other cases.
|
||||
if (wScaleFactor == 1) {
|
||||
auto sub16BitFactor = subByteFactor * 2;
|
||||
if (tileW == executionSize * sub16BitFactor) {
|
||||
// Usage case for loading as Matrix A operand
|
||||
// Emulate with 16bit loads/stores.
|
||||
// scaled_tileW = executionSize
|
||||
elemType = rewriter.getIntegerType(16);
|
||||
tileW = executionSize;
|
||||
wScaleFactor = sub16BitFactor;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported tile shape for sub byte types.");
|
||||
}
|
||||
}
|
||||
// recompute element bit size for emulation.
|
||||
elemBitSize = elemType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
// Get address space from tensor descriptor memory space.
|
||||
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
|
||||
@@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
// Convert base pointer (i64) to LLVM pointer type.
|
||||
Value basePtrLLVM =
|
||||
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
|
||||
// FIXME: width or pitch is not the same as baseShapeW it should be the
|
||||
// stride of the second to last dimension in row major layout.
|
||||
// Compute width in bytes.
|
||||
Value baseWidthByte =
|
||||
Value baseShapeWInBytes =
|
||||
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
|
||||
// Compute pitch in bytes.
|
||||
Value basePitchByte =
|
||||
Value basePitchBytes =
|
||||
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
|
||||
|
||||
// Get tile width from the tensor descriptor type.
|
||||
auto tileW = tdescTy.getDimSize(tileRank - 1);
|
||||
if (wScaleFactor > 1) {
|
||||
// Scale offsetW, baseShapeWInBytes for sub byte emulation.
|
||||
// Note: tileW is already scaled above.
|
||||
Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
|
||||
rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
|
||||
baseShapeWInBytes = arith::ShRSIOp::create(
|
||||
rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
|
||||
basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
|
||||
wScaleFactorValLog2);
|
||||
offsetW =
|
||||
arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
|
||||
}
|
||||
// Get tile height from the tensor descriptor type.
|
||||
auto tileH = tdescTy.getDimSize(0);
|
||||
// Get vblocks from the tensor descriptor type.
|
||||
@@ -330,8 +398,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
auto storeCacheControl =
|
||||
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
|
||||
xevm::BlockStore2dOp::create(
|
||||
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
|
||||
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
|
||||
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
|
||||
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
|
||||
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
|
||||
rewriter.eraseOp(op);
|
||||
} else {
|
||||
@@ -339,8 +407,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
|
||||
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
|
||||
xevm::BlockPrefetch2dOp::create(
|
||||
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
|
||||
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
|
||||
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
|
||||
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
|
||||
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
|
||||
rewriter.eraseOp(op);
|
||||
} else {
|
||||
@@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
: rewriter.getIntegerType(elemBitSize));
|
||||
|
||||
Value resultFlatVec = xevm::BlockLoad2dOp::create(
|
||||
rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
|
||||
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
|
||||
vblocks, transpose, vnni,
|
||||
rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
|
||||
baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
|
||||
tileH, vblocks, transpose, vnni,
|
||||
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
|
||||
resultFlatVec = vector::BitCastOp::create(
|
||||
rewriter, loc,
|
||||
|
||||
80
mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
Normal file
80
mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
Normal file
@@ -0,0 +1,80 @@
|
||||
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
|
||||
|
||||
gpu.module @load_store_check {
|
||||
// CHECK-LABEL: gpu.func @load_store_matrix_a
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1>
|
||||
gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel {
|
||||
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
|
||||
// CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
|
||||
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
|
||||
// CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
|
||||
// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
|
||||
// CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
|
||||
// CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
|
||||
%srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
|
||||
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
|
||||
// CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
|
||||
// CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
|
||||
%dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4>
|
||||
|
||||
// CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
|
||||
// CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
|
||||
// CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
|
||||
// CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
|
||||
// CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
|
||||
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
|
||||
|
||||
// CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64>
|
||||
// CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
|
||||
// CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
|
||||
// CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
|
||||
// CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
|
||||
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
|
||||
// CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
|
||||
// CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
|
||||
%loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
|
||||
: !xegpu.tensor_desc<8x64xi4> -> vector<32xi4>
|
||||
|
||||
// CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
|
||||
// CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
|
||||
// CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
|
||||
// CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
|
||||
// CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
|
||||
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
|
||||
|
||||
// CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64>
|
||||
// CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
|
||||
// CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
|
||||
// CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
|
||||
// CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
|
||||
// CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
|
||||
// CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
|
||||
xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
|
||||
: vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: gpu.func @load_matrix_b_request_pack
|
||||
gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel {
|
||||
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
|
||||
// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
|
||||
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
|
||||
%srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4>
|
||||
%dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4>
|
||||
|
||||
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4>
|
||||
|
||||
// CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
|
||||
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
|
||||
// CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
|
||||
// CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
|
||||
%loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
|
||||
: !xegpu.tensor_desc<32x32xi4> -> vector<64xi4>
|
||||
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
21
mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
Normal file
21
mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
Normal file
@@ -0,0 +1,21 @@
|
||||
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
|
||||
|
||||
gpu.module @prefetch_check {
|
||||
// CHECK-LABEL: gpu.func @prefetch_matrix_a
|
||||
gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel {
|
||||
// CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
|
||||
// CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
|
||||
// CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
|
||||
%srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
|
||||
|
||||
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
|
||||
|
||||
// CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
|
||||
// CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
|
||||
// CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
|
||||
xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
|
||||
: !xegpu.tensor_desc<8x64xi4>
|
||||
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user