mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 02:00:03 +08:00
[mlir][CSE] Add ability to remove commutative operations
This patch takes advantage of the Commutative trait on operation to remove identical commutative operations where the operands are swapped. The second operation below can be removed since `arith.addi` is commutative. ``` %1 = arith.addi %a, %b : i32 %2 = arith.addi %b, %a : i32 ``` Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D123492
This commit is contained in:
@@ -633,8 +633,18 @@ llvm::hash_code OperationEquivalence::computeHash(
|
||||
op->getName(), op->getAttrDictionary(), op->getResultTypes());
|
||||
|
||||
// - Operands
|
||||
for (Value operand : op->getOperands())
|
||||
ValueRange operands = op->getOperands();
|
||||
SmallVector<Value> operandStorage;
|
||||
if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
|
||||
operandStorage.append(operands.begin(), operands.end());
|
||||
llvm::sort(operandStorage, [](Value a, Value b) -> bool {
|
||||
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
|
||||
});
|
||||
operands = operandStorage;
|
||||
}
|
||||
for (Value operand : operands)
|
||||
hash = llvm::hash_combine(hash, hashOperands(operand));
|
||||
|
||||
// - Operands
|
||||
for (Value result : op->getResults())
|
||||
hash = llvm::hash_combine(hash, hashResults(result));
|
||||
@@ -710,6 +720,21 @@ bool OperationEquivalence::isEquivalentTo(
|
||||
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
|
||||
return false;
|
||||
|
||||
ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
|
||||
SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
|
||||
if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
|
||||
lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
|
||||
llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
|
||||
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
|
||||
});
|
||||
lhsOperands = lhsOperandStorage;
|
||||
|
||||
rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
|
||||
llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
|
||||
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
|
||||
});
|
||||
rhsOperands = rhsOperandStorage;
|
||||
}
|
||||
auto checkValueRangeMapping =
|
||||
[](ValueRange lhs, ValueRange rhs,
|
||||
function_ref<LogicalResult(Value, Value)> mapValues) {
|
||||
@@ -724,8 +749,7 @@ bool OperationEquivalence::isEquivalentTo(
|
||||
return true;
|
||||
};
|
||||
// Check mapping of operands and results.
|
||||
if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
|
||||
mapOperands))
|
||||
if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
|
||||
return false;
|
||||
if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
|
||||
return false;
|
||||
|
||||
@@ -310,3 +310,15 @@ func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
|
||||
%2 = arith.addi %0, %1 : i32
|
||||
return %2 : i32
|
||||
}
|
||||
|
||||
/// This test is checking that identical commutative operation are gracefully
|
||||
/// handled but the CSE pass.
|
||||
// CHECK-LABEL: func @check_cummutative_cse
|
||||
func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
|
||||
// CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
|
||||
%1 = arith.addi %a, %b : i32
|
||||
%2 = arith.addi %b, %a : i32
|
||||
// CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32
|
||||
%3 = arith.muli %1, %2 : i32
|
||||
return %3 : i32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user