mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 08:56:15 +08:00
[MLIR][Math] Add lowering for isnan and isfinite (#128125)
Co-authored-by: Ivan R. Ivanov <ivanov.i.aa@m.titech.ac.jp>
This commit is contained in:
@@ -18,6 +18,8 @@
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/FloatingPointMode.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
@@ -286,6 +288,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
|
||||
using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto operandType = adaptor.getOperand().getType();
|
||||
|
||||
if (!operandType || !LLVM::isCompatibleType(operandType))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
|
||||
op, op.getType(), adaptor.getOperand(), llvm::fcNan);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
|
||||
using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto operandType = adaptor.getOperand().getType();
|
||||
|
||||
if (!operandType || !LLVM::isCompatibleType(operandType))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
|
||||
op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertMathToLLVMPass
|
||||
: public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
|
||||
using Base::Base;
|
||||
@@ -309,6 +345,8 @@ void mlir::populateMathToLLVMConversionPatterns(
|
||||
patterns.add<Log1pOpLowering>(converter, benefit);
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
IsNaNOpLowering,
|
||||
IsFiniteOpLowering,
|
||||
AbsFOpLowering,
|
||||
AbsIOpLowering,
|
||||
CeilOpLowering,
|
||||
|
||||
@@ -263,6 +263,26 @@ func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @isnan_double(
|
||||
// CHECK-SAME: f64
|
||||
func.func @isnan_double(%arg0 : f64) {
|
||||
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 3 : i32}> : (f64) -> i1
|
||||
%0 = math.isnan %arg0 : f64
|
||||
func.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @isfinite_double(
|
||||
// CHECK-SAME: f64
|
||||
func.func @isfinite_double(%arg0 : f64) {
|
||||
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 504 : i32}> : (f64) -> i1
|
||||
%0 = math.isfinite %arg0 : f64
|
||||
func.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_double(
|
||||
// CHECK-SAME: f64
|
||||
func.func @rsqrt_double(%arg0 : f64) {
|
||||
|
||||
Reference in New Issue
Block a user