Files
llvm/mlir/test/python/dialects/transform_xegpu_ext.py
Tuomas Kärnä e9fc393a9e [MLIR][XeGPU][TransformOps] Add slice_dims argument to set_op_layout_attr and set_desc_layout (#168929)
`set_op_layout_attr` and `set_desc_layout` transform ops wrap
`xegpu.layout` in an `xegpu.slice` attribute if `slice_dims` argument is
set.
2025-11-21 10:08:12 +02:00

297 lines
8.8 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
from mlir.dialects.transform import structured, AnyValueType
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f
@run
def getDescOpDefaultIndex():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
desc_handle = xegpu.get_desc_op(operand)
transform.YieldOp()
# CHECK-LABEL: TEST: getDescOpDefaultIndex
# CHECK: transform.xegpu.get_desc_op %
@run
def setDescLayoutMinimal():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutMinimal
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
@run
def setDescLayoutInstData():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.set_desc_layout(
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
)
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutInstData
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
@run
def setDescLayoutSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.set_desc_layout(
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
)
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutSlice
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: slice_dims = [0]
@run
def setOpLayoutAttrOperandMinimal():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
xegpu.set_op_layout_attr(
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttr
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# NO-CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# NO-CHECK: inst_data
@run
def setOpLayoutAttrResult():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
xegpu.set_op_layout_attr(
sequence.bodyTarget,
index=0,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
result=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResult
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
@run
def setOpLayoutAttrResultSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
xegpu.set_op_layout_attr(
sequence.bodyTarget,
index=0,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
slice_dims=[0],
result=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]
@run
def setGPULaunchThreadsOp():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("gpu.launch"),
)
with InsertionPoint(sequence.body):
xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
# CHECK: transform.xegpu.set_gpu_launch_threads
# CHECK: threads = [8, 4, 1]
@run
def insertPrefetch0():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
xegpu.insert_prefetch(
operand,
)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetch0
# CHECK: %[[OPR:.*]] = get_operand
# CHECK: transform.xegpu.insert_prefetch %[[OPR]]
@run
def insertPrefetchNbPrefetch():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
xegpu.insert_prefetch(
operand,
nb_prefetch=2,
)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetchNbPrefetch
# CHECK: %[[OPR:.*]] = get_operand
# CHECK: transform.xegpu.insert_prefetch %[[OPR]]
# CHECK-SAME: nb_prefetch = 2
@run
def insertPrefetchNbPrefetchParam():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
int32_t = IntegerType.get_signless(32)
param_int32_t = transform.ParamType.get(int32_t)
nb_param = transform.ParamConstantOp(
param_int32_t,
IntegerAttr.get(int32_t, 2),
)
xegpu.insert_prefetch(
operand,
nb_prefetch=nb_param,
)
transform.YieldOp()
# CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
# CHECK: %[[OPR:.*]] = get_operand
# CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
# CHECK: transform.xegpu.insert_prefetch %[[OPR]]
# CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
@run
def ConvertLayoutMinimal():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
xegpu.convert_layout(
operand,
input_sg_layout=[6, 4],
input_sg_data=[32, 16],
target_sg_layout=[6, 4],
target_sg_data=[8, 16],
)
transform.YieldOp()
# CHECK-LABEL: TEST: ConvertLayoutMinimal
# CHECK: transform.xegpu.convert_layout %
# CHECK: input_sg_layout = [6, 4]
# CHECK: input_sg_data = [32, 16]
# CHECK: target_sg_layout = [6, 4]
# CHECK: target_sg_data = [8, 16]
@run
def ConvertLayout():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
xegpu.convert_layout(
operand,
input_sg_layout=[6, 4],
input_sg_data=[32, 32],
input_inst_data=[32, 16],
target_sg_layout=[6, 4],
target_sg_data=[32, 32],
target_inst_data=[8, 16],
)
transform.YieldOp()
# CHECK-LABEL: TEST: ConvertLayout
# CHECK: transform.xegpu.convert_layout %
# CHECK: input_sg_layout = [6, 4]
# CHECK: input_sg_data = [32, 32]
# CHECK: input_inst_data = [32, 16]
# CHECK: target_sg_layout = [6, 4]
# CHECK: target_sg_data = [32, 32]
# CHECK: target_inst_data = [8, 16]