[mlir] Added ctlz and cttz to math dialect and LLVM dialect

Count leading/trailing zeros are an existing LLVM intrinsic. Added LLVM
support for the intrinsics with lowerings from the math dialect to LLVM
dialect.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D115206
This commit is contained in:
Rob Suderman
2021-12-08 14:32:09 -08:00
parent 7d62b68abc
commit 23149d522b
5 changed files with 159 additions and 0 deletions

View File

@@ -1401,6 +1401,12 @@ class LLVM_TernarySameArgsIntrinsicOp<string func, list<OpTrait> traits = []> :
let arguments = (ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c);
}
class LLVM_CountZerosIntrinsicOp<string func, list<OpTrait> traits = []> :
LLVM_OneResultIntrOp<func, [], [0],
!listconcat([NoSideEffect], traits)> {
let arguments = (ins LLVM_Type:$in, I<1>:$zero_undefined);
}
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
@@ -1421,6 +1427,8 @@ def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">;
def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrinsicOp<"ctlz">;
def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrinsicOp<"cttz">;
def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">;
def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">;
def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">;

View File

@@ -297,6 +297,54 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
}];
}
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp
//===----------------------------------------------------------------------===//
def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
let summary = "counts the leading zeros an integer value";
let description = [{
The `ctlz` operation computes the number of leading zeros of an integer value.
Example:
```mlir
// Scalar ctlz function value.
%a = math.ctlz %b : i32
// SIMD vector element-wise ctlz function value.
%f = math.ctlz %g : vector<4xi16>
// Tensor element-wise ctlz function value.
%x = math.ctlz %y : tensor<4x?xi8>
```
}];
}
//===----------------------------------------------------------------------===//
// CountTrailingZerosOp
//===----------------------------------------------------------------------===//
def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
let summary = "counts the trailing zeros an integer value";
let description = [{
The `cttz` operation computes the number of trailing zeros of an integer value.
Example:
```mlir
// Scalar cttz function value.
%a = math.cttz %b : i32
// SIMD vector element-wise cttz function value.
%f = math.cttz %g : vector<4xi16>
// Tensor element-wise cttz function value.
%x = math.cttz %y : tensor<4x?xi8>
```
}];
}
//===----------------------------------------------------------------------===//
// CtPopOp
//===----------------------------------------------------------------------===//

View File

@@ -38,6 +38,54 @@ using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
template <typename MathOp, typename LLVMOp>
struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
using Super = CountOpLowering<MathOp, LLVMOp>;
LogicalResult
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto boolType = rewriter.getIntegerType(1);
auto boolZero = rewriter.getIntegerAttr(boolType, 0);
if (!operandType.template isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
zero);
return success();
}
auto vectorType = resultType.template dyn_cast<VectorType>();
if (!vectorType)
return failure();
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
operands[0], zero);
},
rewriter);
}
};
using CountLeadingZerosOpLowering =
CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
using CountTrailingZerosOpLowering =
CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -222,6 +270,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
CeilOpLowering,
CopySignOpLowering,
CosOpLowering,
CountLeadingZerosOpLowering,
CountTrailingZerosOpLowering,
CtPopFOpLowering,
ExpOpLowering,
Exp2OpLowering,

View File

@@ -74,6 +74,39 @@ func @sine(%arg0 : f32) {
// -----
// CHECK-LABEL: func @ctlz(
// CHECK-SAME: i32
func @ctlz(%arg0 : i32) {
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
// CHECK: "llvm.intr.ctlz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
%0 = math.ctlz %arg0 : i32
std.return
}
// -----
// CHECK-LABEL: func @cttz(
// CHECK-SAME: i32
func @cttz(%arg0 : i32) {
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
// CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
%0 = math.cttz %arg0 : i32
std.return
}
// -----
// CHECK-LABEL: func @cttz_vec(
// CHECK-SAME: i32
func @cttz_vec(%arg0 : vector<4xi32>) {
// CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
// CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (vector<4xi32>, i1) -> vector<4xi32>
%0 = math.cttz %arg0 : vector<4xi32>
std.return
}
// -----
// CHECK-LABEL: func @ctpop(
// CHECK-SAME: i32
func @ctpop(%arg0 : i32) {

View File

@@ -135,6 +135,26 @@ llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) {
llvm.return
}
// CHECK-LABEL: @ctlz_test
llvm.func @ctlz_test(%arg0: i32, %arg1: vector<8xi32>) {
%i1 = llvm.mlir.constant(false) : i1
// CHECK: call i32 @llvm.ctlz.i32
"llvm.intr.ctlz"(%arg0, %i1) : (i32, i1) -> i32
// CHECK: call <8 x i32> @llvm.ctlz.v8i32
"llvm.intr.ctlz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
llvm.return
}
// CHECK-LABEL: @cttz_test
llvm.func @cttz_test(%arg0: i32, %arg1: vector<8xi32>) {
%i1 = llvm.mlir.constant(false) : i1
// CHECK: call i32 @llvm.cttz.i32
"llvm.intr.cttz"(%arg0, %i1) : (i32, i1) -> i32
// CHECK: call <8 x i32> @llvm.cttz.v8i32
"llvm.intr.cttz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
llvm.return
}
// CHECK-LABEL: @ctpop_test
llvm.func @ctpop_test(%arg0: i32, %arg1: vector<8xi32>) {
// CHECK: call i32 @llvm.ctpop.i32