[MLIR] Add new complex.powi op (#158722)

This PR adds a new complex.powi operation to MLIR's complex dialect for
computing complex numbers raised to integer powers.

Key changes include:

- Addition of the new `PowiOp` operation definition in the Complex
dialect
- Integration with algebraic simplification passes for optimization
- Support for conversion to ROCDL library calls
- Updates to Flang frontend to generate the new operation

This depends on #158642.
This commit is contained in:
Akash Banerjee
2025-09-19 02:36:13 +01:00
committed by GitHub
parent 1ad5d63e5e
commit fdb1f48638
16 changed files with 270 additions and 120 deletions

View File

@@ -1323,26 +1323,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
mlir::Value exp = args[1];
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
exp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
}
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
@@ -1668,11 +1648,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
genComplexPow},
genMathOp<mlir::complex::PowOp>},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
genComplexPow},
genMathOp<mlir::complex::PowOp>},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
genComplexPow},
genMathOp<mlir::complex::PowOp>},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
@@ -1693,20 +1673,20 @@ static constexpr MathOperation mathOperations[] = {
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(zpowi),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(zpowk),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),

View File

@@ -47,39 +47,19 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
return func;
}
static bool isZero(Value v) {
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
return attr.getValue().isZero();
return false;
}
void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
mod.walk([&](complex::PowOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();
Value base = op.getLhs();
Value rhs = op.getRhs();
Value intExp;
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
if (isZero(create.getImaginary())) {
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
intExp = conv.getValue();
}
}
}
func::FuncOp callee;
SmallVector<Value> args;
if (intExp) {
mod.walk([&](Operation *op) {
if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
builder.setInsertionPoint(powIop);
Location loc = powIop.getLoc();
auto complexTy = cast<ComplexType>(powIop.getType());
auto elemTy = complexTy.getElementType();
Value base = powIop.getLhs();
Value intExp = powIop.getRhs();
func::FuncOp callee;
unsigned realBits = cast<FloatType>(elemTy).getWidth();
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
auto funcTy = builder.getFunctionType(
@@ -98,9 +78,20 @@ void ConvertComplexPowPass::runOnOperation() {
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
else
return;
args = {base, intExp};
} else {
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
if (auto fmf = powIop.getFastmathAttr())
call.setFastmathAttr(fmf);
powIop.replaceAllUsesWith(call.getResult(0));
powIop.erase();
}
if (auto powOp = dyn_cast<complex::PowOp>(op)) {
builder.setInsertionPoint(powOp);
Location loc = powOp.getLoc();
auto complexTy = cast<ComplexType>(powOp.getType());
auto elemTy = complexTy.getElementType();
unsigned realBits = cast<FloatType>(elemTy).getWidth();
func::FuncOp callee;
auto funcTy =
builder.getFunctionType({complexTy, complexTy}, {complexTy});
if (realBits == 32)
@@ -111,13 +102,12 @@ void ConvertComplexPowPass::runOnOperation() {
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
else
return;
args = {base, rhs};
auto call = fir::CallOp::create(builder, loc, callee,
{powOp.getLhs(), powOp.getRhs()});
if (auto fmf = powOp.getFastmathAttr())
call.setFastmathAttr(fmf);
powOp.replaceAllUsesWith(call.getResult(0));
powOp.erase();
}
auto call = fir::CallOp::create(builder, loc, callee, args);
if (auto fmf = op.getFastmathAttr())
call.setFastmathAttr(fmf);
op.replaceAllUsesWith(call.getResult(0));
op.erase();
});
}

View File

@@ -193,7 +193,7 @@ end subroutine
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
! CHECK: %[[VAL_8:.*]] = complex.pow
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32
subroutine extremum(c, n, l)
integer(8), intent(in) :: l

View File

@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b

View File

@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b

View File

@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
end subroutine pow_test
! CHECK-LABEL: func @_QPpowi_test(
! CHECK: complex.powi
! CHECK-NOT: fir.call @_FortranAcpowi
subroutine powi_test(a, b, c)
complex :: a, b
integer :: i
b = a ** i
end subroutine powi_test

View File

@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
! PRECISE: fir.call @_FortranAcpowi
end subroutine
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
! PRECISE: fir.call @_FortranAcpowk
end subroutine
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
! PRECISE: fir.call @_FortranAzpowi
end subroutine
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
! PRECISE: fir.call @_FortranAzpowk
end subroutine
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: fir.call @cpow
end subroutine

View File

@@ -2,51 +2,38 @@
module {
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
%c0 = arith.constant 0.0 : f32
%0 = fir.convert %arg1 : (i32) -> f32
%1 = complex.create %0, %c0 : complex<f32>
%2 = complex.pow %arg0, %1 : complex<f32>
return %2 : complex<f32>
%0 = complex.powi %arg0, %arg1 : complex<f32>, i32
return %0 : complex<f32>
}
func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
%0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
return %0 : complex<f32>
}
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
%c0 = arith.constant 0.0 : f32
%0 = fir.convert %arg1 : (i64) -> f32
%1 = complex.create %0, %c0 : complex<f32>
%2 = complex.pow %arg0, %1 : complex<f32>
return %2 : complex<f32>
%0 = complex.powi %arg0, %arg1 : complex<f32>, i64
return %0 : complex<f32>
}
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
%c0 = arith.constant 0.0 : f64
%0 = fir.convert %arg1 : (i32) -> f64
%1 = complex.create %0, %c0 : complex<f64>
%2 = complex.pow %arg0, %1 : complex<f64>
return %2 : complex<f64>
%0 = complex.powi %arg0, %arg1 : complex<f64>, i32
return %0 : complex<f64>
}
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
%c0 = arith.constant 0.0 : f64
%0 = fir.convert %arg1 : (i64) -> f64
%1 = complex.create %0, %c0 : complex<f64>
%2 = complex.pow %arg0, %1 : complex<f64>
return %2 : complex<f64>
%0 = complex.powi %arg0, %arg1 : complex<f64>, i64
return %0 : complex<f64>
}
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
%c0 = arith.constant 0.0 : f128
%0 = fir.convert %arg1 : (i32) -> f128
%1 = complex.create %0, %c0 : complex<f128>
%2 = complex.pow %arg0, %1 : complex<f128>
return %2 : complex<f128>
%0 = complex.powi %arg0, %arg1 : complex<f128>, i32
return %0 : complex<f128>
}
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
%c0 = arith.constant 0.0 : f128
%0 = fir.convert %arg1 : (i64) -> f128
%1 = complex.create %0, %c0 : complex<f128>
%2 = complex.pow %arg0, %1 : complex<f128>
return %2 : complex<f128>
%0 = complex.powi %arg0, %arg1 : complex<f128>, i64
return %0 : complex<f128>
}
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
@@ -74,26 +61,37 @@ module {
// CHECK-LABEL: func.func @pow_c4_i4(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_i4_fast(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_i8(
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_i4(
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_i8(
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_i4(
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_i8(
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_fast(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
@@ -108,4 +106,4 @@ module {
// CHECK-LABEL: func.func @pow_c16_complex(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.pow

View File

@@ -443,6 +443,36 @@ def PowOp : ComplexArithmeticOp<"pow"> {
}];
}
//===----------------------------------------------------------------------===//
// PowiOp
//===----------------------------------------------------------------------===//
def PowiOp : Complex_Op<"powi",
[Pure, Elementwise, SameOperandsAndResultShape,
AllTypesMatch<["lhs", "result"]>,
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "complex number raised to signed integer power";
let description = [{
The `powi` operation takes a `base` operand of complex type and a `power`
operand of signed integer type and returns one result of the same type
as `base`. The result is `base` raised to the power of `power`.
Example:
```mlir
%a = complex.powi %b, %c : complex<f32>, i32
```
}];
let arguments = (ins Complex<AnyFloat>:$lhs,
AnySignlessInteger:$rhs,
OptionalAttr<Arith_FastMathAttr>:$fastmath);
let results = (outs Complex<AnyFloat>:$result);
let assemblyFormat =
"$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
}
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//

View File

@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -74,10 +76,39 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
return success();
}
};
// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
LogicalResult matchAndRewrite(complex::PowiOp op,
PatternRewriter &rewriter) const final {
auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
Type elementType = complexType.getElementType();
Type exponentType = op.getRhs().getType();
Type exponentFloatType = elementType;
if (auto shapedType = dyn_cast<ShapedType>(exponentType))
exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
Location loc = op.getLoc();
Value exponentReal =
rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
Value zeroImag = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(exponentFloatType));
Value exponent = rewriter.create<complex::CreateOp>(
loc, op.getLhs().getType(), exponentReal, zeroImag);
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
exponent, op.getFastmathAttr());
return success();
}
};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
@@ -128,11 +159,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
target.addLegalOp<complex::MulOp>();
target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
target.addLegalOp<complex::CreateOp, complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
complex::LogOp, complex::PowOp, complex::SinOp,
complex::SqrtOp, complex::TanOp, complex::TanhOp>();
complex::LogOp, complex::PowOp, complex::PowiOp,
complex::SinOp, complex::SqrtOp, complex::TanOp,
complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}

View File

@@ -926,6 +926,30 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
return cutoff4;
}
struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
auto type = cast<ComplexType>(op.getType());
auto elementType = cast<FloatType>(type.getElementType());
Value floatExponent =
builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
Value zero = arith::ConstantOp::create(
builder, elementType, builder.getFloatAttr(elementType, 0.0));
Value complexExponent =
complex::CreateOp::create(builder, type, floatExponent, zero);
auto pow = builder.create<complex::PowOp>(
type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
rewriter.replaceOp(op, pow.getResult());
return success();
}
};
struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
using OpConversionPattern<complex::PowOp>::OpConversionPattern;
@@ -1070,6 +1094,7 @@ void mlir::populateComplexToStandardConversionPatterns(
SqrtOpConversion,
TanTanhOpConversion<complex::TanOp>,
TanTanhOpConversion<complex::TanhOp>,
PowiOpConversion,
PowOpConversion,
RsqrtOpConversion
>(patterns.getContext());

View File

@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
Value one;
Type opType = getElementTypeOrSelf(op.getType());
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
one = arith::ConstantOp::create(rewriter, loc,
rewriter.getFloatAttr(opType, 1.0));
else
} else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
auto complexTy = cast<ComplexType>(opType);
Type elementType = complexTy.getElementType();
auto realPart = rewriter.getFloatAttr(elementType, 1.0);
auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
one = complex::ConstantOp::create(
rewriter, loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
} else {
one = arith::ConstantOp::create(rewriter, loc,
rewriter.getIntegerAttr(opType, 1));
}
// Replace `[fi]powi(x, 0)` with `1`.
if (exponentValue == 0) {
@@ -208,13 +217,25 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// `[fi]powi(x, negative_exponent)`
// with:
// (1 / x) * (1 / x) * (1 / x) * ...
auto buildMul = [&](Value lhs, Value rhs) {
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
op.getFastmathAttr());
else
return MulOpTy::create(rewriter, loc, lhs, rhs);
};
for (unsigned i = 1; i < exponentValue; ++i)
result = MulOpTy::create(rewriter, loc, result, base);
result = buildMul(result, base);
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
if (exponentIsNegative)
result = DivOpTy::create(rewriter, loc, bcast(one), result);
if (exponentIsNegative) {
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result,
op.getFastmathAttr());
else
result = DivOpTy::create(rewriter, loc, bcast(one), result);
}
rewriter.replaceOp(op, result);
return success();
@@ -224,9 +245,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
patterns
.add<PowFStrengthReduction,
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
patterns.getContext());
patterns.add<
PowFStrengthReduction,
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
patterns.getContext(), /*exponentThreshold=*/8);
}

View File

@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRComplexDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect

View File

@@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
return %r : complex<f32>
}
//CHECK-LABEL: @powi_caller
//CHECK: (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32)
func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> {
// CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex<f32>
// CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> complex<f32>
// CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32>
// CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> complex<f32>
// CHECK: return %[[EXP]] : complex<f32>
%r = complex.powi %z, %n : complex<f32>, i32
return %r : complex<f32>
}
//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})

View File

@@ -700,6 +700,36 @@ func.func @complex_pow_with_fmf(%lhs: complex<f32>,
// -----
// CHECK-LABEL: func.func @complex_powi
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
func.func @complex_powi(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
%pow = complex.powi %lhs, %rhs : complex<f32>, i32
return %pow : complex<f32>
}
// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
// CHECK: math.atan2
// CHECK-NOT: complex.powi
// -----
// CHECK-LABEL: func.func @complex_powi_with_fmf
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
func.func @complex_powi_with_fmf(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
%pow = complex.powi %lhs, %rhs fastmath<nnan,contract> : complex<f32>, i32
return %pow : complex<f32>
}
// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
// CHECK: math.atan2 {{.*}} fastmath<nnan,contract> : f32
// CHECK-NOT: complex.powi
// -----
// CHECK-LABEL: func.func @complex_rsqrt
func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
%rsqrt = complex.rsqrt %arg : complex<f32>

View File

@@ -0,0 +1,20 @@
// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s
func.func @pow3(%arg0: complex<f32>) -> complex<f32> {
%c3 = arith.constant 3 : i32
%0 = complex.powi %arg0, %c3 : complex<f32>, i32
return %0 : complex<f32>
}
// CHECK-LABEL: func.func @pow3(
// CHECK-NOT: complex.powi
// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32>
// CHECK: return %[[M1]] : complex<f32>
func.func @pow9(%arg0: complex<f32>) -> complex<f32> {
%c9 = arith.constant 9 : i32
%0 = complex.powi %arg0, %c9 : complex<f32>, i32
return %0 : complex<f32>
}
// CHECK-LABEL: func.func @pow9(
// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32