diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 174b7ce9fed2..bf753c7062f3 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -499,13 +499,16 @@ struct LogOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create(elementType, adaptor.getComplex()); - Value resultReal = b.create(elementType, abs); + Value abs = b.create(elementType, adaptor.getComplex(), + fmf.getValue()); + Value resultReal = b.create(elementType, abs, fmf.getValue()); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); - Value resultImag = b.create(elementType, imag, real); + Value resultImag = + b.create(elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); @@ -520,6 +523,7 @@ struct Log1pOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); @@ -535,15 +539,21 @@ struct Log1pOpConversion : public OpConversionPattern { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) - Value sumSq = b.create(real, real); - sumSq = b.create(sumSq, b.create(real, two)); - sumSq = b.create(sumSq, b.create(imag, imag)); - Value logSumSq = b.create(elementType, sumSq); - Value resultReal = b.create(logSumSq, half); + Value sumSq = b.create(real, real, fmf.getValue()); + sumSq = b.create( + sumSq, b.create(real, two, fmf.getValue()), + fmf.getValue()); + sumSq = b.create( + sumSq, b.create(imag, imag, fmf.getValue()), + fmf.getValue()); + Value logSumSq = + b.create(elementType, sumSq, fmf.getValue()); + Value resultReal = b.create(logSumSq, half, fmf.getValue()); - Value realPlusOne = b.create(real, one); + Value realPlusOne = b.create(real, one, fmf.getValue()); - Value resultImag = b.create(elementType, imag, realPlusOne); + Value resultImag = + b.create(elementType, imag, realPlusOne, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 8264382a0265..3af28150fd5c 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -797,4 +797,51 @@ func.func @complex_expm1_with_fmf(%arg: complex) -> complex { // CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath : f32 // CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex -// CHECK: return %[[RES]] : complex \ No newline at end of file +// CHECK: return %[[RES]] : complex + +// ----- + +// CHECK-LABEL: func @complex_log_with_fmf +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_log_with_fmf(%arg: complex) -> complex { + %log = complex.log %arg fastmath : complex + return %log : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 +// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath : f32 +// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] fastmath : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @complex_log1p_with_fmf +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_log1p_with_fmf(%arg: complex) -> complex { + %log1p = complex.log1p %arg fastmath : complex + return %log1p : complex +} + +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 +// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath : f32 +// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath : f32 +// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath : f32 +// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath : f32 +// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath : f32 +// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex