mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[MLIR][SCF] Canonicalize redundant scf.if from scf.while before region into after region (#169892)
When a `scf.if` directly precedes a `scf.condition` in the before region of a `scf.while` and both share the same condition, move the if into the after region of the loop. This helps simplify the control flow to enable uplifting `scf.while` to `scf.for`.
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
@@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() {
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Move a scf.if op that is directly before the scf.condition op in the while
|
||||
/// before region, and whose condition matches the condition of the
|
||||
/// scf.condition op, down into the while after region.
|
||||
///
|
||||
/// scf.while (..) : (...) -> ... {
|
||||
/// %additional_used_values = ...
|
||||
/// %cond = ...
|
||||
/// ...
|
||||
/// %res = scf.if %cond -> (...) {
|
||||
/// use(%additional_used_values)
|
||||
/// ... // then block
|
||||
/// scf.yield %then_value
|
||||
/// } else {
|
||||
/// scf.yield %else_value
|
||||
/// }
|
||||
/// scf.condition(%cond) %res, ...
|
||||
/// } do {
|
||||
/// ^bb0(%res_arg, ...):
|
||||
/// use(%res_arg)
|
||||
/// ...
|
||||
///
|
||||
/// becomes
|
||||
/// scf.while (..) : (...) -> ... {
|
||||
/// %additional_used_values = ...
|
||||
/// %cond = ...
|
||||
/// ...
|
||||
/// scf.condition(%cond) %else_value, ..., %additional_used_values
|
||||
/// } do {
|
||||
/// ^bb0(%res_arg ..., %additional_args): :
|
||||
/// use(%additional_args)
|
||||
/// ... // if then block
|
||||
/// use(%then_value)
|
||||
/// ...
|
||||
struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
|
||||
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::WhileOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto conditionOp = op.getConditionOp();
|
||||
|
||||
// Only support ifOp right before the condition at the moment. Relaxing this
|
||||
// would require to:
|
||||
// - check that the body does not have side-effects conflicting with
|
||||
// operations between the if and the condition.
|
||||
// - check that results of the if operation are only used as arguments to
|
||||
// the condition.
|
||||
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
|
||||
|
||||
// Check that the ifOp is directly before the conditionOp and that it
|
||||
// matches the condition of the conditionOp. Also ensure that the ifOp has
|
||||
// no else block with content, as that would complicate the transformation.
|
||||
// TODO: support else blocks with content.
|
||||
if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
|
||||
(ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
|
||||
return failure();
|
||||
|
||||
assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
|
||||
*ifOp->user_begin() == conditionOp) &&
|
||||
"ifOp has unexpected uses");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Replace uses of ifOp results in the conditionOp with the yielded values
|
||||
// from the ifOp branches.
|
||||
for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
|
||||
auto it = llvm::find(ifOp->getResults(), arg);
|
||||
if (it != ifOp->getResults().end()) {
|
||||
size_t ifOpIdx = it.getIndex();
|
||||
Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
|
||||
Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
|
||||
|
||||
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
|
||||
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect additional used values from before region.
|
||||
SetVector<Value> additionalUsedValuesSet;
|
||||
visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
|
||||
if (&op.getBefore() == operand->get().getParentRegion())
|
||||
additionalUsedValuesSet.insert(operand->get());
|
||||
});
|
||||
|
||||
// Create new whileOp with additional used values as results.
|
||||
auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
|
||||
auto additionalValueTypes = llvm::map_to_vector(
|
||||
additionalUsedValues, [](Value val) { return val.getType(); });
|
||||
size_t additionalValueSize = additionalUsedValues.size();
|
||||
SmallVector<Type> newResultTypes(op.getResultTypes());
|
||||
newResultTypes.append(additionalValueTypes);
|
||||
|
||||
auto newWhileOp =
|
||||
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
|
||||
|
||||
rewriter.modifyOpInPlace(newWhileOp, [&] {
|
||||
newWhileOp.getBefore().takeBody(op.getBefore());
|
||||
newWhileOp.getAfter().takeBody(op.getAfter());
|
||||
newWhileOp.getAfter().addArguments(
|
||||
additionalValueTypes,
|
||||
SmallVector<Location>(additionalValueSize, loc));
|
||||
});
|
||||
|
||||
rewriter.modifyOpInPlace(conditionOp, [&] {
|
||||
conditionOp.getArgsMutable().append(additionalUsedValues);
|
||||
});
|
||||
|
||||
// Replace uses of additional used values inside the ifOp then region with
|
||||
// the whileOp after region arguments.
|
||||
rewriter.replaceUsesWithIf(
|
||||
additionalUsedValues,
|
||||
newWhileOp.getAfterArguments().take_back(additionalValueSize),
|
||||
[&](OpOperand &use) {
|
||||
return ifOp.getThenRegion().isAncestor(
|
||||
use.getOwner()->getParentRegion());
|
||||
});
|
||||
|
||||
// Inline ifOp then region into new whileOp after region.
|
||||
rewriter.eraseOp(ifOp.thenYield());
|
||||
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
|
||||
newWhileOp.getAfterBody()->begin());
|
||||
rewriter.eraseOp(ifOp);
|
||||
rewriter.replaceOp(op,
|
||||
newWhileOp->getResults().drop_back(additionalValueSize));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace uses of the condition within the do block with true, since otherwise
|
||||
/// the block would not be evaluated.
|
||||
///
|
||||
@@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
|
||||
RemoveLoopInvariantValueYielded, WhileConditionTruth,
|
||||
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
|
||||
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
|
||||
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_move_if_down
|
||||
func.func @while_move_if_down() -> i32 {
|
||||
%defined_outside = "test.get_some_value0" () : () -> (i32)
|
||||
%0 = scf.while () : () -> (i32) {
|
||||
%used_value = "test.get_some_value1" () : () -> (i32)
|
||||
%used_by_subregion = "test.get_some_value2" () : () -> (i32)
|
||||
%else_value = "test.get_some_value3" () : () -> (i32)
|
||||
%condition = "test.condition"() : () -> i1
|
||||
%res = scf.if %condition -> (i32) {
|
||||
"test.use0" (%defined_outside) : (i32) -> ()
|
||||
"test.use1" (%used_value) : (i32) -> ()
|
||||
test.alloca_scope_region {
|
||||
"test.use2" (%used_by_subregion) : (i32) -> ()
|
||||
}
|
||||
%then_value = "test.get_some_value4" () : () -> (i32)
|
||||
scf.yield %then_value : i32
|
||||
} else {
|
||||
scf.yield %else_value : i32
|
||||
}
|
||||
scf.condition(%condition) %res : i32
|
||||
} do {
|
||||
^bb0(%res_arg: i32):
|
||||
"test.use3" (%res_arg) : (i32) -> ()
|
||||
scf.yield
|
||||
}
|
||||
return %0 : i32
|
||||
}
|
||||
// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32
|
||||
// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) {
|
||||
// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32
|
||||
// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32
|
||||
// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32
|
||||
// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1
|
||||
// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32
|
||||
// CHECK: } do {
|
||||
// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32):
|
||||
// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> ()
|
||||
// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> ()
|
||||
// CHECK: test.alloca_scope_region {
|
||||
// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32
|
||||
// CHECK: "test.use3"(%[[then_value]]) : (i32) -> ()
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
// CHECK: return %[[WHILE_RES]]#0 : i32
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_cond_true
|
||||
func.func @while_cond_true() -> i1 {
|
||||
%0 = scf.while () : () -> i1 {
|
||||
|
||||
Reference in New Issue
Block a user