[MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)

`set_op_layout_attr` and `set_desc_layout` transform ops wrap
`xegpu.layout` in an `xegpu.slice` attribute if `slice_dims` argument is
set.
This commit is contained in:
Tuomas Kärnä
2025-11-21 10:08:12 +02:00
committed by GitHub
parent 4c81b92e60
commit e9fc393a9e
5 changed files with 139 additions and 15 deletions

View File

@@ -66,6 +66,25 @@ def setDescLayoutInstData():
# CHECK: inst_data = [8, 16]
@run
def setDescLayoutSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.set_desc_layout(
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
)
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutSlice
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: slice_dims = [0]
@run
def setOpLayoutAttrOperandMinimal():
sequence = transform.SequenceOp(
@@ -106,7 +125,7 @@ def setOpLayoutAttrResult():
result=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttr
# CHECK-LABEL: TEST: setOpLayoutAttrResult
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
@@ -115,6 +134,34 @@ def setOpLayoutAttrResult():
# CHECK: inst_data = [8, 16]
@run
def setOpLayoutAttrResultSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
xegpu.set_op_layout_attr(
sequence.bodyTarget,
index=0,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
slice_dims=[0],
result=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]
@run
def setGPULaunchThreadsOp():
sequence = transform.SequenceOp(