mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[MLIR][XeGPU] Retain anchor op layouts for XeGPU nD ops (#170934)
This PR adds support to retain the anchor op layouts (after dropping what's not required) for xegpu nD ops during workgroup to subgroup & unroll transformation
This commit is contained in:
@@ -329,7 +329,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
|
||||
"ArrayRef<OpFoldResult>": $offsets,
|
||||
"xegpu::CachePolicyAttr": $l1_hint,
|
||||
"xegpu::CachePolicyAttr": $l2_hint,
|
||||
"xegpu::CachePolicyAttr": $l3_hint)>
|
||||
"xegpu::CachePolicyAttr": $l3_hint,
|
||||
"xegpu::DistributeLayoutAttr": $layout)>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
@@ -453,7 +454,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
|
||||
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
|
||||
"xegpu::CachePolicyAttr": $l1_hint,
|
||||
"xegpu::CachePolicyAttr": $l2_hint,
|
||||
"xegpu::CachePolicyAttr": $l3_hint)>
|
||||
"xegpu::CachePolicyAttr": $l3_hint,
|
||||
"xegpu::DistributeLayoutAttr": $layout)>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
@@ -564,7 +566,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
|
||||
"ArrayRef<OpFoldResult>": $offsets,
|
||||
"xegpu::CachePolicyAttr": $l1_hint,
|
||||
"xegpu::CachePolicyAttr": $l2_hint,
|
||||
"xegpu::CachePolicyAttr": $l3_hint)>
|
||||
"xegpu::CachePolicyAttr": $l3_hint,
|
||||
"xegpu::DistributeLayoutAttr": $layout)>
|
||||
];
|
||||
|
||||
|
||||
|
||||
@@ -567,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
|
||||
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
|
||||
/*packed=*/nullptr, transposeAttr,
|
||||
/*l1_hint=*/hint,
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
||||
/*layout=*/nullptr);
|
||||
rewriter.replaceOp(readOp, loadOp);
|
||||
|
||||
return success();
|
||||
@@ -621,7 +622,8 @@ struct TransferWriteLowering
|
||||
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
|
||||
ndDesc, indices,
|
||||
/*l1_hint=*/hint,
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
||||
/*layout=*/nullptr);
|
||||
rewriter.replaceOp(writeOp, storeOp);
|
||||
|
||||
return success();
|
||||
@@ -725,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
|
||||
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
|
||||
/*packed=*/nullptr, /*transpose=*/nullptr,
|
||||
/*l1_hint=*/hint,
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
||||
/*layout=*/nullptr);
|
||||
rewriter.replaceOp(loadOp, loadNdOp);
|
||||
|
||||
return success();
|
||||
@@ -763,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
|
||||
auto storeNdOp =
|
||||
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
|
||||
/*l1_hint=*/hint,
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint);
|
||||
/*l2_hint=*/hint, /*l3_hint=*/hint,
|
||||
/*layout=*/nullptr);
|
||||
|
||||
rewriter.replaceOp(storeOp, storeNdOp);
|
||||
|
||||
|
||||
@@ -472,7 +472,8 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
|
||||
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
|
||||
xegpu::CachePolicyAttr l1_hint,
|
||||
xegpu::CachePolicyAttr l2_hint,
|
||||
xegpu::CachePolicyAttr l3_hint) {
|
||||
xegpu::CachePolicyAttr l3_hint,
|
||||
xegpu::DistributeLayoutAttr layout) {
|
||||
SmallVector<Value> dynamicOffsets;
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
@@ -480,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
|
||||
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
|
||||
|
||||
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
|
||||
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
|
||||
l2_hint, l3_hint, /*anchor_layout=*/layout);
|
||||
}
|
||||
|
||||
LogicalResult PrefetchNdOp::verify() {
|
||||
@@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
|
||||
UnitAttr packed, DenseI64ArrayAttr transpose,
|
||||
xegpu::CachePolicyAttr l1_hint,
|
||||
xegpu::CachePolicyAttr l2_hint,
|
||||
xegpu::CachePolicyAttr l3_hint) {
|
||||
xegpu::CachePolicyAttr l3_hint,
|
||||
xegpu::DistributeLayoutAttr layout) {
|
||||
SmallVector<Value> dynamicOffsets;
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
@@ -536,7 +538,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
|
||||
|
||||
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
|
||||
packed, transpose, l1_hint, l2_hint, l3_hint,
|
||||
/*anchor_layout=*/nullptr);
|
||||
/*anchor_layout=*/layout);
|
||||
}
|
||||
|
||||
LogicalResult LoadNdOp::verify() {
|
||||
@@ -647,7 +649,8 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
|
||||
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
|
||||
xegpu::CachePolicyAttr l1_hint,
|
||||
xegpu::CachePolicyAttr l2_hint,
|
||||
xegpu::CachePolicyAttr l3_hint) {
|
||||
xegpu::CachePolicyAttr l3_hint,
|
||||
xegpu::DistributeLayoutAttr layout) {
|
||||
SmallVector<Value> dynamicOffsets;
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
@@ -655,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
|
||||
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
|
||||
|
||||
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
|
||||
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
|
||||
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
|
||||
}
|
||||
|
||||
LogicalResult StoreNdOp::verify() {
|
||||
|
||||
@@ -528,7 +528,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
|
||||
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
|
||||
newDescOp.getResult(),
|
||||
getPrefetchOffsets(initForOp.getInductionVar()),
|
||||
readCacheHint, readCacheHint, readCacheHint);
|
||||
readCacheHint, readCacheHint, readCacheHint,
|
||||
/*layout=*/nullptr);
|
||||
|
||||
// Insert prefetch op in main loop.
|
||||
// Calculate prefetch offset after the init prefetches have been issued.
|
||||
@@ -539,7 +540,7 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
|
||||
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
|
||||
newDescOp.getResult(),
|
||||
getPrefetchOffsets(prefetchOffset), readCacheHint,
|
||||
readCacheHint, readCacheHint);
|
||||
readCacheHint, readCacheHint, /*layout=*/nullptr);
|
||||
|
||||
// Unroll the init loop.
|
||||
if (failed(loopUnrollFull(initForOp)))
|
||||
|
||||
@@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
|
||||
newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
|
||||
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
|
||||
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
|
||||
origLoadOp.getL3HintAttr());
|
||||
origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
|
||||
// Set the layout for the loadOp.
|
||||
auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
|
||||
xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
|
||||
|
||||
@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (layout)
|
||||
layout = layout.dropInstData();
|
||||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
||||
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
|
||||
|
||||
@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
|
||||
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
|
||||
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
|
||||
op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
op.getL3HintAttr(), layout);
|
||||
// return dummy Value to satisfy function's signature
|
||||
return nullptr;
|
||||
};
|
||||
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (layout)
|
||||
layout = layout.dropInstData();
|
||||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
||||
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
|
||||
|
||||
@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
|
||||
return xegpu::LoadNdOp::create(
|
||||
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
|
||||
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
|
||||
op.getL2HintAttr(), op.getL3HintAttr());
|
||||
op.getL2HintAttr(), op.getL3HintAttr(), layout);
|
||||
};
|
||||
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
|
||||
*targetShape, createLoad, loc, rewriter);
|
||||
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (layout)
|
||||
layout = layout.dropInstData();
|
||||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
|
||||
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
|
||||
|
||||
@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
|
||||
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
|
||||
convertedTdescs[0], offsets,
|
||||
op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
op.getL3HintAttr(), layout);
|
||||
// return dummy Value to satisfy function's signature
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
@@ -317,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (layout)
|
||||
layout = layout.dropSgLayoutAndData();
|
||||
SmallVector<Value> newOps;
|
||||
for (auto [tdesc, offsets] :
|
||||
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
|
||||
@@ -326,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
|
||||
auto newOp = xegpu::LoadNdOp::create(
|
||||
rewriter, op.getLoc(), newResTy, tdesc, offsets,
|
||||
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
|
||||
op.getL2HintAttr(), op.getL3HintAttr());
|
||||
op.getL2HintAttr(), op.getL3HintAttr(), layout);
|
||||
newOps.push_back(newOp);
|
||||
}
|
||||
rewriter.replaceOpWithMultiple(op, {newOps});
|
||||
@@ -347,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (layout)
|
||||
layout = layout.dropSgLayoutAndData();
|
||||
for (auto [v, tdesc, offsets] :
|
||||
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
|
||||
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
|
||||
op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
op.getL3HintAttr(), layout);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -371,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
|
||||
if (failed(genOffsetsList(rewriter, op, offsetsList)))
|
||||
return failure();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
|
||||
if (layout)
|
||||
layout = layout.dropSgLayoutAndData();
|
||||
for (auto [tdesc, offsets] :
|
||||
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
|
||||
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
|
||||
op.getL1HintAttr(), op.getL2HintAttr(),
|
||||
op.getL3HintAttr());
|
||||
op.getL3HintAttr(), layout);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
|
||||
@@ -633,4 +633,17 @@ gpu.module @test_distribution {
|
||||
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: load_nd_tdesc_with_anchor_layout
|
||||
gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) {
|
||||
//CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
|
||||
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
|
||||
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
|
||||
// CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}>
|
||||
// CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
|
||||
%load = xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}>
|
||||
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
|
||||
-> vector<256x128xf32>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user