Files
llvm/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Sang Ik Lee c3579f0199 [MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)
Some usage case or shapes for 2D block op with sub byte types can be
emulated with 2D block operations for non-sub byte types. Add sub byte
type i4 as a valid XeGPU type. And add lowering of certain 2D
block operations by emulating with larger element types.
2025-12-08 09:23:07 -08:00

1221 lines
52 KiB
C++

//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
// TODO: Below are uArch dependent values, should move away from hardcoding
static constexpr int32_t systolicDepth{8};
static constexpr int32_t executionSize{16};
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
BasePtr = 0, // Base pointer (i64)
BaseShapeW = 2, // Base shape width (i32)
BaseShapeH = 3, // Base shape height (i32)
BasePitch = 4, // Base pitch (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
switch (xeGpuMemspace) {
case xegpu::MemorySpace::Global:
return static_cast<int>(xevm::AddrSpace::GLOBAL);
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
llvm_unreachable("Unknown XeGPU memory space");
}
// Get same bitwidth flat vector type of new element type.
static VectorType encodeVectorTypeTo(VectorType currentVecType,
Type toElemType) {
auto elemType = currentVecType.getElementType();
auto currentBitWidth = elemType.getIntOrFloatBitWidth();
auto newBitWidth = toElemType.getIntOrFloatBitWidth();
const int size =
currentVecType.getNumElements() * currentBitWidth / newBitWidth;
return VectorType::get(size, toElemType);
}
static xevm::LoadCacheControl
translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
std::optional<xegpu::CachePolicy> L3hint) {
auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
switch (L1hintVal) {
case xegpu::CachePolicy::CACHED:
if (L3hintVal == xegpu::CachePolicy::CACHED)
return xevm::LoadCacheControl::L1C_L2UC_L3C;
else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::LoadCacheControl::L1C_L2UC_L3UC;
else
llvm_unreachable("Unsupported cache control.");
case xegpu::CachePolicy::UNCACHED:
if (L3hintVal == xegpu::CachePolicy::CACHED)
return xevm::LoadCacheControl::L1UC_L2UC_L3C;
else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
else
llvm_unreachable("Unsupported cache control.");
case xegpu::CachePolicy::STREAMING:
if (L3hintVal == xegpu::CachePolicy::CACHED)
return xevm::LoadCacheControl::L1S_L2UC_L3C;
else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::LoadCacheControl::L1S_L2UC_L3UC;
else
llvm_unreachable("Unsupported cache control.");
case xegpu::CachePolicy::READ_INVALIDATE:
return xevm::LoadCacheControl::INVALIDATE_READ;
default:
llvm_unreachable("Unsupported cache control.");
}
}
static xevm::StoreCacheControl
translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
std::optional<xegpu::CachePolicy> L3hint) {
auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
switch (L1hintVal) {
case xegpu::CachePolicy::UNCACHED:
if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
else
llvm_unreachable("Unsupported cache control.");
case xegpu::CachePolicy::STREAMING:
if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::StoreCacheControl::L1S_L2UC_L3UC;
else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
return xevm::StoreCacheControl::L1S_L2UC_L3WB;
else
llvm_unreachable("Unsupported cache control.");
case xegpu::CachePolicy::WRITE_BACK:
if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
else
llvm_unreachable("Unsupported cache control.");
case xegpu::CachePolicy::WRITE_THROUGH:
if (L3hintVal == xegpu::CachePolicy::UNCACHED)
return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
else
llvm_unreachable("Unsupported cache control.");
default:
llvm_unreachable("Unsupported cache control.");
}
}
//
// Note:
// Block operations for tile of sub byte element types are handled by
// emulating with larger element types.
// Tensor descriptor are keep intact and only ops consuming them are
// emulated
//
class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op,
xegpu::CreateNdDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
if (mixedOffsets.size() != 0)
return rewriter.notifyMatchFailure(op, "Offsets not supported.");
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
// Payload is a 8xi32 vector. Offset to individual fields are defined in
// NdTdescOffset enum.
Type payloadElemTy = rewriter.getI32Type();
VectorType payloadTy = VectorType::get(8, payloadElemTy);
Type i64Ty = rewriter.getI64Type();
// 4xi64 view is used for inserting the base pointer.
VectorType payloadI64Ty = VectorType::get(4, i64Ty);
// Initialize payload to zero.
Value payload = arith::ConstantOp::create(
rewriter, loc,
DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
// Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
if (!sourceMemrefTy.hasRank()) {
return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
}
// Access adaptor after failure check to avoid rolling back generated code
// for materialization cast.
baseAddr = adaptor.getSource();
} else {
baseAddr = adaptor.getSource();
if (baseAddr.getType() != i64Ty) {
// Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
}
}
// 1D tensor descriptor is just the base address.
if (rank == 1) {
rewriter.replaceOp(op, baseAddr);
return success();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
unsigned idx) -> Value {
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
// Get pitch value from op fold results.
Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
payLoadAsI64 =
vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
static_cast<int>(NdTdescOffset::BasePtr));
payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
payload =
vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
static_cast<int>(NdTdescOffset::BaseShapeW));
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
payload =
vector::InsertOp::create(rewriter, loc, basePitch, payload,
static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
};
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
auto tileRank = tdescTy.getRank();
if (opOffsetsSize != tileRank)
return rewriter.notifyMatchFailure(
op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
bool isSubByte = elemBitSize < 8;
uint64_t wScaleFactor = 1;
if (!isSubByte && (elemBitSize % 8 != 0))
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
auto tileW = tdescTy.getDimSize(tileRank - 1);
// For sub byte types, only 4bits are currently supported.
if (isSubByte) {
if (elemBitSize != 4)
return rewriter.notifyMatchFailure(
op, "Only sub byte types of 4bits are supported.");
if (tileRank != 2)
return rewriter.notifyMatchFailure(
op, "Sub byte types are only supported for 2D tensor descriptors.");
auto subByteFactor = 8 / elemBitSize;
auto tileH = tdescTy.getDimSize(0);
// Handle special case for packed load.
if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
if (op.getPacked().value_or(false)) {
// packed load is implemented as packed loads of 8bit elements.
if (tileH == systolicDepth * 4 &&
tileW == executionSize * subByteFactor) {
// Usage case for loading as Matrix B with pack request.
// source is assumed to pre-packed into 8bit elements
// Emulate with 8bit loads with pack request.
// scaled_tileW = executionSize
elemType = rewriter.getIntegerType(8);
tileW = executionSize;
wScaleFactor = subByteFactor;
}
}
}
// If not handled by packed load case above, handle other cases.
if (wScaleFactor == 1) {
auto sub16BitFactor = subByteFactor * 2;
if (tileW == executionSize * sub16BitFactor) {
// Usage case for loading as Matrix A operand
// Emulate with 16bit loads/stores.
// scaled_tileW = executionSize
elemType = rewriter.getIntegerType(16);
tileW = executionSize;
wScaleFactor = sub16BitFactor;
} else {
return rewriter.notifyMatchFailure(
op, "Unsupported tile shape for sub byte types.");
}
}
// recompute element bit size for emulation.
elemBitSize = elemType.getIntOrFloatBitWidth();
}
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
if (tileRank == 2) {
// Compute element byte size.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
Value basePtr =
vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
static_cast<int>(NdTdescOffset::BasePtr));
Value baseShapeW = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
Value basePitch = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetW);
Value offsetH =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetH);
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// FIXME: width or pitch is not the same as baseShapeW it should be the
// stride of the second to last dimension in row major layout.
// Compute width in bytes.
Value baseShapeWInBytes =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
// Compute pitch in bytes.
Value basePitchBytes =
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
if (wScaleFactor > 1) {
// Scale offsetW, baseShapeWInBytes for sub byte emulation.
// Note: tileW is already scaled above.
Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
baseShapeWInBytes = arith::ShRSIOp::create(
rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
wScaleFactorValLog2);
offsetW =
arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
}
// Get tile height from the tensor descriptor type.
auto tileH = tdescTy.getDimSize(0);
// Get vblocks from the tensor descriptor type.
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
const bool vnni = op.getPacked().value_or(false);
auto transposeValue = op.getTranspose();
bool transpose =
transposeValue.has_value() && transposeValue.value()[0] == 1;
VectorType loadedTy = encodeVectorTypeTo(
dstVecTy, vnni ? rewriter.getI32Type()
: rewriter.getIntegerType(elemBitSize));
Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
tileH, vblocks, transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
resultFlatVec);
rewriter.replaceOp(op, resultFlatVec);
}
}
} else {
// 1D tensor descriptor.
// `tdesc` represents base address as i64
// Offset in number of elements, need to multiply by element byte size.
// Compute byte offset.
// byteOffset = offset * elementByteSize
Value offset =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI64Type(), offset);
// Compute element byte size.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
Value byteOffset =
rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
// Final address = basePtr + byteOffset
Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
loc, tdesc,
getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
byteOffset));
// Convert base pointer (i64) to LLVM pointer type.
Value finalPtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
op, finalPtrLLVM, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
} else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
VectorType resTy = cast<VectorType>(op.getValue().getType());
VectorType loadedTy =
encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
Value load = xevm::BlockLoadOp::create(
rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
if (loadedTy != resTy)
load = vector::BitCastOp::create(rewriter, loc, resTy, load);
rewriter.replaceOp(op, load);
} else {
return rewriter.notifyMatchFailure(
op, "Unsupported operation: xegpu.prefetch_nd with tensor "
"descriptor rank == 1");
}
}
return success();
}
};
// Add a builder that creates
// offset * elemByteSize + baseAddr
static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
Location loc, Value baseAddr, Value offset,
int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
}
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value offset = adaptor.getOffsets();
if (!offset)
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
Value basePtrI64;
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
valOrResTy =
this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
bool hasScalarVal = !valOrResVecTy;
int64_t elemBitWidth =
hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
: valOrResVecTy.getElementType().getIntOrFloatBitWidth();
// Element type must be multiple of 8 bits.
if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
int64_t elemByteSize = elemBitWidth / 8;
// Default memory space is global.
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
// If tensor descriptor is available, we use its memory space.
if (tdescTy)
ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
// Base pointer can come from source (load) or dest (store).
// If they are memrefs, we use their memory space.
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
basePtrI64 = adaptor.getSource();
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
auto addrSpace = memRefTy.getMemorySpaceAsInt();
if (addrSpace != 0)
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
}
} else {
basePtrI64 = adaptor.getDest();
if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
auto addrSpace = memRefTy.getMemorySpaceAsInt();
if (addrSpace != 0)
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
}
}
// Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
if (basePtrI64.getType() != rewriter.getI64Type()) {
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
basePtrI64);
}
Value mask = adaptor.getMask();
if (dyn_cast<VectorType>(offset.getType())) {
// Offset needs be scalar. Single element vector is converted to scalar
// by type converter.
return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
} else {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
basePtrI64 =
addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
Value maskForLane;
VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
if (maskVecTy) {
// Mask needs be scalar. Single element vector is converted to scalar by
// type converter.
return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
} else
maskForLane = mask;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
maskForLane, true, true);
// If mask is true,- then clause - load from memory and yield.
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
if (!hasScalarVal)
valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
valOrResVecTy.getElementType());
Value loaded =
LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
// Set cache control attribute on the load operation.
loaded.getDefiningOp()->setAttr(
"cache_control", xevm::LoadCacheControlAttr::get(
ctxt, translateLoadXeGPUCacheHint(
op.getL1Hint(), op.getL3Hint())));
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
// If mask is false - else clause -yield a vector of zeros.
auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
TypedAttr eVal;
if (eTy.isFloat())
eVal = FloatAttr::get(eTy, 0.0);
else
eVal = IntegerAttr::get(eTy, 0);
if (hasScalarVal)
loaded = arith::ConstantOp::create(rewriter, loc, eVal);
else
loaded = arith::ConstantOp::create(
rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
rewriter.replaceOp(op, ifOp.getResult(0));
} else {
// If mask is true, perform the store.
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
auto body = ifOp.getBody();
rewriter.setInsertionPointToStart(body);
auto storeOp =
LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
// Set cache control attribute on the store operation.
storeOp.getOperation()->setAttr(
"cache_control", xevm::StoreCacheControlAttr::get(
ctxt, translateStoreXeGPUCacheHint(
op.getL1Hint(), op.getL3Hint())));
rewriter.eraseOp(op);
}
return success();
}
};
class CreateMemDescOpPattern final
: public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
Value data;
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
data = op.getResult();
else
data = adaptor.getData();
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
if (valOrResVecTy.getShape().size() != 1)
return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
// Element type must be multiple of 8 bits.
if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
int64_t elemByteSize = elemBitWidth / 8;
// Default memory space is SLM.
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
linearOffset, elemByteSize);
// convert base pointer (i32) to LLVM pointer type
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
if (op.getSubgroupBlockIoAttr()) {
// if the attribute 'subgroup_block_io' is set to true, it lowers to
// xevm.blockload
Type intElemTy = rewriter.getIntegerType(elemBitWidth);
VectorType intVecTy =
VectorType::get(valOrResVecTy.getShape(), intElemTy);
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
Value loadOp =
xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
if (intVecTy != valOrResVecTy) {
loadOp =
vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
}
rewriter.replaceOp(op, loadOp);
} else {
Value dataToStore = adaptor.getData();
if (valOrResVecTy != intVecTy) {
dataToStore =
vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
}
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
nullptr);
rewriter.eraseOp(op);
}
return success();
}
if (valOrResVecTy.getNumElements() >= 1) {
auto chipOpt = xegpu::getChipStr(op);
if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
// the lowering for chunk load only works for pvc and bmg
return rewriter.notifyMatchFailure(
op, "The lowering is specific to pvc or bmg.");
}
}
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
// if the size of valOrResVecTy is 1, it lowers to a scalar load/store
// operation. LLVM load/store does not support vector of size 1, so we
// need to handle this case separately.
auto scalarTy = valOrResVecTy.getElementType();
LLVM::LoadOp loadOp;
if (valOrResVecTy.getNumElements() == 1)
loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
else
loadOp =
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
rewriter.replaceOp(op, loadOp);
} else {
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
rewriter.eraseOp(op);
}
return success();
}
};
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
Value basePtrI64 = adaptor.getSource();
// Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
if (basePtrI64.getType() != rewriter.getI64Type())
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
basePtrI64);
Value offsets = adaptor.getOffsets();
if (offsets) {
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
if (offsetsVecTy) {
// Offset needs be scalar.
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
} else {
int64_t elemBitWidth{0};
int64_t elemByteSize;
// Element byte size can come from three sources:
if (tdescTy) {
// If tensor descriptor is available, we use its element type to
// determine element byte size.
elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
} else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
// If memref is available, we use its element type to
// determine element byte size.
elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
} else {
// Otherwise, we use the provided offset byte alignment.
elemByteSize = *op.getOffsetAlignByte();
}
if (elemBitWidth != 0) {
if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
elemByteSize);
}
}
// Default memory space is global.
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
// If tensor descriptor is available, we use its memory space.
if (tdescTy)
ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
// If source is a memref, we use its memory space.
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
auto addrSpace = memRefTy.getMemorySpaceAsInt();
if (addrSpace != 0)
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
}
// Convert base pointer (i64) to LLVM pointer type.
Value ptrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
// Create the prefetch op with cache control attribute.
xevm::PrefetchOp::create(
rewriter, loc, ptrLLVM,
xevm::LoadCacheControlAttr::get(
ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
rewriter.eraseOp(op);
return success();
}
};
class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
switch (op.getFenceScope()) {
case xegpu::FenceScope::Workgroup:
memScope = xevm::MemScope::WORKGROUP;
break;
case xegpu::FenceScope::GPU:
memScope = xevm::MemScope::DEVICE;
break;
}
xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
switch (op.getMemoryKind()) {
case xegpu::MemorySpace::Global:
addrSpace = xevm::AddrSpace::GLOBAL;
break;
case xegpu::MemorySpace::SLM:
addrSpace = xevm::AddrSpace::SHARED;
break;
}
xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
rewriter.eraseOp(op);
return success();
}
};
class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto aTy = cast<VectorType>(op.getLhs().getType());
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;
else if (type == rewriter.getF16Type())
return xevm::ElemType::F16;
else if (type == rewriter.getTF32Type())
return xevm::ElemType::TF32;
else if (type.isInteger(8)) {
if (type.isUnsignedInteger())
return xevm::ElemType::U8;
return xevm::ElemType::S8;
} else if (type == rewriter.getF32Type())
return xevm::ElemType::F32;
else if (type.isInteger(32))
return xevm::ElemType::S32;
llvm_unreachable("add more support for ElemType");
};
xevm::ElemType precATy = encodePrecision(aTy.getElementType());
xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
Value c = op.getAcc();
if (!c) {
auto elementTy = resultType.getElementType();
Attribute initValueAttr;
if (isa<FloatType>(elementTy))
initValueAttr = FloatAttr::get(elementTy, 0.0);
else
initValueAttr = IntegerAttr::get(elementTy, 0);
c = arith::ConstantOp::create(
rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
}
Value aVec = op.getLhs();
Value bVec = op.getRhs();
auto cvecty = cast<VectorType>(c.getType());
xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
VectorType cNty =
VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
if (cvecty != cNty)
c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
Value dpasRes = xevm::MMAOp::create(
rewriter, loc, cNty, aVec, bVec, c,
xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
systolicDepth *
getNumOperandsPerDword(precATy)),
xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
if (cvecty != cNty)
dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
rewriter.replaceOp(op, dpasRes);
return success();
}
private:
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
switch (pTy) {
case xevm::ElemType::TF32:
return 1;
case xevm::ElemType::BF16:
case xevm::ElemType::F16:
return 2;
case xevm::ElemType::U8:
case xevm::ElemType::S8:
return 4;
default:
llvm_unreachable("unsupported xevm::ElemType");
}
}
};
static std::optional<LLVM::AtomicBinOp>
matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
switch (arithKind) {
case arith::AtomicRMWKind::addf:
return LLVM::AtomicBinOp::fadd;
case arith::AtomicRMWKind::addi:
return LLVM::AtomicBinOp::add;
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
case arith::AtomicRMWKind::minu:
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
return std::nullopt;
}
}
class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdesc = op.getTensorDesc().getType();
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
Value basePtrI64 = arith::IndexCastOp::create(
rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
VectorType srcOrDstFlatVecTy = VectorType::get(
srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
Value srcFlatVec = vector::ShapeCastOp::create(
rewriter, loc, srcOrDstFlatVecTy, op.getValue());
auto atomicKind = matchSimpleAtomicOp(op.getKind());
assert(atomicKind.has_value());
Value resVec = srcFlatVec;
for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
rewriter.getIndexAttr(i));
Value currPtr =
LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
Value newVal =
LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
val, LLVM::AtomicOrdering::seq_cst);
resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
}
rewriter.replaceOp(op, resVec);
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
struct ConvertXeGPUToXeVMPass
: public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
using Base::Base;
void runOnOperation() override {
LLVMTypeConverter typeConverter(&getContext());
typeConverter.addConversion([&](VectorType type) -> Type {
unsigned rank = type.getRank();
auto elemType = type.getElementType();
// If the element type is index, convert it to i64.
if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);
// If the vector is a scalar or has a single element, return the element
if (rank < 1 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
int64_t sum = llvm::product_of(type.getShape());
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
// Scattered descriptors are not supported in XeVM lowering.
if (type.isScattered())
return {};
if (type.getRank() == 1)
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
// Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
return IntegerType::get(&getContext(), 32);
});
typeConverter.addConversion([&](MemRefType type) -> Type {
if (type.getMemorySpaceAsInt() == 3)
return IntegerType::get(&getContext(), 32);
return IntegerType::get(&getContext(), 64);
});
// LLVM type converter puts unrealized casts for the following cases:
// add materialization casts to handle them.
// Materialization to convert memref to i64
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
Value addr =
memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
return arith::IndexCastUIOp::create(builder, loc, type, addr)
.getResult();
}
return {};
};
// Materialization to convert ui64 to i64
auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (input.getType() == builder.getIntegerType(64, false)) {
Value cast =
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
.getResult();
return arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
}
return {};
};
// Materialization to convert ui32 to i32
auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (input.getType() == builder.getIntegerType(32, false)) {
Value cast =
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
.getResult();
return arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
}
return {};
};
// Materialization to convert
// - single element 1D vector to scalar
// - bitcast vector of same rank
// - shape vector of different rank but same element type
auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
if (vecTy.getNumElements() == 1) {
// If the vector has a single element, return the element type.
Value cast =
vector::ExtractOp::create(builder, loc, input, 0).getResult();
if (vecTy.getElementType() == builder.getIndexType())
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
return cast;
} else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
// If the target type is a vector of same rank,
// bitcast to the target type.
if (targetVecTy.getRank() == vecTy.getRank())
return vector::BitCastOp::create(builder, loc, targetVecTy, input)
.getResult();
else if (targetVecTy.getElementType() == vecTy.getElementType()) {
// If the target type is a vector of different rank but same element
// type, reshape to the target type.
return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
.getResult();
}
}
}
return {};
};
// If result type of original op is single element vector and lowered type
// is scalar. This materialization cast creates a single element vector by
// broadcasting the scalar value.
auto singleElementVectorMaterializationCast =
[](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (input.getType().isIntOrIndexOrFloat()) {
// If the input is a scalar, and the target type is a vector of single
// element, create a single element vector by broadcasting.
if (auto vecTy = dyn_cast<VectorType>(type)) {
if (vecTy.getNumElements() == 1) {
return vector::BroadcastOp::create(builder, loc, vecTy, input)
.getResult();
}
}
}
return {};
};
typeConverter.addSourceMaterialization(
singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
typeConverter.addTargetMaterialization(vectorMaterializationCast);
ConversionTarget target(getContext());
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
vector::VectorDialect, arith::ArithDialect,
memref::MemRefDialect, gpu::GPUDialect,
index::IndexDialect>();
target.addIllegalDialect<xegpu::XeGPUDialect>();
RewritePatternSet patterns(&getContext());
populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
void mlir::populateXeGPUToXeVMConversionPatterns(
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<CreateNdDescToXeVMPattern,
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
typeConverter, patterns.getContext());
patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
CreateMemDescOpPattern>(typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}