[MLIR][XeGPU] Retain anchor op layouts for XeGPU nD ops (#170934)

This PR adds support to retain the anchor op layouts (after dropping
what's not required) for xegpu nD ops during workgroup to subgroup &
unroll transformation
This commit is contained in:
Nishant Patel
2025-12-05 21:49:13 -08:00
committed by GitHub
parent 8fe38c4c9c
commit 5fc8e87fe2
8 changed files with 64 additions and 22 deletions

View File

@@ -329,7 +329,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
"ArrayRef<OpFoldResult>": $offsets,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::DistributeLayoutAttr": $layout)>
];
let hasVerifier = 1;
@@ -453,7 +454,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::DistributeLayoutAttr": $layout)>
];
let hasVerifier = 1;
@@ -564,7 +566,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
"ArrayRef<OpFoldResult>": $offsets,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
"xegpu::CachePolicyAttr": $l3_hint,
"xegpu::DistributeLayoutAttr": $layout)>
];

View File

@@ -567,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
/*l2_hint=*/hint, /*l3_hint=*/hint,
/*layout=*/nullptr);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -621,7 +622,8 @@ struct TransferWriteLowering
auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
ndDesc, indices,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
/*l2_hint=*/hint, /*l3_hint=*/hint,
/*layout=*/nullptr);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -725,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
/*l2_hint=*/hint, /*l3_hint=*/hint,
/*layout=*/nullptr);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -763,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
auto storeNdOp =
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
/*l2_hint=*/hint, /*l3_hint=*/hint,
/*layout=*/nullptr);
rewriter.replaceOp(storeOp, storeNdOp);

View File

@@ -472,7 +472,8 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
xegpu::CachePolicyAttr l3_hint,
xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -480,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult PrefetchNdOp::verify() {
@@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
xegpu::CachePolicyAttr l3_hint,
xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -536,7 +538,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
packed, transpose, l1_hint, l2_hint, l3_hint,
/*anchor_layout=*/nullptr);
/*anchor_layout=*/layout);
}
LogicalResult LoadNdOp::verify() {
@@ -647,7 +649,8 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
xegpu::CachePolicyAttr l3_hint,
xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -655,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult StoreNdOp::verify() {

View File

@@ -528,7 +528,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
newDescOp.getResult(),
getPrefetchOffsets(initForOp.getInductionVar()),
readCacheHint, readCacheHint, readCacheHint);
readCacheHint, readCacheHint, readCacheHint,
/*layout=*/nullptr);
// Insert prefetch op in main loop.
// Calculate prefetch offset after the init prefetches have been issued.
@@ -539,7 +540,7 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
newDescOp.getResult(),
getPrefetchOffsets(prefetchOffset), readCacheHint,
readCacheHint, readCacheHint);
readCacheHint, readCacheHint, /*layout=*/nullptr);
// Unroll the init loop.
if (failed(loopUnrollFull(initForOp)))

View File

@@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
origLoadOp.getL3HintAttr());
origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
// Set the layout for the loadOp.
auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);

View File

@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
if (!targetShape)
return failure();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
if (!targetShape)
return failure();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return xegpu::LoadNdOp::create(
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
op.getL2HintAttr(), op.getL3HintAttr(), layout);
};
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createLoad, loc, rewriter);
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
if (!targetShape)
return failure();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (layout)
layout = layout.dropInstData();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
convertedTdescs[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
op.getL3HintAttr(), layout);
// return dummy Value to satisfy function's signature
return nullptr;
};

View File

@@ -317,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (layout)
layout = layout.dropSgLayoutAndData();
SmallVector<Value> newOps;
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
@@ -326,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
auto newOp = xegpu::LoadNdOp::create(
rewriter, op.getLoc(), newResTy, tdesc, offsets,
/*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
op.getL2HintAttr(), op.getL3HintAttr(), layout);
newOps.push_back(newOp);
}
rewriter.replaceOpWithMultiple(op, {newOps});
@@ -347,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (layout)
layout = layout.dropSgLayoutAndData();
for (auto [v, tdesc, offsets] :
llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);
@@ -371,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
if (failed(genOffsetsList(rewriter, op, offsetsList)))
return failure();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
if (layout)
layout = layout.dropSgLayoutAndData();
for (auto [tdesc, offsets] :
llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
op.getL3HintAttr(), layout);
}
rewriter.eraseOp(op);

View File

@@ -633,4 +633,17 @@ gpu.module @test_distribution {
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc_with_anchor_layout
gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) {
//CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}>
// CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
%load = xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}>
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
gpu.return
}
}