[MLIR][Math] Add lowering for isnan and isfinite (#128125)

Co-authored-by: Ivan R. Ivanov <ivanov.i.aa@m.titech.ac.jp>
This commit is contained in:
William Moses
2025-02-20 23:40:35 -06:00
committed by GitHub
parent cc675c635b
commit f0134e6d31
2 changed files with 58 additions and 0 deletions

View File

@@ -18,6 +18,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/FloatingPointMode.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -286,6 +288,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
}
};
struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
op, op.getType(), adaptor.getOperand(), llvm::fcNan);
return success();
}
};
struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
return success();
}
};
struct ConvertMathToLLVMPass
: public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
using Base::Base;
@@ -309,6 +345,8 @@ void mlir::populateMathToLLVMConversionPatterns(
patterns.add<Log1pOpLowering>(converter, benefit);
// clang-format off
patterns.add<
IsNaNOpLowering,
IsFiniteOpLowering,
AbsFOpLowering,
AbsIOpLowering,
CeilOpLowering,

View File

@@ -263,6 +263,26 @@ func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
// -----
// CHECK-LABEL: func @isnan_double(
// CHECK-SAME: f64
func.func @isnan_double(%arg0 : f64) {
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 3 : i32}> : (f64) -> i1
%0 = math.isnan %arg0 : f64
func.return
}
// -----
// CHECK-LABEL: func @isfinite_double(
// CHECK-SAME: f64
func.func @isfinite_double(%arg0 : f64) {
// CHECK: "llvm.intr.is.fpclass"(%arg0) <{bit = 504 : i32}> : (f64) -> i1
%0 = math.isfinite %arg0 : f64
func.return
}
// -----
// CHECK-LABEL: func @rsqrt_double(
// CHECK-SAME: f64
func.func @rsqrt_double(%arg0 : f64) {