[mlir][spirv] Add folder for LogicalNotEqual

Add a folder for LogicalNotEqual when rhs is false. This pattern shows
up after lowering to SPIRV.

Differential Revision: https://reviews.llvm.org/D141163
This commit is contained in:
Thomas Raoux
2023-01-06 23:03:12 +00:00
parent a344c9073c
commit 493459b6dd
3 changed files with 34 additions and 0 deletions

View File

@@ -723,6 +723,7 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
```
}];
let hasFolder = true;
}
// -----

View File

@@ -251,6 +251,23 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
return Attribute();
}
//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//
OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 &&
"spirv.LogicalNotEqual should take two operands");
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
// x && false = x
if (!rhs.value())
return getOperand1();
}
return Attribute();
}
//===----------------------------------------------------------------------===//
// spirv.LogicalNot
//===----------------------------------------------------------------------===//

View File

@@ -470,6 +470,22 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
spirv.ReturnValue %3 : vector<3xi1>
}
// -----
//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqual
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @convert_logical_not_equal_false
// CHECK-SAME: %[[ARG:.+]]: vector<4xi1>
func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
%cst = spirv.Constant dense<false> : vector<4xi1>
// CHECK: spirv.ReturnValue %[[ARG]] : vector<4xi1>
%0 = spirv.LogicalNotEqual %arg, %cst : vector<4xi1>
spirv.ReturnValue %0 : vector<4xi1>
}
// -----
func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {