mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 23:49:22 +08:00
Revert "[MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix" (#163684)
Reverts llvm/llvm-project#162780 Breaks build bots, see #162780.
This commit is contained in:
@@ -712,14 +712,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
|
||||
return getAttrs().contains(name);
|
||||
}
|
||||
|
||||
ArrayAttr getStrideAttr() {
|
||||
ArrayAttr getStrides() {
|
||||
return getAttrs().getAs<ArrayAttr>("stride");
|
||||
}
|
||||
|
||||
ArrayAttr getBlockAttr() {
|
||||
return getAttrs().getAs<ArrayAttr>("block");
|
||||
}
|
||||
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
@@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
|
||||
}
|
||||
|
||||
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
|
||||
AllElementTypesMatch<["mem_desc", "res"]>]> {
|
||||
AllElementTypesMatch<["mem_desc", "res"]>,
|
||||
AllRanksMatch<["mem_desc", "res"]>]> {
|
||||
let arguments = (ins XeGPU_MemDesc:$mem_desc,
|
||||
Variadic<Index>: $offsets,
|
||||
DenseI64ArrayAttr: $const_offsets,
|
||||
OptionalAttr<UnitAttr>:$subgroup_block_io,
|
||||
OptionalAttr<DistributeLayoutAttr>:$layout
|
||||
);
|
||||
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
|
||||
let results = (outs XeGPU_ValueType:$res);
|
||||
let assemblyFormat = [{
|
||||
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
|
||||
prop-dict attr-dict `` `:` type(operands) `->` type(results)
|
||||
@@ -1319,9 +1319,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
|
||||
Arguments:
|
||||
- `mem_desc`: the memory descriptor identifying the SLM region.
|
||||
- `offsets`: the coordinates within the matrix to read from.
|
||||
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
|
||||
lowered to a subgroup block load. When this attribute is present,
|
||||
the offsets are subgroup-uniform across all lanes.
|
||||
- `layout`: [optional] An attribute for guiding distributions among
|
||||
subgroups and/or work-items. It currently can accept either
|
||||
LayoutAttr or SliceAttr.
|
||||
@@ -1339,10 +1336,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
auto resTy = getRes().getType();
|
||||
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
|
||||
return vecTy.getShape();
|
||||
return {};
|
||||
return getRes().getType().getShape();
|
||||
}
|
||||
}];
|
||||
|
||||
@@ -1350,13 +1344,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
|
||||
}
|
||||
|
||||
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
|
||||
AllElementTypesMatch<["mem_desc", "data"]>]> {
|
||||
AllElementTypesMatch<["mem_desc", "data"]>,
|
||||
AllRanksMatch<["mem_desc", "data"]>]> {
|
||||
let arguments = (ins
|
||||
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
|
||||
XeGPU_ValueType:$data,
|
||||
XeGPU_MemDesc:$mem_desc,
|
||||
Variadic<Index>: $offsets,
|
||||
DenseI64ArrayAttr: $const_offsets,
|
||||
OptionalAttr<UnitAttr>:$subgroup_block_io,
|
||||
OptionalAttr<DistributeLayoutAttr>:$layout
|
||||
);
|
||||
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
|
||||
@@ -1370,9 +1364,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
|
||||
- `mem_desc`: the memory descriptor specifying the SLM region.
|
||||
- `offsets`: the coordinates within the matrix where the data will be written.
|
||||
- `data`: the values to be stored in the matrix.
|
||||
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
|
||||
lowered to a subgroup block store. When this attribute is present,
|
||||
the offsets are subgroup-uniform across all lanes.
|
||||
- `layout`: [optional] An attribute for guiding distributions among
|
||||
subgroups and/or work-items. It currently can accept either
|
||||
LayoutAttr or SliceAttr.
|
||||
@@ -1387,10 +1378,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getDataShape() {
|
||||
auto DataTy = getData().getType();
|
||||
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
|
||||
return vecTy.getShape();
|
||||
return {};
|
||||
return getData().getType().getShape();
|
||||
}
|
||||
|
||||
}];
|
||||
@@ -1398,4 +1386,41 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
|
||||
[Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
|
||||
let description = [{
|
||||
Creates a subview of a memory descriptor. The resulting memory descriptor can have
|
||||
a lower rank than the source; in this case, the result dimensions correspond to the
|
||||
higher-order dimensions of the source memory descriptor.
|
||||
|
||||
Arguments:
|
||||
- `src` : a memory descriptor.
|
||||
- `offsets` : the coordinates within the matrix the subview will be created from.
|
||||
|
||||
Results:
|
||||
- `res` : a memory descriptor with smaller size.
|
||||
|
||||
}];
|
||||
let arguments = (ins XeGPU_MemDesc:$src,
|
||||
Variadic<Index>:$offsets,
|
||||
DenseI64ArrayAttr:$const_offsets);
|
||||
let results = (outs XeGPU_MemDesc:$res);
|
||||
let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
|
||||
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::Value getViewSource() { return getSrc(); }
|
||||
|
||||
SmallVector<OpFoldResult> getMixedOffsets() {
|
||||
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
|
||||
|
||||
@@ -237,11 +237,12 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
|
||||
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
|
||||
}
|
||||
|
||||
ArrayAttr getStrideAttr() {
|
||||
ArrayAttr getStrides() {
|
||||
auto layout = getMemLayout();
|
||||
if (layout && layout.hasAttr("stride")) {
|
||||
return layout.getStrideAttr();
|
||||
return layout.getStrides();
|
||||
}
|
||||
|
||||
// derive and return default strides
|
||||
SmallVector<int64_t> defaultStrides;
|
||||
llvm::append_range(defaultStrides, getShape().drop_front());
|
||||
@@ -249,63 +250,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
|
||||
Builder builder(getContext());
|
||||
return builder.getI64ArrayAttr(defaultStrides);
|
||||
}
|
||||
|
||||
ArrayAttr getBlockAttr() {
|
||||
auto layout = getMemLayout();
|
||||
if (layout && layout.hasAttr("block")) {
|
||||
return layout.getBlockAttr();
|
||||
}
|
||||
Builder builder(getContext());
|
||||
return builder.getI64ArrayAttr({});
|
||||
}
|
||||
|
||||
/// Heuristic to determine if the MemDesc uses column-major layout,
|
||||
/// based on the rank and the value of the first stride dimension.
|
||||
bool isColMajor() {
|
||||
auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
|
||||
return getRank() == 2 && dim0.getInt() == 1;
|
||||
}
|
||||
|
||||
// Get the Blocking shape for a MemDescType, Which is represented
|
||||
// as an attribute in MemDescType. By default it is the shape
|
||||
// of the mdescTy
|
||||
SmallVector<int64_t> getBlockShape() {
|
||||
SmallVector<int64_t> size(getShape());
|
||||
ArrayAttr blockAttr = getBlockAttr();
|
||||
if (!blockAttr.empty()) {
|
||||
size.clear();
|
||||
for (auto attr : blockAttr.getValue()) {
|
||||
size.push_back(cast<IntegerAttr>(attr).getInt());
|
||||
}
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
// Get strides as vector of integer.
|
||||
// If it contains block attribute, the strides are blocked strides.
|
||||
//
|
||||
// The blocking is applied to the base matrix shape derived from the
|
||||
// memory descriptor's stride information. If the matrix described by
|
||||
// the memory descriptor is not contiguous, it is assumed that the base
|
||||
// matrix is contiguous and follows the same memory layout.
|
||||
//
|
||||
// It first computes the original matrix shape using the stride info,
|
||||
// then computes the number of blocks in each dimension of original shape,
|
||||
// then compute the outer block shape and stride,
|
||||
// then combines the inner and outer block shape and stride
|
||||
// e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
|
||||
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
|
||||
// for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
|
||||
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
|
||||
SmallVector<int64_t> getStrideShape();
|
||||
|
||||
/// Generates instructions to compute the linearize offset
|
||||
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
|
||||
// the strides of memory descriptor is always considered regardless of blocked or not
|
||||
Value getLinearOffsets(OpBuilder &builder,
|
||||
Location loc, ArrayRef<OpFoldResult> offsets);
|
||||
|
||||
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = true;
|
||||
|
||||
@@ -21,7 +21,6 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
|
||||
MLIRIndexDialect
|
||||
MLIRSCFDialect
|
||||
MLIRXeGPUDialect
|
||||
MLIRXeGPUUtils
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRSCFTransforms
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
||||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@@ -64,7 +63,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
|
||||
case xegpu::MemorySpace::SLM:
|
||||
return static_cast<int>(xevm::AddrSpace::SHARED);
|
||||
}
|
||||
llvm_unreachable("Unknown XeGPU memory space");
|
||||
}
|
||||
|
||||
// Get same bitwidth flat vector type of new element type.
|
||||
@@ -188,7 +186,6 @@ class CreateNdDescToXeVMPattern
|
||||
int64_t rank = mixedSizes.size();
|
||||
if (rank != 2)
|
||||
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
|
||||
|
||||
auto sourceTy = source.getType();
|
||||
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
|
||||
// If source is a memref, we need to extract the aligned pointer as index.
|
||||
@@ -367,11 +364,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
|
||||
// Add a builder that creates
|
||||
// offset * elemByteSize + baseAddr
|
||||
static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value baseAddr, Value offset,
|
||||
int64_t elemByteSize) {
|
||||
static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value baseAddr, Value offset, int64_t elemByteSize) {
|
||||
Value byteSize = arith::ConstantIntOp::create(
|
||||
rewriter, loc, baseAddr.getType(), elemByteSize);
|
||||
rewriter, loc, rewriter.getI64Type(), elemByteSize);
|
||||
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
|
||||
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
|
||||
return newAddr;
|
||||
@@ -447,8 +443,7 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
// If offset is provided, we add them to the base pointer.
|
||||
// Offset is in number of elements, we need to multiply by
|
||||
// element byte size.
|
||||
basePtrI64 =
|
||||
addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
|
||||
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
|
||||
}
|
||||
// Convert base pointer (i64) to LLVM pointer type.
|
||||
Value basePtrLLVM =
|
||||
@@ -511,147 +506,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
}
|
||||
};
|
||||
|
||||
// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
|
||||
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
|
||||
// 32 bits will be converted to 32 bits.
|
||||
class CreateMemDescOpPattern final
|
||||
: public OpConversionPattern<xegpu::CreateMemDescOp> {
|
||||
public:
|
||||
using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto resTy = op.getMemDesc();
|
||||
|
||||
// Create the result MemRefType with the same shape, element type, and
|
||||
// memory space
|
||||
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
|
||||
|
||||
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
|
||||
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
|
||||
op.getSource(), zero, ValueRange());
|
||||
rewriter.replaceOp(op, viewOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpType,
|
||||
typename = std::enable_if_t<llvm::is_one_of<
|
||||
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
|
||||
class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
|
||||
using OpConversionPattern<OpType>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
|
||||
if (offsets.empty())
|
||||
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto ctxt = rewriter.getContext();
|
||||
Value basePtrStruct = adaptor.getMemDesc();
|
||||
Value mdescVal = op.getMemDesc();
|
||||
// Load result or Store value Type can be vector or scalar.
|
||||
Value data;
|
||||
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
|
||||
data = op.getResult();
|
||||
else
|
||||
data = adaptor.getData();
|
||||
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
|
||||
if (!valOrResVecTy)
|
||||
valOrResVecTy = VectorType::get(1, data.getType());
|
||||
|
||||
int64_t elemBitWidth =
|
||||
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
|
||||
// Element type must be multiple of 8 bits.
|
||||
if (elemBitWidth % 8 != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected element type bit width to be multiple of 8.");
|
||||
int64_t elemByteSize = elemBitWidth / 8;
|
||||
|
||||
// Default memory space is SLM.
|
||||
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
|
||||
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
|
||||
|
||||
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
|
||||
|
||||
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
|
||||
rewriter, loc, basePtrStruct);
|
||||
|
||||
// Convert base pointer (ptr) to i32
|
||||
Value basePtrI32 = arith::IndexCastUIOp::create(
|
||||
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
|
||||
|
||||
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
|
||||
linearOffset = arith::IndexCastUIOp::create(
|
||||
rewriter, loc, rewriter.getI32Type(), linearOffset);
|
||||
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
|
||||
elemByteSize);
|
||||
|
||||
// convert base pointer (i32) to LLVM pointer type
|
||||
basePtrLLVM =
|
||||
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
|
||||
|
||||
if (op.getSubgroupBlockIoAttr()) {
|
||||
// if the attribute 'subgroup_block_io' is set to true, it lowers to
|
||||
// xevm.blockload
|
||||
|
||||
Type intElemTy = rewriter.getIntegerType(elemBitWidth);
|
||||
VectorType intVecTy =
|
||||
VectorType::get(valOrResVecTy.getShape(), intElemTy);
|
||||
|
||||
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
|
||||
Value loadOp =
|
||||
xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
|
||||
if (intVecTy != valOrResVecTy) {
|
||||
loadOp =
|
||||
vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
|
||||
}
|
||||
rewriter.replaceOp(op, loadOp);
|
||||
} else {
|
||||
Value dataToStore = adaptor.getData();
|
||||
if (valOrResVecTy != intVecTy) {
|
||||
dataToStore =
|
||||
vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
|
||||
}
|
||||
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
|
||||
nullptr);
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (valOrResVecTy.getNumElements() >= 1) {
|
||||
auto chipOpt = xegpu::getChipStr(op);
|
||||
if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
|
||||
// the lowering for chunk load only works for pvc and bmg
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The lowering is specific to pvc or bmg.");
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
|
||||
// if the size of valOrResVecTy is 1, it lowers to a scalar load/store
|
||||
// operation. LLVM load/store does not support vector of size 1, so we
|
||||
// need to handle this case separately.
|
||||
auto scalarTy = valOrResVecTy.getElementType();
|
||||
LLVM::LoadOp loadOp;
|
||||
if (valOrResVecTy.getNumElements() == 1)
|
||||
loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
|
||||
else
|
||||
loadOp =
|
||||
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
|
||||
rewriter.replaceOp(op, loadOp);
|
||||
} else {
|
||||
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
@@ -694,8 +548,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
|
||||
op, "Expected element type bit width to be multiple of 8.");
|
||||
elemByteSize = elemBitWidth / 8;
|
||||
}
|
||||
basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
|
||||
elemByteSize);
|
||||
basePtrI64 =
|
||||
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
|
||||
}
|
||||
}
|
||||
// Default memory space is global.
|
||||
@@ -932,13 +786,6 @@ struct ConvertXeGPUToXeVMPass
|
||||
auto i32Type = IntegerType::get(&getContext(), 32);
|
||||
return VectorType::get(8, i32Type);
|
||||
});
|
||||
// Convert MemDescType into flattened MemRefType for SLM
|
||||
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
|
||||
Type elemTy = type.getElementType();
|
||||
int numElems = type.getNumElements();
|
||||
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
|
||||
});
|
||||
|
||||
typeConverter.addConversion([&](MemRefType type) -> Type {
|
||||
// Convert MemRefType to i64 type.
|
||||
return IntegerType::get(&getContext(), 64);
|
||||
@@ -1093,9 +940,6 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
|
||||
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
|
||||
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
|
||||
LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
|
||||
CreateMemDescOpPattern>(typeConverter, patterns.getContext());
|
||||
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -727,152 +727,6 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
|
||||
}
|
||||
printer << ">";
|
||||
}
|
||||
// a helper utility to perform binary operation on OpFoldResult.
|
||||
// If both a and b are attributes, it will simply return the result.
|
||||
// Otherwise, the corresponding arith op will be generated, and an
|
||||
// contant op will be created if one of them is an attribute.
|
||||
template <typename ArithOp>
|
||||
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
|
||||
OpBuilder &builder) {
|
||||
auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
|
||||
auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
|
||||
return builder.create<ArithOp>(loc, aVal, bVal).getResult();
|
||||
}
|
||||
|
||||
// a helper utility to perform division operation on OpFoldResult and int64_t.
|
||||
#define div(a, b) \
|
||||
genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
|
||||
|
||||
// a helper utility to perform reminder operation on OpFoldResult and int64_t.
|
||||
#define rem(a, b) \
|
||||
genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
|
||||
|
||||
// a helper utility to perform multiply operation on OpFoldResult and int64_t.
|
||||
#define mul(a, b) \
|
||||
genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
|
||||
|
||||
// a helper utility to perform addition operation on two OpFoldResult.
|
||||
#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
|
||||
|
||||
// block the given offsets according to the block shape
|
||||
// say the original offset is [y, x], and the block shape is [By, Bx],
|
||||
// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
|
||||
SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<int64_t> blockShape) {
|
||||
|
||||
assert(offsets.size() == blockShape.size() &&
|
||||
"offsets and blockShape must have the same size");
|
||||
SmallVector<OpFoldResult> blockedOffsets;
|
||||
SmallVector<OpFoldResult> divs, rems;
|
||||
|
||||
for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
|
||||
divs.push_back(div(offset, block));
|
||||
rems.push_back(rem(offset, block));
|
||||
}
|
||||
blockedOffsets.append(divs.begin(), divs.end());
|
||||
blockedOffsets.append(rems.begin(), rems.end());
|
||||
|
||||
return blockedOffsets;
|
||||
}
|
||||
|
||||
// Get strides as vector of integer for MemDesc.
|
||||
SmallVector<int64_t> MemDescType::getStrideShape() {
|
||||
|
||||
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
|
||||
|
||||
ArrayAttr strideAttr = getStrideAttr();
|
||||
SmallVector<int64_t> strides;
|
||||
for (Attribute attr : strideAttr.getValue()) {
|
||||
strides.push_back(cast<IntegerAttr>(attr).getInt());
|
||||
}
|
||||
|
||||
SmallVector<int64_t> innerBlkShape = getBlockShape();
|
||||
|
||||
// get perm from FCD to LCD
|
||||
// perm[i] = the dim with i-th smallest stride
|
||||
SmallVector<int, 4> perm =
|
||||
llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
|
||||
llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
|
||||
|
||||
assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
|
||||
|
||||
SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
|
||||
innerBlkStride[perm[0]] = 1;
|
||||
for (size_t i = 1; i < perm.size(); ++i)
|
||||
innerBlkStride[perm[i]] =
|
||||
innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
|
||||
|
||||
// compute the original matrix shape using the stride info
|
||||
// and compute the number of blocks in each dimension
|
||||
// The shape of highest dim can't be derived from stride info,
|
||||
// but doesn't impact the stride computation for blocked layout.
|
||||
SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
|
||||
SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
|
||||
for (size_t i = 0; i < perm.size() - 1; ++i) {
|
||||
matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
|
||||
BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
|
||||
}
|
||||
|
||||
int64_t innerBlkSize = 1;
|
||||
for (auto s : innerBlkShape)
|
||||
innerBlkSize *= s;
|
||||
|
||||
SmallVector<int64_t> outerBlkStride(matrixShape.size());
|
||||
outerBlkStride[perm[0]] = innerBlkSize;
|
||||
for (size_t i = 0; i < perm.size() - 1; ++i) {
|
||||
outerBlkStride[perm[i + 1]] =
|
||||
outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
|
||||
}
|
||||
|
||||
// combine the inner and outer strides
|
||||
SmallVector<int64_t> blockedStrides;
|
||||
blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
|
||||
blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
|
||||
|
||||
return blockedStrides;
|
||||
}
|
||||
|
||||
// Calculate the linear offset using the blocked offsets and stride
|
||||
Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
|
||||
ArrayRef<OpFoldResult> offsets) {
|
||||
|
||||
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
|
||||
SmallVector<int64_t> blockShape = getBlockShape();
|
||||
SmallVector<int64_t> strides = getStrideShape();
|
||||
|
||||
// blockshape equal to matrixshape means no blocking
|
||||
if (llvm::equal(blockShape, matrixShape)) {
|
||||
// remove the outer dims from strides
|
||||
strides.erase(strides.begin(), strides.begin() + matrixShape.size());
|
||||
} else {
|
||||
assert(offsets.size() == blockShape.size() &&
|
||||
"offsets and blockShape must have the same size");
|
||||
// say the original offset is [y, x], and the block shape is [By, Bx],
|
||||
// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
|
||||
SmallVector<OpFoldResult> blockedOffsets;
|
||||
SmallVector<OpFoldResult> divs, rems;
|
||||
|
||||
for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
|
||||
divs.push_back(div(offset, block));
|
||||
rems.push_back(rem(offset, block));
|
||||
}
|
||||
blockedOffsets.append(divs.begin(), divs.end());
|
||||
blockedOffsets.append(rems.begin(), rems.end());
|
||||
|
||||
offsets = blockedOffsets;
|
||||
}
|
||||
|
||||
// Start with initial value as matrix descriptor's base offset.
|
||||
Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||
for (size_t i = 0; i < offsets.size(); ++i) {
|
||||
OpFoldResult mulResult = mul(offsets[i], strides[i]);
|
||||
Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
|
||||
linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
|
||||
}
|
||||
|
||||
return linearOffset;
|
||||
}
|
||||
|
||||
} // namespace xegpu
|
||||
} // namespace mlir
|
||||
|
||||
@@ -173,49 +173,6 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
|
||||
UnitAttr subgroup_block_io,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
|
||||
if (!dataTy) {
|
||||
if (subgroup_block_io)
|
||||
return emitError() << "subgroup_block_io "
|
||||
"are only allowed when result is a 1D VectorType.";
|
||||
else
|
||||
return success();
|
||||
}
|
||||
|
||||
if (mdescTy.getRank() != 2)
|
||||
return emitError() << "mem_desc must be 2D.";
|
||||
|
||||
ArrayRef<int64_t> dataShape = dataTy.getShape();
|
||||
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
|
||||
|
||||
if (dataShape.size() == 2) {
|
||||
if (subgroup_block_io)
|
||||
return emitError() << "subgroup_block_io "
|
||||
"are only allowed when result is a 1D VectorType.";
|
||||
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
|
||||
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
|
||||
return emitError() << "data shape must not exceed mem_desc shape.";
|
||||
} else {
|
||||
SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
|
||||
// if the subgroup_block_io attribute is set, mdescTy must have block
|
||||
// attribute
|
||||
if (subgroup_block_io && !blockShape.size())
|
||||
return emitError() << "mem_desc must have block attribute when "
|
||||
"subgroup_block_io is set.";
|
||||
// if the subgroup_block_io attribute is set, the memdesc should be row
|
||||
// major
|
||||
if (subgroup_block_io && mdescTy.isColMajor())
|
||||
return emitError() << "mem_desc should be row major when "
|
||||
"subgroup_block_io is set.";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XeGPU_CreateNdDescOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1092,20 +1049,23 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
|
||||
// Call the generated builder with all parameters (including optional ones as
|
||||
// nullptr/empty)
|
||||
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
|
||||
/*subgroup_block_io=*/nullptr, layout);
|
||||
layout);
|
||||
}
|
||||
|
||||
LogicalResult LoadMatrixOp::verify() {
|
||||
|
||||
auto resTy = dyn_cast<VectorType>(getRes().getType());
|
||||
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
|
||||
VectorType resTy = getRes().getType();
|
||||
MemDescType mdescTy = getMemDesc().getType();
|
||||
|
||||
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
|
||||
[&]() { return emitError(); });
|
||||
if (mdescTy.getRank() != 2)
|
||||
return emitOpError("mem_desc must be 2D.");
|
||||
|
||||
ArrayRef<int64_t> valueShape = resTy.getShape();
|
||||
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
|
||||
if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
|
||||
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
|
||||
return emitOpError("result shape must not exceed mem_desc shape.");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1120,16 +1080,57 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
|
||||
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
|
||||
/*subgroup_block_io=*/nullptr, layout);
|
||||
layout);
|
||||
}
|
||||
|
||||
LogicalResult StoreMatrixOp::verify() {
|
||||
|
||||
auto dataTy = dyn_cast<VectorType>(getData().getType());
|
||||
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
|
||||
VectorType dataTy = getData().getType();
|
||||
MemDescType mdescTy = getMemDesc().getType();
|
||||
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
|
||||
[&]() { return emitError(); });
|
||||
|
||||
if (mdescTy.getRank() != 2)
|
||||
return emitOpError("mem_desc must be 2D.");
|
||||
|
||||
ArrayRef<int64_t> dataShape = dataTy.getShape();
|
||||
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
|
||||
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
|
||||
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
|
||||
return emitOpError("data shape must not exceed mem_desc shape.");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XeGPU_MemDescSubviewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
|
||||
Type resTy, Value src,
|
||||
llvm::ArrayRef<OpFoldResult> offsets) {
|
||||
llvm::SmallVector<Value> dynamicOffsets;
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
|
||||
build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
|
||||
}
|
||||
|
||||
LogicalResult MemDescSubviewOp::verify() {
|
||||
MemDescType srcTy = getSrc().getType();
|
||||
MemDescType resTy = getRes().getType();
|
||||
ArrayRef<int64_t> srcShape = srcTy.getShape();
|
||||
ArrayRef<int64_t> resShape = resTy.getShape();
|
||||
|
||||
if (srcTy.getRank() < resTy.getRank())
|
||||
return emitOpError("result rank must not exceed source rank.");
|
||||
|
||||
if (llvm::any_of(
|
||||
llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
|
||||
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
|
||||
return emitOpError("result shape must not exceed source shape.");
|
||||
|
||||
if (srcTy.getStrides() != resTy.getStrides())
|
||||
return emitOpError("result must inherit the source strides.");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -941,9 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
|
||||
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
|
||||
assert(valueTy && "the value type must be vector type!");
|
||||
|
||||
VectorType valueTy = op.getType();
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
|
||||
return failure();
|
||||
@@ -986,8 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
|
||||
assert(valueTy && "the value type must be vector type!");
|
||||
VectorType valueTy = op.getData().getType();
|
||||
ArrayRef<int64_t> shape = valueTy.getShape();
|
||||
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
|
||||
|
||||
|
||||
@@ -991,8 +991,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> wgShape = op.getDataShape();
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
|
||||
assert(valueTy && "the value type must be vector type!");
|
||||
VectorType valueTy = op.getRes().getType();
|
||||
Type elemTy = valueTy.getElementType();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
|
||||
|
||||
gpu.module @test_kernel {
|
||||
#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
|
||||
#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
|
||||
#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
|
||||
|
||||
gpu.module @load_store_check {
|
||||
// CHECK-LABEL: func.func @dpas(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
|
||||
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
|
||||
// Loads are checked in a separate test.
|
||||
// CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
|
||||
// CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
|
||||
%d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded
|
||||
%d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
|
||||
: vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
|
||||
return %d : vector<8xf32>
|
||||
}
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
|
||||
|
||||
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
|
||||
|
||||
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
|
||||
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
|
||||
//CHECK-LABEL: load_store_matrix_1
|
||||
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
|
||||
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
|
||||
|
||||
//CHECK: %[[TID:.*]] = gpu.thread_id x
|
||||
//CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
|
||||
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
|
||||
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
|
||||
|
||||
%tid_x = gpu.thread_id x
|
||||
%c0 = arith.constant 0 : index
|
||||
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
|
||||
|
||||
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
|
||||
|
||||
xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
|
||||
|
||||
gpu.return %1: f32
|
||||
}
|
||||
|
||||
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
|
||||
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
|
||||
//CHECK-LABEL: load_store_matrix_2
|
||||
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
|
||||
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
|
||||
//CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
|
||||
//CHECK: %[[c13:.*]] = arith.constant 13 : index
|
||||
//CHECK: %[[c16:.*]] = arith.constant 16 : index
|
||||
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
|
||||
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
|
||||
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
|
||||
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
|
||||
|
||||
//CHECK: %[[c256:.*]] = arith.constant 256 : index
|
||||
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
|
||||
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
|
||||
//CHECK: %[[c512:.*]] = arith.constant 512 : index
|
||||
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
|
||||
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
|
||||
//CHECK: %[[c1:.*]] = arith.constant 1 : index
|
||||
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
|
||||
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
|
||||
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
|
||||
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
|
||||
|
||||
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
|
||||
|
||||
|
||||
%tid_x = gpu.thread_id x
|
||||
%c13 = arith.constant 13 : index
|
||||
%1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
|
||||
|
||||
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
|
||||
|
||||
xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
|
||||
gpu.return %1: f16
|
||||
}
|
||||
|
||||
|
||||
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
|
||||
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
|
||||
//CHECK-LABEL: load_store_matrix_3
|
||||
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
|
||||
//CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
|
||||
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
|
||||
|
||||
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
|
||||
//CHECK: %[[c19:.*]] = arith.constant 19 : index
|
||||
%tid_x = gpu.thread_id x
|
||||
%c19 = arith.constant 19: index
|
||||
|
||||
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
|
||||
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
|
||||
//CHECK: %[[c16:.*]] = arith.constant 16 : index
|
||||
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
|
||||
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
|
||||
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
|
||||
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
|
||||
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
|
||||
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
|
||||
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
|
||||
//CHECK: %[[c256:.*]] = arith.constant 256 : index
|
||||
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
|
||||
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
|
||||
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
|
||||
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
|
||||
//CHECK: %[[c1:.*]] = arith.constant 1 : index
|
||||
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
|
||||
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
|
||||
|
||||
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
|
||||
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
|
||||
|
||||
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
|
||||
xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
|
||||
|
||||
//CHECK: gpu.return %[[loaded]] : f16
|
||||
gpu.return %1: f16
|
||||
}
|
||||
|
||||
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
|
||||
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
|
||||
//CHECK-LABEL: load_store_matrix_4
|
||||
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
|
||||
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
|
||||
|
||||
//CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
|
||||
|
||||
//CHECK: %[[c16:.*]] = arith.constant 16 : index
|
||||
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
|
||||
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
|
||||
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
|
||||
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
|
||||
|
||||
//CHECK: %[[c256:.*]] = arith.constant 256 : index
|
||||
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
|
||||
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
|
||||
//CHECK: %[[c512:.*]] = arith.constant 512 : index
|
||||
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
|
||||
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
|
||||
//CHECK: %[[c1:.*]] = arith.constant 1 : index
|
||||
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
|
||||
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
|
||||
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
|
||||
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
|
||||
|
||||
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
|
||||
|
||||
%tid_x = gpu.thread_id x
|
||||
%c16 = arith.constant 16 : index
|
||||
%1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
|
||||
|
||||
//CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
|
||||
xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
|
||||
|
||||
gpu.return %1: vector<8xf16>
|
||||
}
|
||||
|
||||
|
||||
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
|
||||
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
|
||||
//CHECK-LABEL: load_store_matrix_5
|
||||
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
|
||||
//CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
|
||||
|
||||
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
|
||||
|
||||
//CHECK: %[[c16:.*]] = arith.constant 16 : index
|
||||
//CHECK: %[[c48:.*]] = arith.constant 48 : index
|
||||
|
||||
%c16 = arith.constant 16 : index
|
||||
%c48 = arith.constant 48 : index
|
||||
|
||||
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
|
||||
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
|
||||
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
|
||||
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
|
||||
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
|
||||
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
|
||||
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
|
||||
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
|
||||
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
|
||||
//CHECK: %[[c256:.*]] = arith.constant 256 : index
|
||||
//CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
|
||||
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
|
||||
//CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
|
||||
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
|
||||
//CHECK: %[[c1:.*]] = arith.constant 1 : index
|
||||
//CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
|
||||
//CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
|
||||
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
|
||||
//CHECK: %[[c2:.*]] = arith.constant 2 : i32
|
||||
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
|
||||
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
|
||||
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
|
||||
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
|
||||
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
|
||||
|
||||
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
|
||||
|
||||
//CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
|
||||
//CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
|
||||
|
||||
xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
|
||||
|
||||
gpu.return %1: vector<8xf16>
|
||||
}
|
||||
|
||||
}
|
||||
@@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>
|
||||
|
||||
// -----
|
||||
func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// expected-error@+1 {{data shape must not exceed mem_desc shape}}
|
||||
// expected-error@+1 {{result shape must not exceed mem_desc shape}}
|
||||
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
|
||||
return
|
||||
}
|
||||
@@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
|
||||
%data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
|
||||
// expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
|
||||
@@ -900,16 +892,30 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
|
||||
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
|
||||
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
|
||||
func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// expected-error@+1 {{result shape must not exceed source shape}}
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
|
||||
// expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
|
||||
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
|
||||
func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
|
||||
// expected-error@+1 {{result must inherit the source strides}}
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// expected-error@+1 {{failed to verify that all of {src, res} have same element type}}
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// expected-error@+1 {{result rank must not exceed source rank}}
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -825,73 +825,53 @@ gpu.func @create_mem_desc_with_stride() {
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
|
||||
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
|
||||
gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
|
||||
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
|
||||
gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
|
||||
// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
|
||||
gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
|
||||
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
|
||||
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>)
|
||||
gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
// CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
|
||||
%data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
|
||||
gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
|
||||
// CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
|
||||
%data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
|
||||
gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
|
||||
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
|
||||
%data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
|
||||
gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
|
||||
// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
|
||||
gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
|
||||
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
|
||||
xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
|
||||
gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
|
||||
// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
|
||||
gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
|
||||
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
|
||||
xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
|
||||
gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
|
||||
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16>
|
||||
xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16>
|
||||
// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
|
||||
gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
|
||||
gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
|
||||
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
|
||||
xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
|
||||
// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
|
||||
gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
|
||||
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
|
||||
gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
|
||||
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
|
||||
xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
|
||||
// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
|
||||
gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
|
||||
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
|
||||
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user