mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 05:32:28 +08:00
[MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc (#150576)
Add variant of load/store/prefetch to allow offset. The new xegpu.load
variant accepts memref+offset, and the existing tdesc operand will be
removed in the future PR.
The semantics are combination of "creating scattered_tdesc + xegpu.load
with scattered_tdesc". The current xegpu.load accepts tdesc operand,
which encapsulates "memref+offset". This PR "fold" "memref+offset"
directly to xegpu.load replacing "tdesc". Create_tdesc will be removed
as scatter_tdesc only contains base address after offsets being taken
away, so there is no point to keep it.
```mlir
// wi level code example
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<2xf32>
xegpu.store %val, %src[%offsets], %mask: vector<1xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
xegpu.prefetch %src[%0] : ui64, vector<1xindex>
```
This commit is contained in:
@@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
|
||||
As compared to prefetch_nd, which works on non-scattered TensorDesc,
|
||||
it works on scattered TensorDesc instead.
|
||||
|
||||
Example:
|
||||
Example 1:
|
||||
```mlir
|
||||
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
|
||||
l2_hint = #xegpu.cache_hint<cached>,
|
||||
l3_hint = #xegpu.cache_hint<cached>}
|
||||
: !xegpu.tensor_desc<16xf16>
|
||||
```
|
||||
|
||||
Example 2:
|
||||
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
|
||||
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
|
||||
The source operand could be a raw pointer (uint64_t).
|
||||
Please refer to create_tdesc for the restriction of memref.
|
||||
```mlir
|
||||
%a = memref.alloc() : memref<1024xf32>
|
||||
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
|
||||
xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
|
||||
l2_hint = #xegpu.cache_hint<cached>,
|
||||
l3_hint = #xegpu.cache_hint<cached>}
|
||||
: memref<1024xf32>, vector<4xindex>
|
||||
```
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
|
||||
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
|
||||
Optional<XeGPU_OffsetType>: $offsets,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
Type getSourceType() {
|
||||
return getSource().getType();
|
||||
}
|
||||
|
||||
TypedValue<xegpu::TensorDescType> getTensorDesc() {
|
||||
if (auto tdescType = getTensorDescType()) {
|
||||
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
|
||||
}
|
||||
return TypedValue<xegpu::TensorDescType>();
|
||||
}
|
||||
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return getTensorDesc().getType();
|
||||
return dyn_cast<xegpu::TensorDescType>(getSourceType());
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
|
||||
let assemblyFormat = [{
|
||||
$source
|
||||
(`[` $offsets^ `]`)?
|
||||
prop-dict
|
||||
attr-dict `:` type(operands)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value": $source,
|
||||
"xegpu::CachePolicyAttr": $l1_hint,
|
||||
"xegpu::CachePolicyAttr": $l2_hint,
|
||||
"xegpu::CachePolicyAttr": $l3_hint)>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
|
||||
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
|
||||
]> {
|
||||
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
|
||||
let summary = "load a set of scattered data points from memory.";
|
||||
|
||||
let description = [{ It (aka. load) load data per each work-item. The output
|
||||
@@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
|
||||
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
|
||||
vector<16xi1> -> vector<16x8xf32>
|
||||
```
|
||||
|
||||
Example 3 (SIMT mode):
|
||||
```mlir
|
||||
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
|
||||
@@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
|
||||
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
|
||||
vector<16xi1> -> vector<8xf32>
|
||||
```
|
||||
|
||||
Example 4:
|
||||
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
|
||||
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
|
||||
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
|
||||
for the restriction of memref.
|
||||
```mlir
|
||||
%a = memref.alloc() : memref<1024xf32>
|
||||
%offsets = vector.step : vector<16xindex>
|
||||
%mask = vector.constant_mask [16]: vector<16xi1>
|
||||
%val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
|
||||
l2_hint = #xegpu.cache_hint<cached>,
|
||||
l3_hint = #xegpu.cache_hint<cached>}
|
||||
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
|
||||
```
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
|
||||
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
|
||||
Optional<XeGPU_OffsetType>: $offsets,
|
||||
XeGPU_MaskType: $mask,
|
||||
OptionalAttr<I64Attr>: $chunk_size,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
|
||||
let results = (outs XeGPU_ValueType: $value);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
|
||||
Type getSourceType() {
|
||||
return getSource().getType();
|
||||
}
|
||||
|
||||
TypedValue<xegpu::TensorDescType> getTensorDesc() {
|
||||
if (auto tdescType = getTensorDescType()) {
|
||||
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
|
||||
}
|
||||
return TypedValue<xegpu::TensorDescType>();
|
||||
}
|
||||
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return getTensorDesc().getType();
|
||||
return dyn_cast<xegpu::TensorDescType>(getSourceType());
|
||||
}
|
||||
|
||||
mlir::Type getElementType() {
|
||||
@@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
|
||||
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
|
||||
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
|
||||
let assemblyFormat = [{
|
||||
$source
|
||||
(`[` $offsets^ `]`)? `,`
|
||||
$mask prop-dict
|
||||
attr-dict `:` type(operands) `->` type($value)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
|
||||
"xegpu::CachePolicyAttr": $l1_hint,
|
||||
"xegpu::CachePolicyAttr": $l2_hint,
|
||||
"xegpu::CachePolicyAttr": $l3_hint)>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
|
||||
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
|
||||
]> {
|
||||
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "store data to scattered memory locations.";
|
||||
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
|
||||
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
|
||||
@@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
|
||||
l3_hint = #xegpu.cache_hint<write_through>}>
|
||||
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
|
||||
```
|
||||
|
||||
Example 4:
|
||||
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
|
||||
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
|
||||
The dest operand could be a raw pointer (uint64_t).
|
||||
Please refer to create_tdesc for the restriction of memref.
|
||||
```mlir
|
||||
%a = memref.alloc() : memref<1024xf32>
|
||||
%val = arith.constant dense<0.0> : vector<16xf32>
|
||||
%offsets = vector.step : vector<16xindex>
|
||||
%mask = vector.constant_mask [16]: vector<16xi1>
|
||||
xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
|
||||
l2_hint = #xegpu.cache_hint<cached>,
|
||||
l3_hint = #xegpu.cache_hint<cached>}
|
||||
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
|
||||
```
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
XeGPU_ValueType: $value,
|
||||
XeGPU_TensorDesc: $TensorDesc,
|
||||
XeGPU_GatherScatterSourceType: $dest,
|
||||
Optional<XeGPU_OffsetType>: $offsets,
|
||||
XeGPU_MaskType: $mask,
|
||||
OptionalAttr<I64Attr>: $chunk_size,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
|
||||
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
Type getDestType() {
|
||||
return getDest().getType();
|
||||
}
|
||||
|
||||
TypedValue<xegpu::TensorDescType> getTensorDesc() {
|
||||
if (auto tdescType = getTensorDescType()) {
|
||||
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
|
||||
}
|
||||
return TypedValue<xegpu::TensorDescType>();
|
||||
}
|
||||
|
||||
xegpu::TensorDescType getTensorDescType() {
|
||||
return getTensorDesc().getType();
|
||||
return dyn_cast<xegpu::TensorDescType>(getDestType());
|
||||
}
|
||||
|
||||
VectorType getValueType() {
|
||||
@@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
|
||||
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
|
||||
let assemblyFormat = [{
|
||||
$value `,`
|
||||
$dest
|
||||
(`[` $offsets^ `]`)? `,`
|
||||
$mask
|
||||
prop-dict
|
||||
attr-dict `:` type(operands)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
|
||||
"xegpu::CachePolicyAttr": $l1_hint,
|
||||
"xegpu::CachePolicyAttr": $l2_hint,
|
||||
"xegpu::CachePolicyAttr": $l3_hint)>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
|
||||
let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
|
||||
|
||||
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
|
||||
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
|
||||
|
||||
@@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
|
||||
int64_t chunkSize,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
|
||||
if (!valueTy)
|
||||
return emitError() << "Expecting a vector type result.";
|
||||
|
||||
auto maskShape = getShapeOf(maskTy);
|
||||
auto valueShape = getShapeOf(valueTy);
|
||||
|
||||
// a valid shape for SIMT case
|
||||
if (valueTy.getRank() == 1) {
|
||||
if (valueTy.getNumElements() != chunkSize)
|
||||
return emitError() << "value elements must match chunk size " << chunkSize
|
||||
<< " for SIMT code.";
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
|
||||
if (chunkSize > 1)
|
||||
expectedMaskShape.pop_back();
|
||||
if (expectedMaskShape != maskShape)
|
||||
return emitError() << "Mask should match value except the chunk size dim.";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XeGPU_CreateNdDescOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
LogicalResult PrefetchOp::verify() {
|
||||
auto tdescTy = getTensorDescType();
|
||||
if (!tdescTy.isScattered())
|
||||
|
||||
if (tdescTy && !tdescTy.isScattered())
|
||||
return emitOpError("Expects a scattered TensorDesc.\n");
|
||||
|
||||
if (!tdescTy && getRankOf(getSource()) > 1)
|
||||
return emitOpError(
|
||||
"Expecting the source is a 1D memref or pointer (uint64_t).");
|
||||
|
||||
if (!isReadHintOrNone(getL1HintAttr()))
|
||||
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
|
||||
|
||||
@@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
|
||||
xegpu::CachePolicyAttr l1_hint,
|
||||
xegpu::CachePolicyAttr l2_hint,
|
||||
xegpu::CachePolicyAttr l3_hint) {
|
||||
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XeGPU_LoadGatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() {
|
||||
auto maskTy = getMaskType();
|
||||
auto valueTy = getValueType();
|
||||
|
||||
if (tdescTy && !tdescTy.isScattered())
|
||||
return emitOpError("Expects a scattered TensorDesc.");
|
||||
|
||||
if (!tdescTy && getRankOf(getSource()) > 1)
|
||||
return emitOpError(
|
||||
"Expecting the source is a 1D memref or pointer (uint64_t).");
|
||||
|
||||
if (!isReadHintOrNone(getL1HintAttr()))
|
||||
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
|
||||
|
||||
@@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() {
|
||||
if (!isReadHintOrNone(getL3HintAttr()))
|
||||
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
|
||||
|
||||
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
|
||||
[&]() { return emitOpError(); });
|
||||
if (tdescTy)
|
||||
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
|
||||
[&]() { return emitOpError(); });
|
||||
auto srcTy = getSourceType();
|
||||
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
|
||||
auto memTy = dyn_cast<MemRefType>(srcTy);
|
||||
|
||||
if (memTy && (valueTy.getElementType() != memTy.getElementType()))
|
||||
return emitError() << "Value should have the same element type as MemRef.";
|
||||
|
||||
return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
|
||||
[&]() { return emitOpError(); });
|
||||
}
|
||||
|
||||
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
|
||||
Type valueType, Value source, Value mask,
|
||||
xegpu::CachePolicyAttr l1_hint,
|
||||
xegpu::CachePolicyAttr l2_hint,
|
||||
xegpu::CachePolicyAttr l3_hint) {
|
||||
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
|
||||
l1_hint, l2_hint, l3_hint);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() {
|
||||
auto maskTy = getMaskType();
|
||||
auto valueTy = getValueType();
|
||||
|
||||
if (tdescTy && !tdescTy.isScattered())
|
||||
return emitOpError("Expects a scattered TensorDesc.\n");
|
||||
|
||||
if (!tdescTy && getRankOf(getDest()) > 1)
|
||||
return emitOpError(
|
||||
"Expecting the dest is a 1D memref or pointer (uint64_t).");
|
||||
|
||||
if (!isWriteHintOrNone(getL1HintAttr()))
|
||||
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
|
||||
|
||||
@@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() {
|
||||
if (!isWriteHintOrNone(getL3HintAttr()))
|
||||
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
|
||||
|
||||
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
|
||||
[&]() { return emitOpError(); });
|
||||
if (tdescTy)
|
||||
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
|
||||
[&]() { return emitOpError(); });
|
||||
|
||||
auto destTy = getDestType();
|
||||
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
|
||||
auto memTy = dyn_cast<MemRefType>(destTy);
|
||||
|
||||
if (memTy && (valueTy.getElementType() != memTy.getElementType()))
|
||||
return emitError() << "Value should have the same element type as MemRef.";
|
||||
|
||||
return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
|
||||
[&]() { return emitOpError(); });
|
||||
}
|
||||
|
||||
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
|
||||
Value value, Value dest, Value mask,
|
||||
xegpu::CachePolicyAttr l1_hint,
|
||||
xegpu::CachePolicyAttr l2_hint,
|
||||
xegpu::CachePolicyAttr l3_hint) {
|
||||
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
|
||||
l2_hint, l3_hint);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
if (!tdescTy.isScattered())
|
||||
// TODO: handle the unstructure source case (!tdesTy)
|
||||
if (!tdescTy || op.getOffsets())
|
||||
return failure();
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
@@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
|
||||
Location loc = op.getLoc();
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
if (!tdescTy.isScattered())
|
||||
// TODO: handle the unstructure source case (!tdesTy)
|
||||
if (!tdescTy || op.getOffsets())
|
||||
return failure();
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
@@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
|
||||
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
|
||||
xegpu::TensorDescType tdescTy = op.getTensorDescType();
|
||||
|
||||
if (!tdescTy.isScattered())
|
||||
// TODO: handle the unstructure source case (!tdesTy)
|
||||
if (!tdescTy || op.getOffsets())
|
||||
return failure();
|
||||
|
||||
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
|
||||
|
||||
@@ -384,6 +384,74 @@ func.func @load_gather_vc_3(%src: ui64) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
|
||||
%offsets = arith.constant dense<[0]> : vector<1xindex>
|
||||
// expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
|
||||
xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @load_gather_offset_sg(%src: memref<?xf16>) {
|
||||
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
%mask = arith.constant dense<1>: vector<8xi1>
|
||||
// expected-error@+1 {{Mask should match value except the chunk size dim}}
|
||||
%2 = xegpu.load %src[%offsets], %mask
|
||||
: memref<?xf16>, vector<4xindex>, vector<8xi1>
|
||||
-> vector<4x2xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @load_gather_offset_wi(%src: ui64) {
|
||||
%mask = arith.constant dense<1>: vector<1xi1>
|
||||
%offsets = arith.constant dense<[0]> : vector<1xindex>
|
||||
// expected-error@+1 {{value elements must match chunk size}}
|
||||
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
|
||||
%val = arith.constant dense<2.9>: vector<4xf16>
|
||||
%offsets = arith.constant dense<[0]> : vector<1xindex>
|
||||
%mask = arith.constant dense<1>: vector<1xi1>
|
||||
// expected-error@+1 {{value elements must match chunk size}}
|
||||
xegpu.store %val, %src[%offsets], %mask
|
||||
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
|
||||
%val = arith.constant dense<2.9>: vector<4xf16>
|
||||
%offsets = arith.constant dense<[0]> : vector<1xindex>
|
||||
%mask = arith.constant dense<1>: vector<1xi1>
|
||||
// expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
|
||||
xegpu.store %val, %src[%offsets], %mask
|
||||
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @load_gather_offset_wi_2(%src: ui64) {
|
||||
%mask = arith.constant dense<1>: vector<1xi1>
|
||||
%offsets = arith.constant dense<[0]> : vector<1xindex>
|
||||
// expected-error@+1 {{value elements must match chunk size}}
|
||||
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
|
||||
%mask = arith.constant dense<1>: vector<1xi1>
|
||||
%offsets = arith.constant dense<[0]> : vector<1xindex>
|
||||
// expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
|
||||
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
|
||||
%0 = arith.constant dense<1>: vector<4xi1>
|
||||
|
||||
@@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
|
||||
gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
|
||||
%offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
%mask = arith.constant dense<1>: vector<4xi1>
|
||||
//CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
|
||||
%val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
|
||||
: memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
|
||||
gpu.func @subgroup_store(%src: ui64) {
|
||||
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
@@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
|
||||
gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
|
||||
%val = arith.constant dense<2.9>: vector<4x2xf16>
|
||||
%offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
%mask = arith.constant dense<1>: vector<4xi1>
|
||||
//CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
|
||||
xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
|
||||
: vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
|
||||
gpu.func @prefetch(%src: ui64) {
|
||||
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
@@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) {
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
|
||||
gpu.func @prefetch_offset(%src: ui64) {
|
||||
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
|
||||
// CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
|
||||
xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
|
||||
gpu.func @create_update_tdesc(%src: ui64) {
|
||||
|
||||
Reference in New Issue
Block a user