diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 0399a5a9afa5..fef9e4bd1721 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -46,16 +46,22 @@ def LoopOutlineOp : Op]> { let summary = "Outlines a loop into a named function"; let description = [{ - Moves the loop into a separate function with the specified name and - replaces the loop in the Payload IR with a call to that function. Takes - care of forwarding values that are used in the loop as function arguments. - If the operand is associated with more than one loop, each loop will be - outlined into a separate function. The provided name is used as a _base_ - for forming actual function names following SymbolTable auto-renaming - scheme to avoid duplicate symbols. Expects that all ops in the Payload IR - have a SymbolTable ancestor (typically true because of the top-level - module). Returns the handle to the list of outlined functions in the same - order as the operand handle. + Moves the loop into a separate function with the specified name and replaces + the loop in the Payload IR with a call to that function. Takes care of + forwarding values that are used in the loop as function arguments. If the + operand is associated with more than one loop, each loop will be outlined + into a separate function. The provided name is used as a _base_ for forming + actual function names following `SymbolTable` auto-renaming scheme to avoid + duplicate symbols. Expects that all ops in the Payload IR have a + `SymbolTable` ancestor (typically true because of the top-level module). + + #### Return Modes + + Returns a handle to the list of outlined functions and a handle to the + corresponding function call operations in the same order as the operand + handle. + + Produces a definite failure if outlining failed for any of the targets. }]; // Note that despite the name of the transform operation and related utility @@ -63,7 +69,8 @@ def LoopOutlineOp : Op transformed; + SmallVector functions; + SmallVector calls; DenseMap symbolTables; for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); @@ -112,9 +113,11 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results, symbolTable.insert(*outlined); call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); } - transformed.push_back(*outlined); + functions.push_back(*outlined); + calls.push_back(call); } - results.set(getTransformed().cast(), transformed); + results.set(getFunction().cast(), functions); + results.set(getCall().cast(), calls); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index a275ea615378..10079d32fd92 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -39,7 +39,8 @@ class LoopOutlineOp: def __init__( self, - result_type: Type, + function_type: Type, + call_type: Type, target: Union[Operation, Value], *, func_name: Union[str, StringAttr], @@ -47,7 +48,8 @@ class LoopOutlineOp: loc=None, ): super().__init__( - result_type, + function_type, + call_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir index da040fff8273..2e15abdd260d 100644 --- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir @@ -54,8 +54,8 @@ func.func @loop_outline_op_multi_region() { } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!pdl.operation) -> !pdl.operation +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{failed to outline}} - transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation + transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir index d876d0f6be9c..8fe9eddf6c19 100644 --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -75,11 +75,11 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) { } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> // CHECK: = transform.loop.outline %{{.*}} - transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation + transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) } // ----- diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py index 02d35f628c17..067a8b60d4f8 100644 --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -33,7 +33,7 @@ def loopOutline(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.OperationType.get("scf.for")) with InsertionPoint(sequence.body): - loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo") + loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo") transform.YieldOp() # CHECK-LABEL: TEST: loopOutline # CHECK: = transform.loop.outline %