mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (#166865)
Adds `transform.xegpu.set_gpu_launch_threads` that overrides `gpu.launch` operation threads.
This commit is contained in:
@@ -161,4 +161,43 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
|
||||
}];
|
||||
}
|
||||
|
||||
def SetGPULaunchThreadsOp
|
||||
: Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
TransformOpInterface
|
||||
]> {
|
||||
|
||||
let summary = "Set number of threads for a given gpu.launch operation";
|
||||
let description = [{
|
||||
Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
|
||||
);
|
||||
let results = (outs);
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$target
|
||||
`threads` `=` custom<DynamicIndexList>($threads, $static_threads)
|
||||
attr-dict `:` qualified(type(operands))
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure apply(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::transform::TransformResults &transformResults,
|
||||
::mlir::transform::TransformState &state);
|
||||
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticThreads(), getThreads(), b);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // XEGPU_TRANSFORM_OPS
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
||||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
||||
@@ -341,6 +342,69 @@ void transform::SetOpLayoutAttrOp::getEffects(
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
void transform::SetGPULaunchThreadsOp::build(
|
||||
OpBuilder &builder, OperationState &ostate, Value target,
|
||||
ArrayRef<OpFoldResult> mixedThreads) {
|
||||
SmallVector<int64_t> staticThreads;
|
||||
SmallVector<Value> dynamicThreads;
|
||||
dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
|
||||
build(builder, ostate, target.getType(),
|
||||
/*target=*/target,
|
||||
/*threads=*/dynamicThreads,
|
||||
/*static_threads=*/staticThreads);
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto targetOps = state.getPayloadOps(getTarget());
|
||||
if (!llvm::hasSingleElement(targetOps)) {
|
||||
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
|
||||
<< llvm::range_size(targetOps) << ")";
|
||||
}
|
||||
Operation *target = *targetOps.begin();
|
||||
|
||||
auto launchOp = dyn_cast<gpu::LaunchOp>(target);
|
||||
if (!launchOp) {
|
||||
auto diag = emitSilenceableFailure(getLoc())
|
||||
<< "Expected a gpu.launch op, but got: " << target->getName();
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
SmallVector<int32_t> threads;
|
||||
DiagnosedSilenceableFailure status =
|
||||
convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
|
||||
if (!status.succeeded())
|
||||
return status;
|
||||
|
||||
if (threads.size() != 3) {
|
||||
return emitSilenceableFailure(getLoc())
|
||||
<< "Expected threads argument to consist of three values (got "
|
||||
<< threads.size() << ")";
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(launchOp);
|
||||
auto createConstValue = [&](int value) {
|
||||
return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
|
||||
};
|
||||
|
||||
// Replace threads in-place.
|
||||
launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
|
||||
launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
|
||||
launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::SetGPULaunchThreadsOp::getEffects(
|
||||
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getTargetMutable(), effects);
|
||||
onlyReadsHandle(getThreadsMutable(), effects);
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class XeGPUTransformDialectExtension
|
||||
: public transform::TransformDialectExtension<
|
||||
|
||||
@@ -132,3 +132,39 @@ class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
|
||||
"""Specialization for SetGPULaunchThreadsOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
launch_op: Union[Operation, Value],
|
||||
threads: MixedValues,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
(
|
||||
dynamic_threads,
|
||||
static_threads,
|
||||
_,
|
||||
) = _dispatch_dynamic_index_list(threads)
|
||||
|
||||
super().__init__(
|
||||
_get_op_result_or_value(launch_op),
|
||||
dynamic_threads,
|
||||
static_threads=static_threads,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def set_gpu_launch_threads(
|
||||
launch_op: Union[Operation, Value],
|
||||
threads: MixedValues,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> SetGPULaunchThreadsOp:
|
||||
return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)
|
||||
|
||||
@@ -71,3 +71,56 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
|
||||
%c32 = arith.constant 32 : index // expected-note {{target op}}
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error@below {{Expected a gpu.launch op, but got: arith.constant}}
|
||||
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
|
||||
%c32 = arith.constant 32 : index
|
||||
%c64 = arith.constant 64 : index
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
|
||||
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
|
||||
gpu.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error@below {{Expected threads argument to consist of three values (got 2)}}
|
||||
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,6 +230,7 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @set_op_layout_attr_operand1
|
||||
@@ -252,3 +253,58 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @set_gpu_launch_threads
|
||||
func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
// CHECK: %[[C16:.+]] = arith.constant 16 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
// CHECK: %[[C8:.+]] = arith.constant 8 : index
|
||||
// CHECK: %[[C4:.+]] = arith.constant 4 : index
|
||||
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
|
||||
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
|
||||
// CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
|
||||
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
|
||||
gpu.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
|
||||
transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @set_gpu_launch_threads_param
|
||||
func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
// CHECK: %[[C16:.+]] = arith.constant 16 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
// CHECK: %[[C8:.+]] = arith.constant 8 : index
|
||||
// CHECK: %[[C4:.+]] = arith.constant 4 : index
|
||||
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
|
||||
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
|
||||
// CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
|
||||
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
|
||||
gpu.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
|
||||
%th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
|
||||
transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,3 +113,18 @@ def setOpLayoutAttrResult():
|
||||
# CHECK: sg_layout = [6, 4]
|
||||
# CHECK: sg_data = [32, 16]
|
||||
# CHECK: inst_data = [8, 16]
|
||||
|
||||
|
||||
@run
|
||||
def setGPULaunchThreadsOp():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.OperationType.get("gpu.launch"),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
|
||||
# CHECK: transform.xegpu.set_gpu_launch_threads
|
||||
# CHECK: threads = [8, 4, 1]
|
||||
|
||||
Reference in New Issue
Block a user