diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 9e3f79e64bb1..e60c3f364604 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2247,56 +2247,6 @@ def ConvertConv2DToImg2ColOp : Op, - TransformEachOpTrait, - TransformOpInterface, - ReportTrackingListenerFailuresOpTrait]> { - let description = [{ - Hoists supported tensor subset extract/insert operation pairs out of - immediately enclosing loop iteratively, if the following conditions - are true: - 1. The 2 ops access the same tensor subset. - 2. All operands are invariant under the enclosing loop. - - The supported subset extract/insert operation pairs currently comprise: - - tensor.extract_slice / tensor.insert_slice - - vector.transfer_read / vector.transfer_write on tensors - - Only scf.for loops are currently supported. - - When applied to: - 1. an scf.for loop, hoist out of this loop only. - 2. a non-loop op, apply hoisting to all the contained loop ops. - - #### Return modes: - - The operation always succeeds and returns nothing. - }]; - - let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs); - - let assemblyFormat = [{ - $target - attr-dict - `:` functional-type(operands, results) - }]; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt index d9fbaee80239..d6c5c975c2e9 100644 --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) +add_subdirectory(LoopExtension) add_subdirectory(PDLExtension) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt new file mode 100644 index 000000000000..8f5e510ad39a --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS LoopExtensionOps.td) +mlir_tablegen(LoopExtensionOps.h.inc -gen-op-decls) +mlir_tablegen(LoopExtensionOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen) + +add_mlir_doc(LoopExtensionOps LoopExtensionOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h new file mode 100644 index 000000000000..7a8ed2075ef1 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtension.h @@ -0,0 +1,16 @@ +//===- LoopExtension.h - Loop extension for Transform dialect ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +namespace mlir { +class DialectRegistry; + +namespace transform { +/// Registers the loop extension of the Transform dialect in the given registry. +void registerLoopExtension(DialectRegistry &dialectRegistry); +} // namespace transform +} // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h new file mode 100644 index 000000000000..68cc0699d081 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h @@ -0,0 +1,23 @@ +//===- LoopExtensionOps.h - Loop ext. for Transform dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H +#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc" + +#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td new file mode 100644 index 000000000000..78a8c6ad489a --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td @@ -0,0 +1,76 @@ +//===- LoopExtensionOps.td - Transform dialect operations --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS +#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def HoistLoopInvariantSubsetsOp + : TransformDialectOp<"loop.hoist_loop_invariant_subsets", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Hoist loop invariant subset ops"; + let description = [{ + This transform hoists loop-invariant subset ops out of the targeted + loop-like op. It looks for matching subset extraction/insertion op pairs and + hoists them. The loop body operates on a newly introduced region iter_arg. + + Subset ops are hoisted only from the targeted op. If subset ops should be + hoisted from an entire loop nest, this transformation must be applied to + each loop-like op of the loop nest, starting with the innermost loop and + ending with the outermost loop. + + Example: + ``` + %r = scf.for ... iter_args(%t = %a) -> (tensor) { + %0 = tensor.extract_slice %t[0][5][1] : tensor to tensor<5xf32> + %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>) + %2 = tensor.insert_slice %1 into %t[0][5][1] + : tensor<5xf32> into tensor + scf.yield %2 : tensor + } + ``` + Is transformed to: + ``` + %0 = tensor.extract_slice %a[0][5][1] : tensor to tensor<5xf32> + %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor) { + %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>) + scf.yield %t, %2 : tensor, tensor<5xf32> + } + %r = tensor.insert_slice %new_loop#1 into %new_loop#0 + : tensor<5xf32> into tensor + ``` + + Subset ops are hoisted only if there are no conflicting subset ops. E.g., + if there were a second overlapping extraction in the above example, no ops + could be hoisted safely. + + This transform reads the target handle and modifies the payload. This + transform does not invalidate any handles, but loop-like ops are replaced + with new loop-like ops when a subset op is hoisted. The transform rewriter + updates all handles accordingly. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::LoopLikeOpInterface loopLikeOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td index 16107b3d0869..206a799690aa 100644 --- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td @@ -1,4 +1,4 @@ -//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===// +//===- PDLExtensionOps.td - Transform dialect operations ---*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 8e2ad3a2e34f..c04ce850fb96 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" @@ -74,6 +75,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { scf::registerTransformDialectExtension(registry); sparse_tensor::registerTransformDialectExtension(registry); tensor::registerTransformDialectExtension(registry); + transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h index 579054070f72..3ceef44d799e 100644 --- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h +++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h @@ -18,6 +18,7 @@ namespace mlir { class LoopLikeOpInterface; class Operation; class Region; +class RewriterBase; class Value; /// Given a list of regions, perform loop-invariant code motion. An operation is @@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike); /// %r = tensor.insert_slice %new_loop#1 into %new_loop#0 /// : tensor<5xf32> into tensor /// ``` -LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike); +LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, + LoopLikeOpInterface loopLike); } // end namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 87be3bb85b6e..fd8a1657db3a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3163,35 +3163,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( return DiagnosedSilenceableFailure::success(); } -//===----------------------------------------------------------------------===// -// HoistRedundantTensorSubsetsOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::HoistRedundantTensorSubsetsOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { - auto forOp = dyn_cast(target); - if (forOp) { - linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); - return DiagnosedSilenceableFailure::success(); - } - - // TODO: walking in some reverse / inside-out order would be more efficient - // and would capture more cases. - target->walk([&](scf::ForOp forOp) { - hoistRedundantSubsetExtractInsert(rewriter, forOp); - }); - return DiagnosedSilenceableFailure::success(); -} - -void transform::HoistRedundantTensorSubsetsOp::getEffects( - SmallVectorImpl &effects) { - transform::onlyReadsHandle(getTarget(), effects); - transform::modifiesPayload(effects); -} - //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt index 9e144eba2571..6898d81df7ca 100644 --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(IR) +add_subdirectory(LoopExtension) add_subdirectory(PDLExtension) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt new file mode 100644 index 000000000000..9e1abdd1ca17 --- /dev/null +++ b/mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTransformLoopExtension + LoopExtension.cpp + LoopExtensionOps.cpp + + DEPENDS + MLIRTransformDialectLoopExtensionOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLoopLikeInterface + MLIRTransformDialect + MLIRTransforms +) diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp new file mode 100644 index 000000000000..b33288fd7b99 --- /dev/null +++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp @@ -0,0 +1,34 @@ +//===- LoopExtension.cpp - Loop extension for the Transform dialect -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h" +#include "mlir/IR/DialectRegistry.h" + +using namespace mlir; + +namespace { +/// Loop extension of the Transform dialect. This provides "core" transform +/// operations for loop-like ops. +class LoopExtension + : public transform::TransformDialectExtension { +public: + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::transform::registerLoopExtension(DialectRegistry &dialectRegistry) { + dialectRegistry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp new file mode 100644 index 000000000000..c992fd15946f --- /dev/null +++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp @@ -0,0 +1,36 @@ +//===- LoopExtensionOps.cpp - Loop extension for the Transform dialect ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// HoistLoopInvariantSubsetsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne( + transform::TransformRewriter &rewriter, LoopLikeOpInterface loopLikeOp, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + hoistLoopInvariantSubsets(rewriter, loopLikeOp); + return DiagnosedSilenceableFailure::success(); +} + +void transform::HoistLoopInvariantSubsetsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index e6d8af8f0583..02c3ea1ce9b6 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" @@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() { } void LoopInvariantSubsetHoisting::runOnOperation() { + IRRewriter rewriter(getOperation()->getContext()); // Walk through all loops in a function in innermost-loop-first order. This // way, we first hoist from the inner loop, and place the ops in the outer // loop, which in turn can be further hoisted from. getOperation()->walk([&](LoopLikeOpInterface loopLike) { - (void)hoistLoopInvariantSubsets(loopLike); + (void)hoistLoopInvariantSubsets(rewriter, loopLike); }); } diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 53bdb7aafe41..8f97fd3d9ddf 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -311,12 +311,12 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, /// loop-like op and index into loop-invariant subset locations. Return the /// newly created loop op (that has extra iter_args) or the original loop op if /// nothing was hoisted. -static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike, +static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, + LoopLikeOpInterface loopLike, BlockArgument iterArg) { assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); - IRRewriter rewriter(loopLike.getContext()); MatchingSubsets subsets; if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg))) return loopLike; @@ -367,11 +367,12 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike, OpResult newLoopResult = loopLike.getLoopResults()->back(); extractionOp->moveBefore(loopLike); insertionOp->moveAfter(loopLike); - insertionOp.getUpdatedDestination().replaceAllUsesWith( - insertionOp.getDestinationOperand().get()); + rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(), + insertionOp.getDestinationOperand().get()); extractionOp.getSourceOperand().set( loopLike.getTiedLoopInit(iterArg)->get()); - loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination()); + rewriter.replaceAllUsesWith(loopResult, + insertionOp.getUpdatedDestination()); insertionOp.getSourceOperand().set(newLoopResult); insertionOp.getDestinationOperand().set(loopResult); } @@ -381,13 +382,15 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike, } LoopLikeOpInterface -mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) { +mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter, + LoopLikeOpInterface loopLike) { // Note: As subset ops are getting hoisted, the number of region iter_args // increases. This can enable further hoisting opportunities on the new // iter_args. for (int64_t i = 0; i < static_cast(loopLike.getRegionIterArgs().size()); ++i) { - loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]); + loopLike = hoistSubsetAtIterArg(rewriter, loopLike, + loopLike.getRegionIterArgs()[i]); } return loopLike; } diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir new file mode 100644 index 000000000000..425962757f72 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file \ +// RUN: --verify-diagnostics | FileCheck %s + +// UNSUPPORTED: target=aarch64-pc-windows-msvc + +// CHECK-LABEL: func @test_loop_invariant_subset_hoisting( +// CHECK-SAME: %[[arg:.*]]: tensor +func.func @test_loop_invariant_subset_hoisting(%arg: tensor) -> tensor { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]] + // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]]) + // expected-remark @below{{new loop op}} + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor) { + %1 = tensor.extract_slice %t[0][5][1] : tensor to tensor<5xf32> + // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]]) + %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) + // Obfuscate the IR by inserting at offset %sub instead of 0; both of them + // have the same value. + %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor + // CHECK: scf.yield %[[t]], %[[foo]] + scf.yield %3 : tensor + } + // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0 + // CHECK: return %[[insert]] + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op + // Make sure that the handles are still valid (and were updated in case of + // the loop). + + // expected-remark @below{{1}} + transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op + transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op + // expected-remark @below{{1}} + transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op + // expected-remark @below{{1}} + transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op + + transform.yield + } +} + +// ----- + +// Checks that transform ops from LoopExtensionOps and SCFTransformOps can be +// used together. + +// CHECK-LABEL: func @test_mixed_loop_extension_scf_transform( +func.func @test_mixed_loop_extension_scf_transform(%arg: tensor) -> tensor { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + // CHECK: scf.for + // CHECK: scf.for + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor) { + %1 = "test.foo"(%t) : (tensor) -> (tensor) + scf.yield %1 : tensor + } + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op + transform.loop.unroll %0 { factor = 4 } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt index c7e83d3a7128..436f892a2723 100644 --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -21,5 +21,6 @@ add_mlir_library(MLIRTestTransformDialect MLIRPDLDialect MLIRTransformDialect MLIRTransformDialectTransforms + MLIRTransformLoopExtension MLIRTransformPDLExtension ) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 2cadd4e0d291..99aa78bb3d3d 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4416,6 +4416,7 @@ cc_library( ":SCFTransformOps", ":SparseTensorTransformOps", ":TensorTransformOps", + ":TransformLoopExtension", ":TransformPDLExtension", ":UBToLLVM", ":VectorTransformOps", @@ -8677,6 +8678,7 @@ cc_library( ":TosaToLinalg", ":TransformDialect", ":TransformDialectTransforms", + ":TransformLoopExtension", ":TransformPDLExtension", ":Transforms", ":TransformsPassIncGen", @@ -11401,6 +11403,52 @@ cc_library( ], ) +td_library( + name = "TransformLoopExtensionTdFiles", + srcs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.td"]), + deps = [ + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "TransformLoopExtensionOpsIncGen", + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td", + deps = [":TransformLoopExtensionTdFiles"], +) + +cc_library( + name = "TransformLoopExtension", + srcs = glob(["lib/Dialect/Transform/LoopExtension/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.h"]), + deps = [ + ":IR", + ":LoopLikeInterface", + ":Rewrite", + ":SideEffectInterfaces", + ":Support", + ":TransformDialect", + ":TransformLoopExtensionOpsIncGen", + ":Transforms", + "//llvm:Support", + ], +) + td_library( name = "TransformDialectTransformsTdFiles", srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),