[MLIR][XeGPU][Transform] add xegpu.set_desc_layout transform op (#165615)

Adds the first XeGPU transform op, `xegpu.set_desc_layout`, which attachs a `xegpu.layout` attribute to the descriptor that a `xegpu.create_nd_tdesc` op returns.
This commit is contained in:
Tuomas Kärnä
2025-11-06 16:25:34 +02:00
committed by GitHub
parent 9d1b578a22
commit 3a68751190
16 changed files with 584 additions and 7 deletions

View File

@@ -0,0 +1,51 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
from mlir.dialects.transform import structured
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def setDescLayoutMinimal():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutMinimal
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
@run
def setDescLayoutInstData():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.SetDescLayoutOp(
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
)
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutInstData
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]