mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][transform] LISH: Add transform op (#70630)
Add a transform op for loop-invariant subset hoisting. Delete the old transform op from the Linalg dialect.
This commit is contained in:
committed by
GitHub
parent
f0535c72bf
commit
b9fe461e73
@@ -2247,56 +2247,6 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HoistRedundantTensorSubsetsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HoistRedundantTensorSubsetsOp :
|
||||
Op<Transform_Dialect, "structured.hoist_redundant_tensor_subsets",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
TransformEachOpTrait,
|
||||
TransformOpInterface,
|
||||
ReportTrackingListenerFailuresOpTrait]> {
|
||||
let description = [{
|
||||
Hoists supported tensor subset extract/insert operation pairs out of
|
||||
immediately enclosing loop iteratively, if the following conditions
|
||||
are true:
|
||||
1. The 2 ops access the same tensor subset.
|
||||
2. All operands are invariant under the enclosing loop.
|
||||
|
||||
The supported subset extract/insert operation pairs currently comprise:
|
||||
- tensor.extract_slice / tensor.insert_slice
|
||||
- vector.transfer_read / vector.transfer_write on tensors
|
||||
|
||||
Only scf.for loops are currently supported.
|
||||
|
||||
When applied to:
|
||||
1. an scf.for loop, hoist out of this loop only.
|
||||
2. a non-loop op, apply hoisting to all the contained loop ops.
|
||||
|
||||
#### Return modes:
|
||||
|
||||
The operation always succeeds and returns nothing.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$target
|
||||
attr-dict
|
||||
`:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::Operation *target,
|
||||
::mlir::transform::ApplyToEachResultList &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceToCopyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(LoopExtension)
|
||||
add_subdirectory(PDLExtension)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
set(LLVM_TARGET_DEFINITIONS LoopExtensionOps.td)
|
||||
mlir_tablegen(LoopExtensionOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(LoopExtensionOps.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(MLIRTransformDialectLoopExtensionOpsIncGen)
|
||||
|
||||
add_mlir_doc(LoopExtensionOps LoopExtensionOps Dialects/ -gen-op-doc)
|
||||
@@ -0,0 +1,16 @@
|
||||
//===- LoopExtension.h - Loop extension for Transform dialect ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace transform {
|
||||
/// Registers the loop extension of the Transform dialect in the given registry.
|
||||
void registerLoopExtension(DialectRegistry &dialectRegistry);
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
||||
@@ -0,0 +1,23 @@
|
||||
//===- LoopExtensionOps.h - Loop ext. for Transform dialect -----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
|
||||
#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS_H
|
||||
@@ -0,0 +1,76 @@
|
||||
//===- LoopExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
|
||||
#define MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
|
||||
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def HoistLoopInvariantSubsetsOp
|
||||
: TransformDialectOp<"loop.hoist_loop_invariant_subsets",
|
||||
[TransformOpInterface, TransformEachOpTrait,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
ReportTrackingListenerFailuresOpTrait]> {
|
||||
let summary = "Hoist loop invariant subset ops";
|
||||
let description = [{
|
||||
This transform hoists loop-invariant subset ops out of the targeted
|
||||
loop-like op. It looks for matching subset extraction/insertion op pairs and
|
||||
hoists them. The loop body operates on a newly introduced region iter_arg.
|
||||
|
||||
Subset ops are hoisted only from the targeted op. If subset ops should be
|
||||
hoisted from an entire loop nest, this transformation must be applied to
|
||||
each loop-like op of the loop nest, starting with the innermost loop and
|
||||
ending with the outermost loop.
|
||||
|
||||
Example:
|
||||
```
|
||||
%r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) {
|
||||
%0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
|
||||
%1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||
%2 = tensor.insert_slice %1 into %t[0][5][1]
|
||||
: tensor<5xf32> into tensor<?xf32>
|
||||
scf.yield %2 : tensor<?xf32>
|
||||
}
|
||||
```
|
||||
Is transformed to:
|
||||
```
|
||||
%0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32>
|
||||
%new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) {
|
||||
%1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||
scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32>
|
||||
}
|
||||
%r = tensor.insert_slice %new_loop#1 into %new_loop#0
|
||||
: tensor<5xf32> into tensor<?xf32>
|
||||
```
|
||||
|
||||
Subset ops are hoisted only if there are no conflicting subset ops. E.g.,
|
||||
if there were a second overlapping extraction in the above example, no ops
|
||||
could be hoisted safely.
|
||||
|
||||
This transform reads the target handle and modifies the payload. This
|
||||
transform does not invalidate any handles, but loop-like ops are replaced
|
||||
with new loop-like ops when a subset op is hoisted. The transform rewriter
|
||||
updates all handles accordingly.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs);
|
||||
let assemblyFormat = "$target attr-dict `:` type($target)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::LoopLikeOpInterface loopLikeOp,
|
||||
::mlir::transform::ApplyToEachResultList &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
|
||||
//===- PDLExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
|
||||
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
|
||||
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
|
||||
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
|
||||
@@ -74,6 +75,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
scf::registerTransformDialectExtension(registry);
|
||||
sparse_tensor::registerTransformDialectExtension(registry);
|
||||
tensor::registerTransformDialectExtension(registry);
|
||||
transform::registerLoopExtension(registry);
|
||||
transform::registerPDLExtension(registry);
|
||||
vector::registerTransformDialectExtension(registry);
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ namespace mlir {
|
||||
class LoopLikeOpInterface;
|
||||
class Operation;
|
||||
class Region;
|
||||
class RewriterBase;
|
||||
class Value;
|
||||
|
||||
/// Given a list of regions, perform loop-invariant code motion. An operation is
|
||||
@@ -108,7 +109,8 @@ size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike);
|
||||
/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0
|
||||
/// : tensor<5xf32> into tensor<?xf32>
|
||||
/// ```
|
||||
LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike);
|
||||
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter,
|
||||
LoopLikeOpInterface loopLike);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
||||
@@ -3163,35 +3163,6 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HoistRedundantTensorSubsetsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::HoistRedundantTensorSubsetsOp::applyToOne(
|
||||
transform::TransformRewriter &rewriter, Operation *target,
|
||||
transform::ApplyToEachResultList &results,
|
||||
transform::TransformState &state) {
|
||||
auto forOp = dyn_cast<scf::ForOp>(target);
|
||||
if (forOp) {
|
||||
linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
// TODO: walking in some reverse / inside-out order would be more efficient
|
||||
// and would capture more cases.
|
||||
target->walk([&](scf::ForOp forOp) {
|
||||
hoistRedundantSubsetExtractInsert(rewriter, forOp);
|
||||
});
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::HoistRedundantTensorSubsetsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::onlyReadsHandle(getTarget(), effects);
|
||||
transform::modifiesPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceToCopyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(LoopExtension)
|
||||
add_subdirectory(PDLExtension)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
||||
13
mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
Normal file
13
mlir/lib/Dialect/Transform/LoopExtension/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
add_mlir_dialect_library(MLIRTransformLoopExtension
|
||||
LoopExtension.cpp
|
||||
LoopExtensionOps.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRTransformDialectLoopExtensionOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLoopLikeInterface
|
||||
MLIRTransformDialect
|
||||
MLIRTransforms
|
||||
)
|
||||
34
mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
Normal file
34
mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
//===- LoopExtension.cpp - Loop extension for the Transform dialect -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Loop extension of the Transform dialect. This provides "core" transform
|
||||
/// operations for loop-like ops.
|
||||
class LoopExtension
|
||||
: public transform::TransformDialectExtension<LoopExtension> {
|
||||
public:
|
||||
void init() {
|
||||
registerTransformOps<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::transform::registerLoopExtension(DialectRegistry &dialectRegistry) {
|
||||
dialectRegistry.addExtensions<LoopExtension>();
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
//===- LoopExtensionOps.cpp - Loop extension for the Transform dialect ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
|
||||
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HoistLoopInvariantSubsetsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure transform::HoistLoopInvariantSubsetsOp::applyToOne(
|
||||
transform::TransformRewriter &rewriter, LoopLikeOpInterface loopLikeOp,
|
||||
transform::ApplyToEachResultList &results,
|
||||
transform::TransformState &state) {
|
||||
hoistLoopInvariantSubsets(rewriter, loopLikeOp);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::HoistLoopInvariantSubsetsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::onlyReadsHandle(getTarget(), effects);
|
||||
transform::modifiesPayload(effects);
|
||||
}
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
|
||||
@@ -47,11 +48,12 @@ void LoopInvariantCodeMotion::runOnOperation() {
|
||||
}
|
||||
|
||||
void LoopInvariantSubsetHoisting::runOnOperation() {
|
||||
IRRewriter rewriter(getOperation()->getContext());
|
||||
// Walk through all loops in a function in innermost-loop-first order. This
|
||||
// way, we first hoist from the inner loop, and place the ops in the outer
|
||||
// loop, which in turn can be further hoisted from.
|
||||
getOperation()->walk([&](LoopLikeOpInterface loopLike) {
|
||||
(void)hoistLoopInvariantSubsets(loopLike);
|
||||
(void)hoistLoopInvariantSubsets(rewriter, loopLike);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -311,12 +311,12 @@ MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
|
||||
/// loop-like op and index into loop-invariant subset locations. Return the
|
||||
/// newly created loop op (that has extra iter_args) or the original loop op if
|
||||
/// nothing was hoisted.
|
||||
static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
|
||||
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
|
||||
LoopLikeOpInterface loopLike,
|
||||
BlockArgument iterArg) {
|
||||
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
|
||||
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
|
||||
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
|
||||
IRRewriter rewriter(loopLike.getContext());
|
||||
MatchingSubsets subsets;
|
||||
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
|
||||
return loopLike;
|
||||
@@ -367,11 +367,12 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
|
||||
OpResult newLoopResult = loopLike.getLoopResults()->back();
|
||||
extractionOp->moveBefore(loopLike);
|
||||
insertionOp->moveAfter(loopLike);
|
||||
insertionOp.getUpdatedDestination().replaceAllUsesWith(
|
||||
insertionOp.getDestinationOperand().get());
|
||||
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
|
||||
insertionOp.getDestinationOperand().get());
|
||||
extractionOp.getSourceOperand().set(
|
||||
loopLike.getTiedLoopInit(iterArg)->get());
|
||||
loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination());
|
||||
rewriter.replaceAllUsesWith(loopResult,
|
||||
insertionOp.getUpdatedDestination());
|
||||
insertionOp.getSourceOperand().set(newLoopResult);
|
||||
insertionOp.getDestinationOperand().set(loopResult);
|
||||
}
|
||||
@@ -381,13 +382,15 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
|
||||
}
|
||||
|
||||
LoopLikeOpInterface
|
||||
mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) {
|
||||
mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
|
||||
LoopLikeOpInterface loopLike) {
|
||||
// Note: As subset ops are getting hoisted, the number of region iter_args
|
||||
// increases. This can enable further hoisting opportunities on the new
|
||||
// iter_args.
|
||||
for (int64_t i = 0;
|
||||
i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
|
||||
loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]);
|
||||
loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
|
||||
loopLike.getRegionIterArgs()[i]);
|
||||
}
|
||||
return loopLike;
|
||||
}
|
||||
|
||||
78
mlir/test/Dialect/Transform/test-loop-transforms.mlir
Normal file
78
mlir/test/Dialect/Transform/test-loop-transforms.mlir
Normal file
@@ -0,0 +1,78 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter --split-input-file \
|
||||
// RUN: --verify-diagnostics | FileCheck %s
|
||||
|
||||
// UNSUPPORTED: target=aarch64-pc-windows-msvc
|
||||
|
||||
// CHECK-LABEL: func @test_loop_invariant_subset_hoisting(
|
||||
// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
|
||||
func.func @test_loop_invariant_subset_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%lb = "test.foo"() : () -> (index)
|
||||
%ub = "test.foo"() : () -> (index)
|
||||
%step = "test.foo"() : () -> (index)
|
||||
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
|
||||
// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
|
||||
// expected-remark @below{{new loop op}}
|
||||
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
|
||||
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
|
||||
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
|
||||
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||
// Obfuscate the IR by inserting at offset %sub instead of 0; both of them
|
||||
// have the same value.
|
||||
%3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
|
||||
// CHECK: scf.yield %[[t]], %[[foo]]
|
||||
scf.yield %3 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
|
||||
// CHECK: return %[[insert]]
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%1 = transform.structured.match ops{["tensor.extract_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%2 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
|
||||
transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op
|
||||
// Make sure that the handles are still valid (and were updated in case of
|
||||
// the loop).
|
||||
|
||||
// expected-remark @below{{1}}
|
||||
transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
|
||||
transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
|
||||
// expected-remark @below{{1}}
|
||||
transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
|
||||
// expected-remark @below{{1}}
|
||||
transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
|
||||
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Checks that transform ops from LoopExtensionOps and SCFTransformOps can be
|
||||
// used together.
|
||||
|
||||
// CHECK-LABEL: func @test_mixed_loop_extension_scf_transform(
|
||||
func.func @test_mixed_loop_extension_scf_transform(%arg: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%lb = "test.foo"() : () -> (index)
|
||||
%ub = "test.foo"() : () -> (index)
|
||||
%step = "test.foo"() : () -> (index)
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
|
||||
%1 = "test.foo"(%t) : (tensor<?xf32>) -> (tensor<?xf32>)
|
||||
scf.yield %1 : tensor<?xf32>
|
||||
}
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.loop.hoist_loop_invariant_subsets %0 : !transform.any_op
|
||||
transform.loop.unroll %0 { factor = 4 } : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
@@ -21,5 +21,6 @@ add_mlir_library(MLIRTestTransformDialect
|
||||
MLIRPDLDialect
|
||||
MLIRTransformDialect
|
||||
MLIRTransformDialectTransforms
|
||||
MLIRTransformLoopExtension
|
||||
MLIRTransformPDLExtension
|
||||
)
|
||||
|
||||
@@ -4416,6 +4416,7 @@ cc_library(
|
||||
":SCFTransformOps",
|
||||
":SparseTensorTransformOps",
|
||||
":TensorTransformOps",
|
||||
":TransformLoopExtension",
|
||||
":TransformPDLExtension",
|
||||
":UBToLLVM",
|
||||
":VectorTransformOps",
|
||||
@@ -8677,6 +8678,7 @@ cc_library(
|
||||
":TosaToLinalg",
|
||||
":TransformDialect",
|
||||
":TransformDialectTransforms",
|
||||
":TransformLoopExtension",
|
||||
":TransformPDLExtension",
|
||||
":Transforms",
|
||||
":TransformsPassIncGen",
|
||||
@@ -11401,6 +11403,52 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "TransformLoopExtensionTdFiles",
|
||||
srcs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.td"]),
|
||||
deps = [
|
||||
":TransformDialectTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "TransformLoopExtensionOpsIncGen",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-op-decls",
|
||||
],
|
||||
"include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h.inc",
|
||||
),
|
||||
(
|
||||
[
|
||||
"-gen-op-defs",
|
||||
],
|
||||
"include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td",
|
||||
deps = [":TransformLoopExtensionTdFiles"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TransformLoopExtension",
|
||||
srcs = glob(["lib/Dialect/Transform/LoopExtension/*.cpp"]),
|
||||
hdrs = glob(["include/mlir/Dialect/Transform/LoopExtension/*.h"]),
|
||||
deps = [
|
||||
":IR",
|
||||
":LoopLikeInterface",
|
||||
":Rewrite",
|
||||
":SideEffectInterfaces",
|
||||
":Support",
|
||||
":TransformDialect",
|
||||
":TransformLoopExtensionOpsIncGen",
|
||||
":Transforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "TransformDialectTransformsTdFiles",
|
||||
srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),
|
||||
|
||||
Reference in New Issue
Block a user