mirror of
https://github.com/intel/llvm.git
synced 2026-02-05 22:17:23 +08:00
[MLIR] Add SCF.if Condition Canonicalizations
Add two canoncalizations for scf.if.
1) A canonicalization that allows users of a condition within an if to assume the condition
is true if in the true region, etc.
2) A canonicalization that removes yielded statements that are equivalent to the condition
or its negation
Differential Revision: https://reviews.llvm.org/D101012
This commit is contained in:
@@ -1106,12 +1106,172 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Allow the true region of an if to assume the condition is true
|
||||
// and vice versa. For example:
|
||||
//
|
||||
// scf.if %cmp {
|
||||
// print(%cmp)
|
||||
// }
|
||||
//
|
||||
// becomes
|
||||
//
|
||||
// scf.if %cmp {
|
||||
// print(true)
|
||||
// }
|
||||
//
|
||||
struct ConditionPropagation : public OpRewritePattern<IfOp> {
|
||||
using OpRewritePattern<IfOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Early exit if the condition is constant since replacing a constant
|
||||
// in the body with another constant isn't a simplification.
|
||||
if (op.condition().getDefiningOp<ConstantOp>())
|
||||
return failure();
|
||||
|
||||
bool changed = false;
|
||||
mlir::Type i1Ty = rewriter.getI1Type();
|
||||
|
||||
// These variables serve to prevent creating duplicate constants
|
||||
// and hold constant true or false values.
|
||||
Value constantTrue = nullptr;
|
||||
Value constantFalse = nullptr;
|
||||
|
||||
for (OpOperand &use :
|
||||
llvm::make_early_inc_range(op.condition().getUses())) {
|
||||
if (op.thenRegion().isAncestor(use.getOwner()->getParentRegion())) {
|
||||
changed = true;
|
||||
|
||||
if (!constantTrue)
|
||||
constantTrue = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&]() { use.set(constantTrue); });
|
||||
} else if (op.elseRegion().isAncestor(
|
||||
use.getOwner()->getParentRegion())) {
|
||||
changed = true;
|
||||
|
||||
if (!constantFalse)
|
||||
constantFalse = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&]() { use.set(constantFalse); });
|
||||
}
|
||||
}
|
||||
|
||||
return success(changed);
|
||||
}
|
||||
};
|
||||
|
||||
/// Remove any statements from an if that are equivalent to the condition
|
||||
/// or its negation. For example:
|
||||
///
|
||||
/// %res:2 = scf.if %cmp {
|
||||
/// yield something(), true
|
||||
/// } else {
|
||||
/// yield something2(), false
|
||||
/// }
|
||||
/// print(%res#1)
|
||||
///
|
||||
/// becomes
|
||||
/// %res = scf.if %cmp {
|
||||
/// yield something()
|
||||
/// } else {
|
||||
/// yield something2()
|
||||
/// }
|
||||
/// print(%cmp)
|
||||
///
|
||||
/// Additionally if both branches yield the same value, replace all uses
|
||||
/// of the result with the yielded value
|
||||
///
|
||||
/// %res:2 = scf.if %cmp {
|
||||
/// yield something(), %arg1
|
||||
/// } else {
|
||||
/// yield something2(), %arg1
|
||||
/// }
|
||||
/// print(%res#1)
|
||||
///
|
||||
/// becomes
|
||||
/// %res = scf.if %cmp {
|
||||
/// yield something()
|
||||
/// } else {
|
||||
/// yield something2()
|
||||
/// }
|
||||
// print(%arg1)
|
||||
struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
|
||||
using OpRewritePattern<IfOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Early exit if there are no results that could be replaced.
|
||||
if (op.getNumResults() == 0)
|
||||
return failure();
|
||||
|
||||
auto trueYield = cast<scf::YieldOp>(op.thenRegion().back().getTerminator());
|
||||
auto falseYield =
|
||||
cast<scf::YieldOp>(op.elseRegion().back().getTerminator());
|
||||
|
||||
rewriter.setInsertionPoint(op->getBlock(),
|
||||
op.getOperation()->getIterator());
|
||||
bool changed = false;
|
||||
Type i1Ty = rewriter.getI1Type();
|
||||
for (auto tup :
|
||||
llvm::zip(trueYield.results(), falseYield.results(), op.results())) {
|
||||
Value trueResult, falseResult, opResult;
|
||||
std::tie(trueResult, falseResult, opResult) = tup;
|
||||
|
||||
if (trueResult == falseResult) {
|
||||
if (!opResult.use_empty()) {
|
||||
opResult.replaceAllUsesWith(trueResult);
|
||||
changed = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto trueYield = trueResult.getDefiningOp<ConstantOp>();
|
||||
if (!trueYield)
|
||||
continue;
|
||||
|
||||
if (!trueYield.getType().isInteger(1))
|
||||
continue;
|
||||
|
||||
auto falseYield = falseResult.getDefiningOp<ConstantOp>();
|
||||
if (!falseYield)
|
||||
continue;
|
||||
|
||||
bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
|
||||
bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
|
||||
if (!trueVal && falseVal) {
|
||||
if (!opResult.use_empty()) {
|
||||
Value notCond = rewriter.create<XOrOp>(
|
||||
op.getLoc(), op.condition(),
|
||||
rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
|
||||
opResult.replaceAllUsesWith(notCond);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (trueVal && !falseVal) {
|
||||
if (!opResult.use_empty()) {
|
||||
opResult.replaceAllUsesWith(op.condition());
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(changed);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<RemoveUnusedResults, RemoveStaticCondition,
|
||||
ConvertTrivialIfToSelect>(context);
|
||||
results
|
||||
.add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
|
||||
ConditionPropagation, ReplaceIfYieldWithConditionOrValue>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -103,22 +103,25 @@ func private @side_effect()
|
||||
func @one_unused(%cond: i1) -> (index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%0, %1 = scf.if %cond -> (index, index) {
|
||||
call @side_effect() : () -> ()
|
||||
scf.yield %c0, %c1 : index, index
|
||||
} else {
|
||||
scf.yield %c0, %c1 : index, index
|
||||
scf.yield %c2, %c3 : index, index
|
||||
}
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @one_unused
|
||||
// CHECK: [[C0:%.*]] = constant 1 : index
|
||||
// CHECK: [[C3:%.*]] = constant 3 : index
|
||||
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
|
||||
// CHECK: call @side_effect() : () -> ()
|
||||
// CHECK: scf.yield [[C0]] : index
|
||||
// CHECK: } else
|
||||
// CHECK: scf.yield [[C0]] : index
|
||||
// CHECK: scf.yield [[C3]] : index
|
||||
// CHECK: }
|
||||
// CHECK: return [[V0]] : index
|
||||
|
||||
@@ -128,12 +131,14 @@ func private @side_effect()
|
||||
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%0, %1 = scf.if %cond1 -> (index, index) {
|
||||
%2, %3 = scf.if %cond2 -> (index, index) {
|
||||
call @side_effect() : () -> ()
|
||||
scf.yield %c0, %c1 : index, index
|
||||
} else {
|
||||
scf.yield %c0, %c1 : index, index
|
||||
scf.yield %c2, %c3 : index, index
|
||||
}
|
||||
scf.yield %2, %3 : index, index
|
||||
} else {
|
||||
@@ -144,12 +149,13 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
|
||||
|
||||
// CHECK-LABEL: func @nested_unused
|
||||
// CHECK: [[C0:%.*]] = constant 1 : index
|
||||
// CHECK: [[C3:%.*]] = constant 3 : index
|
||||
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
|
||||
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
|
||||
// CHECK: call @side_effect() : () -> ()
|
||||
// CHECK: scf.yield [[C0]] : index
|
||||
// CHECK: } else
|
||||
// CHECK: scf.yield [[C0]] : index
|
||||
// CHECK: scf.yield [[C3]] : index
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield [[V1]] : index
|
||||
// CHECK: } else
|
||||
@@ -610,3 +616,111 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
|
||||
%res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
|
||||
return %res : tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CHECK-LABEL: @cond_prop
|
||||
func @cond_prop(%arg0 : i1) -> index {
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%c4 = constant 4 : index
|
||||
%res = scf.if %arg0 -> index {
|
||||
%res1 = scf.if %arg0 -> index {
|
||||
%v1 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %c1 : index
|
||||
} else {
|
||||
%v2 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %c2 : index
|
||||
}
|
||||
scf.yield %res1 : index
|
||||
} else {
|
||||
%res2 = scf.if %arg0 -> index {
|
||||
%v3 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %c3 : index
|
||||
} else {
|
||||
%v4 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %c4 : index
|
||||
}
|
||||
scf.yield %res2 : index
|
||||
}
|
||||
return %res : index
|
||||
}
|
||||
// CHECK-DAG: %[[c1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[c4:.+]] = constant 4 : index
|
||||
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
|
||||
// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[c1]] : index
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[c4]] : index
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]] : index
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// CHECK-LABEL: @replace_if_with_cond1
|
||||
func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
|
||||
%true = constant true
|
||||
%false = constant false
|
||||
%res:2 = scf.if %arg0 -> (i32, i1) {
|
||||
%v = "test.get_some_value"() : () -> i32
|
||||
scf.yield %v, %true : i32, i1
|
||||
} else {
|
||||
%v2 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %v2, %false : i32, i1
|
||||
}
|
||||
return %res#0, %res#1 : i32, i1
|
||||
}
|
||||
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
|
||||
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[sv1]] : i32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[sv2]] : i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
|
||||
|
||||
// CHECK-LABEL: @replace_if_with_cond2
|
||||
func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
|
||||
%true = constant true
|
||||
%false = constant false
|
||||
%res:2 = scf.if %arg0 -> (i32, i1) {
|
||||
%v = "test.get_some_value"() : () -> i32
|
||||
scf.yield %v, %false : i32, i1
|
||||
} else {
|
||||
%v2 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %v2, %true : i32, i1
|
||||
}
|
||||
return %res#0, %res#1 : i32, i1
|
||||
}
|
||||
// CHECK-NEXT: %true = constant true
|
||||
// CHECK-NEXT: %[[toret:.+]] = xor %arg0, %true : i1
|
||||
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
|
||||
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[sv1]] : i32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[sv2]] : i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
|
||||
|
||||
|
||||
// CHECK-LABEL: @replace_if_with_cond3
|
||||
func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
|
||||
%res:2 = scf.if %arg0 -> (i32, i64) {
|
||||
%v = "test.get_some_value"() : () -> i32
|
||||
scf.yield %v, %arg2 : i32, i64
|
||||
} else {
|
||||
%v2 = "test.get_some_value"() : () -> i32
|
||||
scf.yield %v2, %arg2 : i32, i64
|
||||
}
|
||||
return %res#0, %res#1 : i32, i64
|
||||
}
|
||||
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
|
||||
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[sv1]] : i32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
|
||||
// CHECK-NEXT: scf.yield %[[sv2]] : i32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
|
||||
|
||||
@@ -1198,11 +1198,12 @@ func @clone_loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clone_nested_region
|
||||
func @clone_nested_region(%arg0: index, %arg1: index) -> memref<?x?xf32> {
|
||||
func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memref<?x?xf32> {
|
||||
%cmp = cmpi eq, %arg0, %arg1 : index
|
||||
%0 = cmpi eq, %arg0, %arg1 : index
|
||||
%1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
|
||||
%2 = scf.if %0 -> (memref<?x?xf32>) {
|
||||
%3 = scf.if %0 -> (memref<?x?xf32>) {
|
||||
%3 = scf.if %cmp -> (memref<?x?xf32>) {
|
||||
%9 = memref.clone %1 : memref<?x?xf32> to memref<?x?xf32>
|
||||
scf.yield %9 : memref<?x?xf32>
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user