mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[MLIR] Introduce utility to hoist affine if/else conditions
This revision introduces a utility to unswitch affine.for/parallel loops by hoisting affine.if operations past surrounding affine.for/parallel. The hoisting works for both perfect/imperfect nests and in the presence of else blocks. The hoisting is currently to as outermost a level as possible. Uses a test pass to test the utility. Add convenience method Operation::getParentWithTrait<Trait>. Depends on D77487. Differential Revision: https://reviews.llvm.org/D77870
This commit is contained in:
@@ -337,13 +337,16 @@ def AffineIfOp : Affine_Op<"if",
|
||||
/// list of AffineIf is not resizable.
|
||||
void setConditional(IntegerSet set, ValueRange operands);
|
||||
|
||||
/// Returns true if an else block exists.
|
||||
bool hasElse() { return !elseRegion().empty(); }
|
||||
|
||||
Block *getThenBlock() {
|
||||
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
|
||||
return &thenRegion().front();
|
||||
}
|
||||
|
||||
Block *getElseBlock() {
|
||||
assert(!elseRegion().empty() && "Empty 'else' region.");
|
||||
assert(hasElse() && "Empty 'else' region.");
|
||||
return &elseRegion().front();
|
||||
}
|
||||
|
||||
@@ -353,7 +356,7 @@ def AffineIfOp : Affine_Op<"if",
|
||||
return OpBuilder(&body, std::prev(body.end()));
|
||||
}
|
||||
OpBuilder getElseBodyBuilder() {
|
||||
assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
|
||||
assert(hasElse() && "No 'else' block");
|
||||
Block &body = elseRegion().front();
|
||||
return OpBuilder(&body, std::prev(body.end()));
|
||||
}
|
||||
@@ -491,6 +494,9 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
|
||||
|
||||
Block *getBody();
|
||||
OpBuilder getBodyBuilder();
|
||||
MutableArrayRef<BlockArgument> getIVs() {
|
||||
return getBody()->getArguments();
|
||||
}
|
||||
void setSteps(ArrayRef<int64_t> newSteps);
|
||||
|
||||
static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }
|
||||
|
||||
29
mlir/include/mlir/Dialect/Affine/Utils.h
Normal file
29
mlir/include/mlir/Dialect/Affine/Utils.h
Normal file
@@ -0,0 +1,29 @@
|
||||
//===- Utils.h - Affine dialect utilities -----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file declares a set of utilities for the affine dialect ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_AFFINE_UTILS_H
|
||||
#define MLIR_DIALECT_AFFINE_UTILS_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineIfOp;
|
||||
struct LogicalResult;
|
||||
|
||||
/// Hoists out affine.if/else to as high as possible, i.e., past all invariant
|
||||
/// affine.fors/parallel's. Returns success if any hoisting happened; folded` is
|
||||
/// set to true if the op was folded or erased. This hoisting could lead to
|
||||
/// significant code expansion in some cases.
|
||||
LogicalResult hoistAffineIfOp(AffineIfOp ifOp, bool *folded = nullptr);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_AFFINE_UTILS_H
|
||||
@@ -116,6 +116,12 @@ public:
|
||||
return getOperation()->getParentOfType<OpTy>();
|
||||
}
|
||||
|
||||
/// Returns the closest surrounding parent operation with trait `Trait`.
|
||||
template <template <typename T> class Trait>
|
||||
Operation *getParentWithTrait() {
|
||||
return getOperation()->getParentWithTrait<Trait>();
|
||||
}
|
||||
|
||||
/// Return the context this operation belongs to.
|
||||
MLIRContext *getContext() { return getOperation()->getContext(); }
|
||||
|
||||
|
||||
@@ -126,6 +126,16 @@ public:
|
||||
return OpTy();
|
||||
}
|
||||
|
||||
/// Returns the closest surrounding parent operation with trait `Trait`.
|
||||
template <template <typename T> class Trait>
|
||||
Operation *getParentWithTrait() {
|
||||
Operation *op = this;
|
||||
while ((op = op->getParentOp()))
|
||||
if (op->hasTrait<Trait>())
|
||||
return op;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Return true if this operation is a proper ancestor of the `other`
|
||||
/// operation.
|
||||
bool isProperAncestor(Operation *other);
|
||||
|
||||
@@ -19,3 +19,4 @@ target_link_libraries(MLIRAffine
|
||||
)
|
||||
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
||||
11
mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
Normal file
11
mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
add_mlir_dialect_library(MLIRAffineUtils
|
||||
Utils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
|
||||
|
||||
)
|
||||
target_link_libraries(MLIRAffineUtils
|
||||
PUBLIC
|
||||
MLIRAffine
|
||||
)
|
||||
175
mlir/lib/Dialect/Affine/Utils/Utils.cpp
Normal file
175
mlir/lib/Dialect/Affine/Utils/Utils.cpp
Normal file
@@ -0,0 +1,175 @@
|
||||
//===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements miscellaneous transformation utilities for the Affine
|
||||
// dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/Utils.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Promotes the `then` or the `else` block of `ifOp` (depending on whether
|
||||
/// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
|
||||
/// the rest of the op.
|
||||
static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
|
||||
if (elseBlock)
|
||||
assert(ifOp.hasElse() && "else block expected");
|
||||
|
||||
Block *destBlock = ifOp.getOperation()->getBlock();
|
||||
Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
|
||||
destBlock->getOperations().splice(
|
||||
Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
|
||||
std::prev(srcBlock->end()));
|
||||
ifOp.erase();
|
||||
}
|
||||
|
||||
/// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
|
||||
/// on. The `ifOp` could be hoisted and placed right before such an operation.
|
||||
/// This method assumes that the ifOp has been canonicalized (to be correct and
|
||||
/// effective).
|
||||
static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
|
||||
// Walk up the parents past all for op that this conditional is invariant on.
|
||||
auto ifOperands = ifOp.getOperands();
|
||||
auto res = ifOp.getOperation();
|
||||
while (!isa<FuncOp>(res->getParentOp())) {
|
||||
auto *parentOp = res->getParentOp();
|
||||
if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
|
||||
if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
|
||||
break;
|
||||
} else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
|
||||
for (auto iv : parallelOp.getIVs())
|
||||
if (llvm::is_contained(ifOperands, iv))
|
||||
break;
|
||||
} else if (!isa<AffineIfOp>(parentOp)) {
|
||||
// Won't walk up past anything other than affine.for/if ops.
|
||||
break;
|
||||
}
|
||||
// You can always hoist up past any affine.if ops.
|
||||
res = parentOp;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
|
||||
/// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
|
||||
/// otherwise the same `ifOp`.
|
||||
static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
|
||||
// No hoisting to do.
|
||||
if (hoistOverOp == ifOp)
|
||||
return ifOp;
|
||||
|
||||
// Create the hoisted 'if' first. Then, clone the op we are hoisting over for
|
||||
// the else block. Then drop the else block of the original 'if' in the 'then'
|
||||
// branch while promoting its then block, and analogously drop the 'then'
|
||||
// block of the original 'if' from the 'else' branch while promoting its else
|
||||
// block.
|
||||
BlockAndValueMapping operandMap;
|
||||
OpBuilder b(hoistOverOp);
|
||||
auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
|
||||
ifOp.getOperands(),
|
||||
/*elseBlock=*/true);
|
||||
|
||||
// Create a clone of hoistOverOp to use for the else branch of the hoisted
|
||||
// conditional. The else block may get optimized away if empty.
|
||||
Operation *hoistOverOpClone = nullptr;
|
||||
// We use this unique name to identify/find `ifOp`'s clone in the else
|
||||
// version.
|
||||
Identifier idForIfOp = b.getIdentifier("__mlir_if_hoisting");
|
||||
operandMap.clear();
|
||||
b.setInsertionPointAfter(hoistOverOp);
|
||||
// We'll set an attribute to identify this op in a clone of this sub-tree.
|
||||
ifOp.setAttr(idForIfOp, b.getBoolAttr(true));
|
||||
hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
|
||||
|
||||
// Promote the 'then' block of the original affine.if in the then version.
|
||||
promoteIfBlock(ifOp, /*elseBlock=*/false);
|
||||
|
||||
// Move the then version to the hoisted if op's 'then' block.
|
||||
auto *thenBlock = hoistedIfOp.getThenBlock();
|
||||
thenBlock->getOperations().splice(thenBlock->begin(),
|
||||
hoistOverOp->getBlock()->getOperations(),
|
||||
Block::iterator(hoistOverOp));
|
||||
|
||||
// Find the clone of the original affine.if op in the else version.
|
||||
AffineIfOp ifCloneInElse;
|
||||
hoistOverOpClone->walk([&](AffineIfOp ifClone) {
|
||||
if (!ifClone.getAttr(idForIfOp))
|
||||
return WalkResult::advance();
|
||||
ifCloneInElse = ifClone;
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
assert(ifCloneInElse && "if op clone should exist");
|
||||
// For the else block, promote the else block of the original 'if' if it had
|
||||
// one; otherwise, the op itself is to be erased.
|
||||
if (!ifCloneInElse.hasElse())
|
||||
ifCloneInElse.erase();
|
||||
else
|
||||
promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
|
||||
|
||||
// Move the else version into the else block of the hoisted if op.
|
||||
auto *elseBlock = hoistedIfOp.getElseBlock();
|
||||
elseBlock->getOperations().splice(
|
||||
elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
|
||||
Block::iterator(hoistOverOpClone));
|
||||
|
||||
return hoistedIfOp;
|
||||
}
|
||||
|
||||
// Returns success if any hoisting happened.
|
||||
LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
|
||||
// Apply canonicalization patterns and folding - this is necessary for the
|
||||
// hoisting check to be correct (operands should be composed), and to be more
|
||||
// effective (no unused operands). Since the pattern rewriter's folding is
|
||||
// entangled with application of patterns, we may fold/end up erasing the op,
|
||||
// in which case we return with `folded` being set.
|
||||
OwningRewritePatternList patterns;
|
||||
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
|
||||
bool erased;
|
||||
applyOpPatternsAndFold(ifOp, patterns, &erased);
|
||||
if (erased) {
|
||||
if (folded)
|
||||
*folded = true;
|
||||
return failure();
|
||||
}
|
||||
if (folded)
|
||||
*folded = false;
|
||||
|
||||
// The folding above should have ensured this, but the affine.if's
|
||||
// canonicalization is missing composition of affine.applys into it.
|
||||
assert(llvm::all_of(ifOp.getOperands(),
|
||||
[](Value v) {
|
||||
return isTopLevelValue(v) || isForInductionVar(v);
|
||||
}) &&
|
||||
"operands not composed");
|
||||
|
||||
// We are going hoist as high as possible.
|
||||
// TODO: this could be customized in the future.
|
||||
auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
|
||||
|
||||
AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
|
||||
// Nothing to hoist over.
|
||||
if (hoistedIfOp == ifOp)
|
||||
return failure();
|
||||
|
||||
// Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
|
||||
// a sequence of affine.fors that are all perfectly nested).
|
||||
applyPatternsAndFoldGreedily(
|
||||
hoistedIfOp.getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
|
||||
std::move(patterns));
|
||||
|
||||
return success();
|
||||
}
|
||||
258
mlir/test/Dialect/Affine/loop-unswitch.mlir
Normal file
258
mlir/test/Dialect/Affine/loop-unswitch.mlir
Normal file
@@ -0,0 +1,258 @@
|
||||
// RUN: mlir-opt %s -split-input-file -test-affine-loop-unswitch | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[SET:.*]] = affine_set<(d0) : (d0 - 2 >= 0)>
|
||||
|
||||
// CHECK-LABEL: func @if_else_imperfect
|
||||
func @if_else_imperfect(%A : memref<100xi32>, %B : memref<100xi32>, %v : i32) {
|
||||
// CHECK: %[[A:.*]]: memref<100xi32>, %[[B:.*]]: memref
|
||||
affine.for %i = 0 to 100 {
|
||||
affine.load %A[%i] : memref<100xi32>
|
||||
affine.for %j = 0 to 100 {
|
||||
affine.load %A[%j] : memref<100xi32>
|
||||
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
|
||||
affine.load %B[%j] : memref<100xi32>
|
||||
}
|
||||
call @external() : () -> ()
|
||||
}
|
||||
affine.load %A[%i] : memref<100xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
func @external()
|
||||
|
||||
// CHECK: affine.for %[[I:.*]] = 0 to 100 {
|
||||
// CHECK-NEXT: affine.load %[[A]][%[[I]]]
|
||||
// CHECK-NEXT: affine.if #[[SET]](%[[I]]) {
|
||||
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 100 {
|
||||
// CHECK-NEXT: affine.load %[[A]][%[[J]]]
|
||||
// CHECK-NEXT: affine.load %[[B]][%[[J]]]
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: affine.for %[[JJ:.*]] = 0 to 100 {
|
||||
// CHECK-NEXT: affine.load %[[A]][%[[JJ]]]
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.load %[[A]][%[[I]]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// -----
|
||||
|
||||
func @foo()
|
||||
func @bar()
|
||||
func @abc()
|
||||
func @xyz()
|
||||
|
||||
// CHECK-LABEL: func @if_then_perfect
|
||||
func @if_then_perfect(%A : memref<100xi32>, %v : i32) {
|
||||
affine.for %i = 0 to 100 {
|
||||
affine.for %j = 0 to 100 {
|
||||
affine.for %k = 0 to 100 {
|
||||
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
|
||||
affine.load %A[%i] : memref<100xi32>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.if
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NOT: else
|
||||
|
||||
|
||||
// CHECK-LABEL: func @if_else_perfect
|
||||
func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
|
||||
affine.for %i = 0 to 99 {
|
||||
affine.for %j = 0 to 100 {
|
||||
affine.for %k = 0 to 100 {
|
||||
call @foo() : () -> ()
|
||||
affine.if affine_set<(d0, d1) : (d0 - 2 >= 0, -d1 + 80 >= 0)>(%i, %j) {
|
||||
affine.load %A[%i] : memref<100xi32>
|
||||
call @abc() : () -> ()
|
||||
} else {
|
||||
affine.load %A[%i + 1] : memref<100xi32>
|
||||
call @xyz() : () -> ()
|
||||
}
|
||||
call @bar() : () -> ()
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.if
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: call @foo
|
||||
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}]
|
||||
// CHECK-NEXT: call @abc
|
||||
// CHECK-NEXT: call @bar
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: else
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: call @foo
|
||||
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} + 1]
|
||||
// CHECK-NEXT: call @xyz
|
||||
// CHECK-NEXT: call @bar
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-LABEL: func @if_then_imperfect
|
||||
func @if_then_imperfect(%A : memref<100xi32>, %N : index) {
|
||||
affine.for %i = 0 to 100 {
|
||||
affine.load %A[0] : memref<100xi32>
|
||||
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%N) {
|
||||
affine.load %A[%i] : memref<100xi32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: affine.if
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// Check if unused operands are dropped: hence, hoisting is possible.
|
||||
// CHECK-LABEL: func @hoist_after_canonicalize
|
||||
func @hoist_after_canonicalize() {
|
||||
affine.for %i = 0 to 100 {
|
||||
affine.for %j = 0 to 100 {
|
||||
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%j) {
|
||||
affine.if affine_set<(d0, d1) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%i, %j) {
|
||||
// The call to external is to avoid DCE on affine.if.
|
||||
call @foo() : () -> ()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.if
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.if
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// CHECK-LABEL: func @handle_dead_if
|
||||
func @handle_dead_if(%N : index) {
|
||||
affine.for %i = 0 to 100 {
|
||||
affine.if affine_set<(d0) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%N) {
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// -----
|
||||
|
||||
// A test case with affine.parallel.
|
||||
|
||||
#flb1 = affine_map<(d0) -> (d0 * 3)>
|
||||
#fub1 = affine_map<(d0) -> (d0 * 3 + 3)>
|
||||
#flb0 = affine_map<(d0) -> (d0 * 16)>
|
||||
#fub0 = affine_map<(d0) -> (d0 * 16 + 16)>
|
||||
#pub1 = affine_map<(d0)[s0] -> (s0, d0 * 3 + 3)>
|
||||
#pub0 = affine_map<(d0)[s0] -> (s0, d0 * 16 + 16)>
|
||||
#lb1 = affine_map<(d0) -> (d0 * 480)>
|
||||
#ub1 = affine_map<(d0)[s0] -> (s0, d0 * 480 + 480)>
|
||||
#lb0 = affine_map<(d0) -> (d0 * 110)>
|
||||
#ub0 = affine_map<(d0)[s0] -> (d0 * 110 + 110, s0 floordiv 3)>
|
||||
|
||||
#set0 = affine_set<(d0, d1)[s0, s1] : (d0 * -16 + s0 - 16 >= 0, d1 * -3 + s1 - 3 >= 0)>
|
||||
|
||||
// CHECK-LABEL: func @perfect_if_else
|
||||
func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 : index,
|
||||
%arg5 : index, %arg6 : index, %sym : index) {
|
||||
affine.for %arg7 = #lb0(%arg5) to min #ub0(%arg5)[%sym] {
|
||||
affine.parallel (%i0, %j0) = (0, 0) to (symbol(%sym), 100) step (10, 10) {
|
||||
affine.for %arg8 = #lb1(%arg4) to min #ub1(%arg4)[%sym] {
|
||||
affine.if #set0(%arg6, %arg7)[%sym, %sym] {
|
||||
affine.for %arg9 = #flb0(%arg6) to #fub0(%arg6) {
|
||||
affine.for %arg10 = #flb1(%arg7) to #fub1(%arg7) {
|
||||
affine.load %arg0[0, 0] : memref<?x?xf64>
|
||||
}
|
||||
}
|
||||
} else {
|
||||
affine.for %arg9 = #lb0(%arg6) to min #pub0(%arg6)[%sym] {
|
||||
affine.for %arg10 = #lb1(%arg7) to min #pub1(%arg7)[%sym] {
|
||||
affine.load %arg0[0, 0] : memref<?x?xf64>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.if
|
||||
// CHECK-NEXT: affine.parallel
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: affine.parallel
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.load
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// With multiple if ops in a function, the test pass just looks for the first if
|
||||
// op that it is able to successfully hoist.
|
||||
|
||||
// CHECK-LABEL: func @multiple_if
|
||||
func @multiple_if(%N : index) {
|
||||
affine.if affine_set<() : (0 == 0)>() {
|
||||
call @external() : () -> ()
|
||||
}
|
||||
affine.for %i = 0 to 100 {
|
||||
affine.if affine_set<()[s0] : (s0 >= 0)>()[%N] {
|
||||
call @external() : () -> ()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: affine.if
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.if
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: call
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
func @external()
|
||||
@@ -1,5 +1,6 @@
|
||||
add_llvm_library(MLIRAffineTransformsTestPasses
|
||||
TestAffineDataCopy.cpp
|
||||
TestAffineLoopUnswitching.cpp
|
||||
TestLoopPermutation.cpp
|
||||
TestParallelismDetection.cpp
|
||||
TestVectorizationUtils.cpp
|
||||
|
||||
60
mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
Normal file
60
mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
//===- TestAffineLoopUnswitching.cpp - Test affine if/else hoisting -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass to hoist affine if/else structures.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Utils.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#define PASS_NAME "test-affine-loop-unswitch"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// This pass applies the permutation on the first maximal perfect nest.
|
||||
struct TestAffineLoopUnswitching
|
||||
: public PassWrapper<TestAffineLoopUnswitching, FunctionPass> {
|
||||
TestAffineLoopUnswitching() = default;
|
||||
TestAffineLoopUnswitching(const TestAffineLoopUnswitching &pass) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
/// The maximum number of iterations to run this for.
|
||||
constexpr static unsigned kMaxIterations = 5;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestAffineLoopUnswitching::runOnFunction() {
|
||||
// Each hoisting invalidates a lot of IR around. Just stop the walk after the
|
||||
// first if/else hoisting, and repeat until no more hoisting can be done, or
|
||||
// the maximum number of iterations have been run.
|
||||
auto func = getFunction();
|
||||
unsigned i = 0;
|
||||
do {
|
||||
auto walkFn = [](AffineIfOp op) {
|
||||
return succeeded(hoistAffineIfOp(op)) ? WalkResult::interrupt()
|
||||
: WalkResult::advance();
|
||||
};
|
||||
if (func.walk(walkFn).wasInterrupted())
|
||||
break;
|
||||
} while (++i < kMaxIterations);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestAffineLoopUnswitchingPass() {
|
||||
PassRegistration<TestAffineLoopUnswitching>(
|
||||
PASS_NAME, "Tests affine loop unswitching / if/else hoisting");
|
||||
}
|
||||
} // namespace mlir
|
||||
@@ -40,6 +40,7 @@ void registerSimpleParametricTilingPass();
|
||||
void registerSymbolTestPasses();
|
||||
void registerTestAffineDataCopyPass();
|
||||
void registerTestAllReduceLoweringPass();
|
||||
void registerTestAffineLoopUnswitchingPass();
|
||||
void registerTestLinalgMatmulToVectorPass();
|
||||
void registerTestLoopPermutationPass();
|
||||
void registerTestCallGraphPass();
|
||||
@@ -103,6 +104,7 @@ void registerTestPasses() {
|
||||
registerSymbolTestPasses();
|
||||
registerTestAffineDataCopyPass();
|
||||
registerTestAllReduceLoweringPass();
|
||||
registerTestAffineLoopUnswitchingPass();
|
||||
registerTestLinalgMatmulToVectorPass();
|
||||
registerTestLoopPermutationPass();
|
||||
registerTestCallGraphPass();
|
||||
|
||||
Reference in New Issue
Block a user