[mlir][CSE] Add ability to remove commutative operations

This patch takes advantage of the Commutative trait on operation
to remove identical commutative operations where the operands are swapped.

The second operation below can be removed since `arith.addi` is commutative.
```
%1 = arith.addi %a, %b : i32
%2 = arith.addi %b, %a : i32
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123492
This commit is contained in:
Valentin Clement
2022-04-16 21:08:16 +02:00
parent bf59cd7244
commit bd514967aa
2 changed files with 39 additions and 3 deletions

View File

@@ -633,8 +633,18 @@ llvm::hash_code OperationEquivalence::computeHash(
op->getName(), op->getAttrDictionary(), op->getResultTypes());
// - Operands
for (Value operand : op->getOperands())
ValueRange operands = op->getOperands();
SmallVector<Value> operandStorage;
if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
operandStorage.append(operands.begin(), operands.end());
llvm::sort(operandStorage, [](Value a, Value b) -> bool {
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
});
operands = operandStorage;
}
for (Value operand : operands)
hash = llvm::hash_combine(hash, hashOperands(operand));
// - Operands
for (Value result : op->getResults())
hash = llvm::hash_combine(hash, hashResults(result));
@@ -710,6 +720,21 @@ bool OperationEquivalence::isEquivalentTo(
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
});
lhsOperands = lhsOperandStorage;
rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
});
rhsOperands = rhsOperandStorage;
}
auto checkValueRangeMapping =
[](ValueRange lhs, ValueRange rhs,
function_ref<LogicalResult(Value, Value)> mapValues) {
@@ -724,8 +749,7 @@ bool OperationEquivalence::isEquivalentTo(
return true;
};
// Check mapping of operands and results.
if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
mapOperands))
if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
return false;
if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
return false;

View File

@@ -310,3 +310,15 @@ func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
%2 = arith.addi %0, %1 : i32
return %2 : i32
}
/// This test is checking that identical commutative operation are gracefully
/// handled but the CSE pass.
// CHECK-LABEL: func @check_cummutative_cse
func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
// CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
%1 = arith.addi %a, %b : i32
%2 = arith.addi %b, %a : i32
// CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32
%3 = arith.muli %1, %2 : i32
return %3 : i32
}