[MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc (#150576)

Add variant of load/store/prefetch to allow offset. The new xegpu.load
variant accepts memref+offset, and the existing tdesc operand will be
removed in the future PR.

The semantics are combination of "creating scattered_tdesc + xegpu.load
with scattered_tdesc". The current xegpu.load accepts tdesc operand,
which encapsulates "memref+offset". This PR "fold" "memref+offset"
directly to xegpu.load replacing "tdesc". Create_tdesc will be removed
as scatter_tdesc only contains base address after offsets being taken
away, so there is no point to keep it.

```mlir 
    // wi level code example
    %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<2xf32>
    xegpu.store %val, %src[%offsets], %mask: vector<1xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
    xegpu.prefetch %src[%0] : ui64, vector<1xindex>
```
This commit is contained in:
Jianhui Li
2025-07-30 16:00:40 -07:00
committed by GitHub
parent b9a627e6fb
commit e6f360b0ab
6 changed files with 338 additions and 26 deletions

View File

@@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.
Example:
Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```
Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<4xindex>
```
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
return getSource().getType();
}
TypedValue<xegpu::TensorDescType> getTensorDesc() {
if (auto tdescType = getTensorDescType()) {
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
}
return TypedValue<xegpu::TensorDescType>();
}
xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
}];
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
}];
let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
let hasVerifier = 1;
}
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
]> {
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a set of scattered data points from memory.";
let description = [{ It (aka. load) load data per each work-item. The output
@@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```
Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```
Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
%val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let results = (outs XeGPU_ValueType: $value);
let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
return getSource().getType();
}
TypedValue<xegpu::TensorDescType> getTensorDesc() {
if (auto tdescType = getTensorDescType()) {
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
}
return TypedValue<xegpu::TensorDescType>();
}
xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
mlir::Type getElementType() {
@@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
let hasVerifier = 1;
}
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
```
Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```
}];
let arguments = (ins
XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
XeGPU_GatherScatterSourceType: $dest,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getDestType() {
return getDest().getType();
}
TypedValue<xegpu::TensorDescType> getTensorDesc() {
if (auto tdescType = getTensorDescType()) {
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
}
return TypedValue<xegpu::TensorDescType>();
}
xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getDestType());
}
VectorType getValueType() {
@@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
}
}];
let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
let assemblyFormat = [{
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
attr-dict `:` type(operands)
}];
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
let hasVerifier = 1;
}

View File

@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

View File

@@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
static LogicalResult
isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
if (!valueTy)
return emitError() << "Expecting a vector type result.";
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
// a valid shape for SIMT case
if (valueTy.getRank() == 1) {
if (valueTy.getNumElements() != chunkSize)
return emitError() << "value elements must match chunk size " << chunkSize
<< " for SIMT code.";
return success();
}
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
return emitError() << "Mask should match value except the chunk size dim.";
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
if (!tdescTy.isScattered())
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (!tdescTy && getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}
void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
@@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
if (!tdescTy && getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
if (tdescTy)
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
auto srcTy = getSourceType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
if (memTy && (valueTy.getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source, Value mask,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
l1_hint, l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
@@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (!tdescTy && getRankOf(getDest()) > 1)
return emitOpError(
"Expecting the dest is a 1D memref or pointer (uint64_t).");
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
if (tdescTy)
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
auto destTy = getDestType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
if (memTy && (valueTy.getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Value value, Value dest, Value mask,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//

View File

@@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
if (!tdescTy.isScattered())
// TODO: handle the unstructure source case (!tdesTy)
if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
if (!tdescTy.isScattered())
// TODO: handle the unstructure source case (!tdesTy)
if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
if (!tdescTy.isScattered())
// TODO: handle the unstructure source case (!tdesTy)
if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

View File

@@ -384,6 +384,74 @@ func.func @load_gather_vc_3(%src: ui64) {
return
}
// -----
func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
return
}
// -----
func.func @load_gather_offset_sg(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%mask = arith.constant dense<1>: vector<8xi1>
// expected-error@+1 {{Mask should match value except the chunk size dim}}
%2 = xegpu.load %src[%offsets], %mask
: memref<?xf16>, vector<4xindex>, vector<8xi1>
-> vector<4x2xf16>
return
}
// -----
func.func @load_gather_offset_wi(%src: ui64) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error@+1 {{value elements must match chunk size}}
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
return
}
// -----
func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{value elements must match chunk size}}
xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
// -----
func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
return
}
// -----
func.func @load_gather_offset_wi_2(%src: ui64) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error@+1 {{value elements must match chunk size}}
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
return
}
// -----
func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
return
}
// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
%0 = arith.constant dense<1>: vector<4xi1>

View File

@@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
gpu.return
}
// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
%offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%mask = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
%val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
: memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
gpu.return
}
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
gpu.return
}
// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
%val = arith.constant dense<2.9>: vector<4x2xf16>
%offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%mask = arith.constant dense<1>: vector<4xi1>
//CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
: vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
gpu.return
}
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
gpu.func @prefetch(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) {
gpu.return
}
// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
gpu.func @prefetch_offset(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
// CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
gpu.return
}
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
gpu.func @create_update_tdesc(%src: ui64) {