mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 20:53:29 +08:00
[MLIR][XeGPU][TransformOps] Add convert_layout op (#167342)
Adds `transform.xegpu.convert_layout` transform op that inserts an `xegpu.convert_layout` op for a given `Value`.
This commit is contained in:
@@ -244,4 +244,90 @@ def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
|
||||
}];
|
||||
}
|
||||
|
||||
def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
|
||||
AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
TransformOpInterface
|
||||
]> {
|
||||
|
||||
let summary = "Convert xegpu.layout attribute for a value.";
|
||||
let description = [{
|
||||
Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
|
||||
of a value. The input and target layouts are defined by the `*sg_layout`,
|
||||
`*sg_data` and optional `*inst_data` attributes. Returns a handle to the
|
||||
emitted `xegpu.convert_layout` op.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformValueHandleTypeInterface:$target,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_layout,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_data,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$input_inst_data,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_layout,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_data,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_inst_data,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_layout,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_data,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_inst_data,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_layout,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_data,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data
|
||||
);
|
||||
|
||||
let results = (outs TransformHandleTypeInterface:$newConvertOp);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$target,
|
||||
"ArrayRef<OpFoldResult>":$mixedInputSgLayout,
|
||||
"ArrayRef<OpFoldResult>":$mixedInputSgData,
|
||||
"ArrayRef<OpFoldResult>":$mixedInputInstData,
|
||||
"ArrayRef<OpFoldResult>":$mixedTargetSgLayout,
|
||||
"ArrayRef<OpFoldResult>":$mixedTargetSgData,
|
||||
"ArrayRef<OpFoldResult>":$mixedTargetInstData
|
||||
)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$target
|
||||
`input_sg_layout` `=` custom<DynamicIndexList>($input_sg_layout, $static_input_sg_layout)
|
||||
`input_sg_data` `=` custom<DynamicIndexList>($input_sg_data, $static_input_sg_data)
|
||||
(`input_inst_data` `=` custom<DynamicIndexList>($input_inst_data, $static_input_inst_data)^)?
|
||||
`target_sg_layout` `=` custom<DynamicIndexList>($target_sg_layout, $static_target_sg_layout)
|
||||
`target_sg_data` `=` custom<DynamicIndexList>($target_sg_data, $static_target_sg_data)
|
||||
(`target_inst_data` `=` custom<DynamicIndexList>($target_inst_data, $static_target_inst_data)^)?
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure apply(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::transform::TransformResults &transformResults,
|
||||
::mlir::transform::TransformState &state);
|
||||
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgLayout() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticInputSgLayout(), getInputSgLayout(), b);
|
||||
}
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgData() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticInputSgData(), getInputSgData(), b);
|
||||
}
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputInstData() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticInputInstData(), getInputInstData(), b);
|
||||
}
|
||||
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgLayout() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticTargetSgLayout(), getTargetSgLayout(), b);
|
||||
}
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgData() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticTargetSgData(), getTargetSgData(), b);
|
||||
}
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetInstData() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticTargetInstData(), getTargetInstData(), b);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // XEGPU_TRANSFORM_OPS
|
||||
|
||||
@@ -537,6 +537,110 @@ void transform::InsertPrefetchOp::getEffects(
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
void transform::ConvertLayoutOp::build(
|
||||
OpBuilder &builder, OperationState &ostate, Value target,
|
||||
ArrayRef<OpFoldResult> mixedInputSgLayout,
|
||||
ArrayRef<OpFoldResult> mixedInputSgData,
|
||||
ArrayRef<OpFoldResult> mixedInputInstData,
|
||||
ArrayRef<OpFoldResult> mixedTargetSgLayout,
|
||||
ArrayRef<OpFoldResult> mixedTargetSgData,
|
||||
ArrayRef<OpFoldResult> mixedTargetInstData) {
|
||||
SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
|
||||
staticInputInstData;
|
||||
SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
|
||||
dynamicInputInstData;
|
||||
dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
|
||||
staticInputSgLayout);
|
||||
dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
|
||||
staticInputSgData);
|
||||
dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
|
||||
staticInputInstData);
|
||||
SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
|
||||
staticTargetInstData;
|
||||
SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
|
||||
dynamicTargetInstData;
|
||||
dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
|
||||
staticTargetSgLayout);
|
||||
dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
|
||||
staticTargetSgData);
|
||||
dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
|
||||
staticTargetInstData);
|
||||
build(builder, ostate, target.getType(),
|
||||
/*target=*/target,
|
||||
/*input_sg_layout=*/dynamicInputSgLayout,
|
||||
/*input_sg_data=*/dynamicInputSgData,
|
||||
/*input_inst_data=*/dynamicInputInstData,
|
||||
/*target_sg_layout=*/dynamicTargetSgLayout,
|
||||
/*target_sg_data=*/dynamicTargetSgData,
|
||||
/*target_inst_data=*/dynamicTargetInstData,
|
||||
/*static_input_sg_layout=*/staticInputSgLayout,
|
||||
/*static_input_sg_data=*/staticInputSgData,
|
||||
/*static_input_inst_data=*/staticInputInstData,
|
||||
/*static_target_sg_layout=*/staticTargetSgLayout,
|
||||
/*static_target_sg_data=*/staticTargetSgData,
|
||||
/*static_target_inst_data=*/staticTargetInstData);
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto targetValues = state.getPayloadValues(getTarget());
|
||||
if (!llvm::hasSingleElement(targetValues))
|
||||
return emitDefiniteFailure()
|
||||
<< "requires exactly one target value handle (got "
|
||||
<< llvm::range_size(targetValues) << ")";
|
||||
auto value = *targetValues.begin();
|
||||
|
||||
// Construct layout attributes.
|
||||
xegpu::LayoutAttr inputLayoutAttr = nullptr;
|
||||
auto status = getLayoutAttrFromOperands(
|
||||
getContext(), state, (*this), getMixedInputSgLayout(),
|
||||
getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
|
||||
if (!status.succeeded())
|
||||
return status;
|
||||
|
||||
xegpu::LayoutAttr targetLayoutAttr = nullptr;
|
||||
status = getLayoutAttrFromOperands(
|
||||
getContext(), state, (*this), getMixedTargetSgLayout(),
|
||||
getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
|
||||
if (!status.succeeded())
|
||||
return status;
|
||||
|
||||
// Find first user op to define insertion point for layout conversion.
|
||||
if (value.use_empty())
|
||||
return emitSilenceableFailure(getLoc())
|
||||
<< "Value has no users to insert layout conversion.";
|
||||
Operation *userOp = *value.getUsers().begin();
|
||||
|
||||
// Emit convert_layout op.
|
||||
rewriter.setInsertionPoint(userOp);
|
||||
auto convLayoutOp =
|
||||
xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
|
||||
value, inputLayoutAttr, targetLayoutAttr);
|
||||
// Replace load op result with the converted layout.
|
||||
rewriter.replaceUsesWithIf(
|
||||
value, convLayoutOp.getResult(), [&](OpOperand &use) {
|
||||
return use.getOwner() != convLayoutOp.getOperation();
|
||||
});
|
||||
|
||||
results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::ConvertLayoutOp::getEffects(
|
||||
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getTargetMutable(), effects);
|
||||
onlyReadsHandle(getInputSgLayoutMutable(), effects);
|
||||
onlyReadsHandle(getInputSgDataMutable(), effects);
|
||||
onlyReadsHandle(getInputInstDataMutable(), effects);
|
||||
onlyReadsHandle(getTargetSgLayoutMutable(), effects);
|
||||
onlyReadsHandle(getTargetSgDataMutable(), effects);
|
||||
onlyReadsHandle(getTargetInstDataMutable(), effects);
|
||||
producesHandle(getOperation()->getOpResults(), effects);
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class XeGPUTransformDialectExtension
|
||||
: public transform::TransformDialectExtension<
|
||||
|
||||
@@ -42,6 +42,15 @@ class GetDescOp(GetDescOp):
|
||||
)
|
||||
|
||||
|
||||
def get_desc_op(
|
||||
target: Value,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> OpResult:
|
||||
return GetDescOp(target, loc=loc, ip=ip).result
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class SetDescLayoutOp(SetDescLayoutOp):
|
||||
"""Specialization for SetDescLayoutOp class."""
|
||||
@@ -88,6 +97,25 @@ class SetDescLayoutOp(SetDescLayoutOp):
|
||||
)
|
||||
|
||||
|
||||
def set_desc_layout(
|
||||
target: Union[Operation, Value],
|
||||
sg_layout: MixedValues,
|
||||
sg_data: MixedValues,
|
||||
*,
|
||||
inst_data: Optional[MixedValues] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> OpResult:
|
||||
return SetDescLayoutOp(
|
||||
target,
|
||||
sg_layout,
|
||||
sg_data,
|
||||
inst_data=inst_data,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
).result
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
|
||||
"""Specialization for SetOpLayoutAttrOp class."""
|
||||
@@ -135,6 +163,29 @@ class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
|
||||
)
|
||||
|
||||
|
||||
def set_op_layout_attr(
|
||||
target: Union[Operation, Value],
|
||||
sg_layout: MixedValues,
|
||||
sg_data: MixedValues,
|
||||
*,
|
||||
inst_data: Optional[MixedValues] = None,
|
||||
index: Optional[Union[int, Attribute]] = None,
|
||||
result: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> SetOpLayoutAttrOp:
|
||||
return SetOpLayoutAttrOp(
|
||||
target,
|
||||
sg_layout,
|
||||
sg_data,
|
||||
inst_data=inst_data,
|
||||
index=index,
|
||||
result=result,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
|
||||
"""Specialization for SetGPULaunchThreadsOp class."""
|
||||
@@ -210,4 +261,98 @@ def insert_prefetch(
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> OpResult:
|
||||
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
|
||||
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ConvertLayoutOp(ConvertLayoutOp):
|
||||
"""Specialization for ConvertLayoutOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Value,
|
||||
input_sg_layout: MixedValues,
|
||||
input_sg_data: MixedValues,
|
||||
target_sg_layout: MixedValues,
|
||||
target_sg_data: MixedValues,
|
||||
*,
|
||||
input_inst_data: Optional[MixedValues] = None,
|
||||
target_inst_data: Optional[MixedValues] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
input_inst_data = [] if input_inst_data is None else input_inst_data
|
||||
target_inst_data = [] if target_inst_data is None else target_inst_data
|
||||
(
|
||||
dynamic_input_sg_layout,
|
||||
static_input_sg_layout,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(input_sg_layout)
|
||||
(
|
||||
dynamic_input_sg_data,
|
||||
static_input_sg_data,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(input_sg_data)
|
||||
(
|
||||
dynamic_input_inst_data,
|
||||
static_input_inst_data,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(input_inst_data)
|
||||
(
|
||||
dynamic_target_sg_layout,
|
||||
static_target_sg_layout,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(target_sg_layout)
|
||||
(
|
||||
dynamic_target_sg_data,
|
||||
static_target_sg_data,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(target_sg_data)
|
||||
(
|
||||
dynamic_target_inst_data,
|
||||
static_target_inst_data,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(target_inst_data)
|
||||
super().__init__(
|
||||
transform.AnyOpType.get(),
|
||||
target,
|
||||
dynamic_input_sg_layout,
|
||||
dynamic_input_sg_data,
|
||||
dynamic_input_inst_data,
|
||||
dynamic_target_sg_layout,
|
||||
dynamic_target_sg_data,
|
||||
dynamic_target_inst_data,
|
||||
static_input_sg_layout=static_input_sg_layout,
|
||||
static_input_sg_data=static_input_sg_data,
|
||||
static_input_inst_data=static_input_inst_data,
|
||||
static_target_sg_layout=static_target_sg_layout,
|
||||
static_target_sg_data=static_target_sg_data,
|
||||
static_target_inst_data=static_target_inst_data,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def convert_layout(
|
||||
target: Value,
|
||||
input_sg_layout: MixedValues,
|
||||
input_sg_data: MixedValues,
|
||||
target_sg_layout: MixedValues,
|
||||
target_sg_data: MixedValues,
|
||||
*,
|
||||
input_inst_data: Optional[MixedValues] = None,
|
||||
target_inst_data: Optional[MixedValues] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> ConvertLayoutOp:
|
||||
return ConvertLayoutOp(
|
||||
target,
|
||||
input_sg_layout,
|
||||
input_sg_data,
|
||||
target_sg_layout,
|
||||
target_sg_data,
|
||||
input_inst_data=input_inst_data,
|
||||
target_inst_data=target_inst_data,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
).result
|
||||
|
||||
@@ -400,3 +400,72 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_layout_a
|
||||
func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
|
||||
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
|
||||
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
|
||||
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
|
||||
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
|
||||
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
|
||||
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
|
||||
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
|
||||
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
|
||||
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
|
||||
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
|
||||
// CHECK: = xegpu.dpas %[[V2]]
|
||||
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
|
||||
// CHECK: transform.xegpu.convert_layout %{{.*}}
|
||||
transform.xegpu.convert_layout %1
|
||||
input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16]
|
||||
target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16]
|
||||
: (!transform.any_value) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_layout_a_sg_param
|
||||
func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
|
||||
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
|
||||
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
|
||||
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
|
||||
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
|
||||
// CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
|
||||
// CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
|
||||
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
|
||||
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
|
||||
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
|
||||
%5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
|
||||
// CHECK: = xegpu.dpas %[[V2]]
|
||||
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
|
||||
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
|
||||
// CHECK: transform.xegpu.convert_layout %{{.*}}
|
||||
transform.xegpu.convert_layout %1
|
||||
input_sg_layout = [%layout0, 4] input_sg_data = [32, 32] input_inst_data = [32, 16]
|
||||
target_sg_layout = [%layout0, 4] target_sg_data = [32, 32] target_inst_data = [8, 16]
|
||||
: (!transform.any_value, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ def getDescOpDefaultIndex():
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
|
||||
desc_handle = xegpu.GetDescOp(operand)
|
||||
desc_handle = xegpu.get_desc_op(operand)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: getDescOpDefaultIndex
|
||||
# CHECK: transform.xegpu.get_desc_op %
|
||||
@@ -39,7 +39,7 @@ def setDescLayoutMinimal():
|
||||
transform.OperationType.get("xegpu.create_nd_tdesc"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
|
||||
xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: setDescLayoutMinimal
|
||||
# CHECK: %0 = transform.xegpu.set_desc_layout %
|
||||
@@ -55,7 +55,7 @@ def setDescLayoutInstData():
|
||||
transform.OperationType.get("xegpu.create_nd_tdesc"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
xegpu.SetDescLayoutOp(
|
||||
xegpu.set_desc_layout(
|
||||
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
|
||||
)
|
||||
transform.YieldOp()
|
||||
@@ -74,7 +74,7 @@ def setOpLayoutAttrOperandMinimal():
|
||||
transform.OperationType.get("xegpu.dpas"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
xegpu.SetOpLayoutAttrOp(
|
||||
xegpu.set_op_layout_attr(
|
||||
sequence.bodyTarget,
|
||||
sg_layout=[6, 4],
|
||||
sg_data=[32, 16],
|
||||
@@ -97,7 +97,7 @@ def setOpLayoutAttrResult():
|
||||
transform.OperationType.get("xegpu.dpas"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
xegpu.SetOpLayoutAttrOp(
|
||||
xegpu.set_op_layout_attr(
|
||||
sequence.bodyTarget,
|
||||
index=0,
|
||||
sg_layout=[6, 4],
|
||||
@@ -193,3 +193,57 @@ def insertPrefetchNbPrefetchParam():
|
||||
# CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
|
||||
# CHECK: transform.xegpu.insert_prefetch %[[OPR]]
|
||||
# CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
|
||||
|
||||
|
||||
@run
|
||||
def ConvertLayoutMinimal():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("xegpu.dpas"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
|
||||
xegpu.convert_layout(
|
||||
operand,
|
||||
input_sg_layout=[6, 4],
|
||||
input_sg_data=[32, 16],
|
||||
target_sg_layout=[6, 4],
|
||||
target_sg_data=[8, 16],
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: ConvertLayoutMinimal
|
||||
# CHECK: transform.xegpu.convert_layout %
|
||||
# CHECK: input_sg_layout = [6, 4]
|
||||
# CHECK: input_sg_data = [32, 16]
|
||||
# CHECK: target_sg_layout = [6, 4]
|
||||
# CHECK: target_sg_data = [8, 16]
|
||||
|
||||
|
||||
@run
|
||||
def ConvertLayout():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("xegpu.dpas"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
|
||||
xegpu.convert_layout(
|
||||
operand,
|
||||
input_sg_layout=[6, 4],
|
||||
input_sg_data=[32, 32],
|
||||
input_inst_data=[32, 16],
|
||||
target_sg_layout=[6, 4],
|
||||
target_sg_data=[32, 32],
|
||||
target_inst_data=[8, 16],
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: ConvertLayout
|
||||
# CHECK: transform.xegpu.convert_layout %
|
||||
# CHECK: input_sg_layout = [6, 4]
|
||||
# CHECK: input_sg_data = [32, 32]
|
||||
# CHECK: input_inst_data = [32, 16]
|
||||
# CHECK: target_sg_layout = [6, 4]
|
||||
# CHECK: target_sg_data = [32, 32]
|
||||
# CHECK: target_inst_data = [8, 16]
|
||||
|
||||
Reference in New Issue
Block a user