[MLIR] Introduce utility to hoist affine if/else conditions

This revision introduces a utility to unswitch affine.for/parallel loops
by hoisting affine.if operations past surrounding affine.for/parallel.
The hoisting works for both perfect/imperfect nests and in the presence
of else blocks. The hoisting is currently to as outermost a level as
possible.  Uses a test pass to test the utility.
Add convenience method Operation::getParentWithTrait<Trait>.

Depends on D77487.

Differential Revision: https://reviews.llvm.org/D77870
This commit is contained in:
Uday Bondhugula
2020-04-10 17:12:49 +05:30
parent cece7af586
commit af5e83f569
11 changed files with 561 additions and 2 deletions

View File

@@ -337,13 +337,16 @@ def AffineIfOp : Affine_Op<"if",
/// list of AffineIf is not resizable.
void setConditional(IntegerSet set, ValueRange operands);
/// Returns true if an else block exists.
bool hasElse() { return !elseRegion().empty(); }
Block *getThenBlock() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
return &thenRegion().front();
}
Block *getElseBlock() {
assert(!elseRegion().empty() && "Empty 'else' region.");
assert(hasElse() && "Empty 'else' region.");
return &elseRegion().front();
}
@@ -353,7 +356,7 @@ def AffineIfOp : Affine_Op<"if",
return OpBuilder(&body, std::prev(body.end()));
}
OpBuilder getElseBodyBuilder() {
assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
assert(hasElse() && "No 'else' block");
Block &body = elseRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
@@ -491,6 +494,9 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> {
Block *getBody();
OpBuilder getBodyBuilder();
MutableArrayRef<BlockArgument> getIVs() {
return getBody()->getArguments();
}
void setSteps(ArrayRef<int64_t> newSteps);
static StringRef getLowerBoundsMapAttrName() { return "lowerBoundsMap"; }

View File

@@ -0,0 +1,29 @@
//===- Utils.h - Affine dialect utilities -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file declares a set of utilities for the affine dialect ops.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AFFINE_UTILS_H
#define MLIR_DIALECT_AFFINE_UTILS_H
namespace mlir {
class AffineIfOp;
struct LogicalResult;
/// Hoists out affine.if/else to as high as possible, i.e., past all invariant
/// affine.fors/parallel's. Returns success if any hoisting happened; folded` is
/// set to true if the op was folded or erased. This hoisting could lead to
/// significant code expansion in some cases.
LogicalResult hoistAffineIfOp(AffineIfOp ifOp, bool *folded = nullptr);
} // namespace mlir
#endif // MLIR_DIALECT_AFFINE_UTILS_H

View File

@@ -116,6 +116,12 @@ public:
return getOperation()->getParentOfType<OpTy>();
}
/// Returns the closest surrounding parent operation with trait `Trait`.
template <template <typename T> class Trait>
Operation *getParentWithTrait() {
return getOperation()->getParentWithTrait<Trait>();
}
/// Return the context this operation belongs to.
MLIRContext *getContext() { return getOperation()->getContext(); }

View File

@@ -126,6 +126,16 @@ public:
return OpTy();
}
/// Returns the closest surrounding parent operation with trait `Trait`.
template <template <typename T> class Trait>
Operation *getParentWithTrait() {
Operation *op = this;
while ((op = op->getParentOp()))
if (op->hasTrait<Trait>())
return op;
return nullptr;
}
/// Return true if this operation is a proper ancestor of the `other`
/// operation.
bool isProperAncestor(Operation *other);

View File

@@ -19,3 +19,4 @@ target_link_libraries(MLIRAffine
)
add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@@ -0,0 +1,11 @@
add_mlir_dialect_library(MLIRAffineUtils
Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
)
target_link_libraries(MLIRAffineUtils
PUBLIC
MLIRAffine
)

View File

@@ -0,0 +1,175 @@
//===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous transformation utilities for the Affine
// dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
/// Promotes the `then` or the `else` block of `ifOp` (depending on whether
/// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
/// the rest of the op.
static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
if (elseBlock)
assert(ifOp.hasElse() && "else block expected");
Block *destBlock = ifOp.getOperation()->getBlock();
Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
destBlock->getOperations().splice(
Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
std::prev(srcBlock->end()));
ifOp.erase();
}
/// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
/// on. The `ifOp` could be hoisted and placed right before such an operation.
/// This method assumes that the ifOp has been canonicalized (to be correct and
/// effective).
static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
// Walk up the parents past all for op that this conditional is invariant on.
auto ifOperands = ifOp.getOperands();
auto res = ifOp.getOperation();
while (!isa<FuncOp>(res->getParentOp())) {
auto *parentOp = res->getParentOp();
if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
break;
} else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
for (auto iv : parallelOp.getIVs())
if (llvm::is_contained(ifOperands, iv))
break;
} else if (!isa<AffineIfOp>(parentOp)) {
// Won't walk up past anything other than affine.for/if ops.
break;
}
// You can always hoist up past any affine.if ops.
res = parentOp;
}
return res;
}
/// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
/// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
/// otherwise the same `ifOp`.
static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
// No hoisting to do.
if (hoistOverOp == ifOp)
return ifOp;
// Create the hoisted 'if' first. Then, clone the op we are hoisting over for
// the else block. Then drop the else block of the original 'if' in the 'then'
// branch while promoting its then block, and analogously drop the 'then'
// block of the original 'if' from the 'else' branch while promoting its else
// block.
BlockAndValueMapping operandMap;
OpBuilder b(hoistOverOp);
auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
ifOp.getOperands(),
/*elseBlock=*/true);
// Create a clone of hoistOverOp to use for the else branch of the hoisted
// conditional. The else block may get optimized away if empty.
Operation *hoistOverOpClone = nullptr;
// We use this unique name to identify/find `ifOp`'s clone in the else
// version.
Identifier idForIfOp = b.getIdentifier("__mlir_if_hoisting");
operandMap.clear();
b.setInsertionPointAfter(hoistOverOp);
// We'll set an attribute to identify this op in a clone of this sub-tree.
ifOp.setAttr(idForIfOp, b.getBoolAttr(true));
hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
// Promote the 'then' block of the original affine.if in the then version.
promoteIfBlock(ifOp, /*elseBlock=*/false);
// Move the then version to the hoisted if op's 'then' block.
auto *thenBlock = hoistedIfOp.getThenBlock();
thenBlock->getOperations().splice(thenBlock->begin(),
hoistOverOp->getBlock()->getOperations(),
Block::iterator(hoistOverOp));
// Find the clone of the original affine.if op in the else version.
AffineIfOp ifCloneInElse;
hoistOverOpClone->walk([&](AffineIfOp ifClone) {
if (!ifClone.getAttr(idForIfOp))
return WalkResult::advance();
ifCloneInElse = ifClone;
return WalkResult::interrupt();
});
assert(ifCloneInElse && "if op clone should exist");
// For the else block, promote the else block of the original 'if' if it had
// one; otherwise, the op itself is to be erased.
if (!ifCloneInElse.hasElse())
ifCloneInElse.erase();
else
promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
// Move the else version into the else block of the hoisted if op.
auto *elseBlock = hoistedIfOp.getElseBlock();
elseBlock->getOperations().splice(
elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
Block::iterator(hoistOverOpClone));
return hoistedIfOp;
}
// Returns success if any hoisting happened.
LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// Apply canonicalization patterns and folding - this is necessary for the
// hoisting check to be correct (operands should be composed), and to be more
// effective (no unused operands). Since the pattern rewriter's folding is
// entangled with application of patterns, we may fold/end up erasing the op,
// in which case we return with `folded` being set.
OwningRewritePatternList patterns;
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
applyOpPatternsAndFold(ifOp, patterns, &erased);
if (erased) {
if (folded)
*folded = true;
return failure();
}
if (folded)
*folded = false;
// The folding above should have ensured this, but the affine.if's
// canonicalization is missing composition of affine.applys into it.
assert(llvm::all_of(ifOp.getOperands(),
[](Value v) {
return isTopLevelValue(v) || isForInductionVar(v);
}) &&
"operands not composed");
// We are going hoist as high as possible.
// TODO: this could be customized in the future.
auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
// Nothing to hoist over.
if (hoistedIfOp == ifOp)
return failure();
// Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
// a sequence of affine.fors that are all perfectly nested).
applyPatternsAndFoldGreedily(
hoistedIfOp.getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
std::move(patterns));
return success();
}

View File

@@ -0,0 +1,258 @@
// RUN: mlir-opt %s -split-input-file -test-affine-loop-unswitch | FileCheck %s
// CHECK-DAG: #[[SET:.*]] = affine_set<(d0) : (d0 - 2 >= 0)>
// CHECK-LABEL: func @if_else_imperfect
func @if_else_imperfect(%A : memref<100xi32>, %B : memref<100xi32>, %v : i32) {
// CHECK: %[[A:.*]]: memref<100xi32>, %[[B:.*]]: memref
affine.for %i = 0 to 100 {
affine.load %A[%i] : memref<100xi32>
affine.for %j = 0 to 100 {
affine.load %A[%j] : memref<100xi32>
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
affine.load %B[%j] : memref<100xi32>
}
call @external() : () -> ()
}
affine.load %A[%i] : memref<100xi32>
}
return
}
func @external()
// CHECK: affine.for %[[I:.*]] = 0 to 100 {
// CHECK-NEXT: affine.load %[[A]][%[[I]]]
// CHECK-NEXT: affine.if #[[SET]](%[[I]]) {
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 100 {
// CHECK-NEXT: affine.load %[[A]][%[[J]]]
// CHECK-NEXT: affine.load %[[B]][%[[J]]]
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: affine.for %[[JJ:.*]] = 0 to 100 {
// CHECK-NEXT: affine.load %[[A]][%[[JJ]]]
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.load %[[A]][%[[I]]]
// CHECK-NEXT: }
// CHECK-NEXT: return
// -----
func @foo()
func @bar()
func @abc()
func @xyz()
// CHECK-LABEL: func @if_then_perfect
func @if_then_perfect(%A : memref<100xi32>, %v : i32) {
affine.for %i = 0 to 100 {
affine.for %j = 0 to 100 {
affine.for %k = 0 to 100 {
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%i) {
affine.load %A[%i] : memref<100xi32>
}
}
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NOT: else
// CHECK-LABEL: func @if_else_perfect
func @if_else_perfect(%A : memref<100xi32>, %v : i32) {
affine.for %i = 0 to 99 {
affine.for %j = 0 to 100 {
affine.for %k = 0 to 100 {
call @foo() : () -> ()
affine.if affine_set<(d0, d1) : (d0 - 2 >= 0, -d1 + 80 >= 0)>(%i, %j) {
affine.load %A[%i] : memref<100xi32>
call @abc() : () -> ()
} else {
affine.load %A[%i + 1] : memref<100xi32>
call @xyz() : () -> ()
}
call @bar() : () -> ()
}
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: call @foo
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}]
// CHECK-NEXT: call @abc
// CHECK-NEXT: call @bar
// CHECK-NEXT: }
// CHECK-NEXT: else
// CHECK-NEXT: affine.for
// CHECK-NEXT: call @foo
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} + 1]
// CHECK-NEXT: call @xyz
// CHECK-NEXT: call @bar
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-LABEL: func @if_then_imperfect
func @if_then_imperfect(%A : memref<100xi32>, %N : index) {
affine.for %i = 0 to 100 {
affine.load %A[0] : memref<100xi32>
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%N) {
affine.load %A[%i] : memref<100xi32>
}
}
return
}
// CHECK: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: affine.load
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
// Check if unused operands are dropped: hence, hoisting is possible.
// CHECK-LABEL: func @hoist_after_canonicalize
func @hoist_after_canonicalize() {
affine.for %i = 0 to 100 {
affine.for %j = 0 to 100 {
affine.if affine_set<(d0) : (d0 - 2 >= 0)>(%j) {
affine.if affine_set<(d0, d1) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%i, %j) {
// The call to external is to avoid DCE on affine.if.
call @foo() : () -> ()
}
}
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.if
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-LABEL: func @handle_dead_if
func @handle_dead_if(%N : index) {
affine.for %i = 0 to 100 {
affine.if affine_set<(d0) : (d0 - 1 >= 0, -d0 + 99 >= 0)>(%N) {
}
}
return
}
// CHECK-NEXT: affine.for
// CHECK-NEXT: }
// CHECK-NEXT: return
// -----
// A test case with affine.parallel.
#flb1 = affine_map<(d0) -> (d0 * 3)>
#fub1 = affine_map<(d0) -> (d0 * 3 + 3)>
#flb0 = affine_map<(d0) -> (d0 * 16)>
#fub0 = affine_map<(d0) -> (d0 * 16 + 16)>
#pub1 = affine_map<(d0)[s0] -> (s0, d0 * 3 + 3)>
#pub0 = affine_map<(d0)[s0] -> (s0, d0 * 16 + 16)>
#lb1 = affine_map<(d0) -> (d0 * 480)>
#ub1 = affine_map<(d0)[s0] -> (s0, d0 * 480 + 480)>
#lb0 = affine_map<(d0) -> (d0 * 110)>
#ub0 = affine_map<(d0)[s0] -> (d0 * 110 + 110, s0 floordiv 3)>
#set0 = affine_set<(d0, d1)[s0, s1] : (d0 * -16 + s0 - 16 >= 0, d1 * -3 + s1 - 3 >= 0)>
// CHECK-LABEL: func @perfect_if_else
func @perfect_if_else(%arg0 : memref<?x?xf64>, %arg1 : memref<?x?xf64>, %arg4 : index,
%arg5 : index, %arg6 : index, %sym : index) {
affine.for %arg7 = #lb0(%arg5) to min #ub0(%arg5)[%sym] {
affine.parallel (%i0, %j0) = (0, 0) to (symbol(%sym), 100) step (10, 10) {
affine.for %arg8 = #lb1(%arg4) to min #ub1(%arg4)[%sym] {
affine.if #set0(%arg6, %arg7)[%sym, %sym] {
affine.for %arg9 = #flb0(%arg6) to #fub0(%arg6) {
affine.for %arg10 = #flb1(%arg7) to #fub1(%arg7) {
affine.load %arg0[0, 0] : memref<?x?xf64>
}
}
} else {
affine.for %arg9 = #lb0(%arg6) to min #pub0(%arg6)[%sym] {
affine.for %arg10 = #lb1(%arg7) to min #pub1(%arg7)[%sym] {
affine.load %arg0[0, 0] : memref<?x?xf64>
}
}
}
}
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.parallel
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: affine.parallel
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.load
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// With multiple if ops in a function, the test pass just looks for the first if
// op that it is able to successfully hoist.
// CHECK-LABEL: func @multiple_if
func @multiple_if(%N : index) {
affine.if affine_set<() : (0 == 0)>() {
call @external() : () -> ()
}
affine.for %i = 0 to 100 {
affine.if affine_set<()[s0] : (s0 >= 0)>()[%N] {
call @external() : () -> ()
}
}
return
}
// CHECK: affine.if
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: affine.if
// CHECK-NEXT: affine.for
// CHECK-NEXT: call
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
func @external()

View File

@@ -1,5 +1,6 @@
add_llvm_library(MLIRAffineTransformsTestPasses
TestAffineDataCopy.cpp
TestAffineLoopUnswitching.cpp
TestLoopPermutation.cpp
TestParallelismDetection.cpp
TestVectorizationUtils.cpp

View File

@@ -0,0 +1,60 @@
//===- TestAffineLoopUnswitching.cpp - Test affine if/else hoisting -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to hoist affine if/else structures.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#define PASS_NAME "test-affine-loop-unswitch"
using namespace mlir;
namespace {
/// This pass applies the permutation on the first maximal perfect nest.
struct TestAffineLoopUnswitching
: public PassWrapper<TestAffineLoopUnswitching, FunctionPass> {
TestAffineLoopUnswitching() = default;
TestAffineLoopUnswitching(const TestAffineLoopUnswitching &pass) {}
void runOnFunction() override;
/// The maximum number of iterations to run this for.
constexpr static unsigned kMaxIterations = 5;
};
} // end anonymous namespace
void TestAffineLoopUnswitching::runOnFunction() {
// Each hoisting invalidates a lot of IR around. Just stop the walk after the
// first if/else hoisting, and repeat until no more hoisting can be done, or
// the maximum number of iterations have been run.
auto func = getFunction();
unsigned i = 0;
do {
auto walkFn = [](AffineIfOp op) {
return succeeded(hoistAffineIfOp(op)) ? WalkResult::interrupt()
: WalkResult::advance();
};
if (func.walk(walkFn).wasInterrupted())
break;
} while (++i < kMaxIterations);
}
namespace mlir {
void registerTestAffineLoopUnswitchingPass() {
PassRegistration<TestAffineLoopUnswitching>(
PASS_NAME, "Tests affine loop unswitching / if/else hoisting");
}
} // namespace mlir

View File

@@ -40,6 +40,7 @@ void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAllReduceLoweringPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestLinalgMatmulToVectorPass();
void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
@@ -103,6 +104,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
registerTestAffineLoopUnswitchingPass();
registerTestLinalgMatmulToVectorPass();
registerTestLoopPermutationPass();
registerTestCallGraphPass();