mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 14:48:27 +08:00
[MLIR] Add new complex.powi op (#158722)
This PR adds a new complex.powi operation to MLIR's complex dialect for computing complex numbers raised to integer powers. Key changes include: - Addition of the new `PowiOp` operation definition in the Complex dialect - Integration with algebraic simplification passes for optimization - Support for conversion to ROCDL library calls - Updates to Flang frontend to generate the new operation This depends on #158642.
This commit is contained in:
@@ -1323,26 +1323,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
return result;
|
||||
}
|
||||
|
||||
mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
|
||||
const MathOperation &mathOp,
|
||||
mlir::FunctionType mathLibFuncType,
|
||||
llvm::ArrayRef<mlir::Value> args) {
|
||||
if (mathRuntimeVersion == preciseVersion)
|
||||
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
|
||||
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
|
||||
mlir::Value exp = args[1];
|
||||
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
|
||||
auto realTy = complexTy.getElementType();
|
||||
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
|
||||
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
|
||||
exp =
|
||||
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
|
||||
}
|
||||
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
|
||||
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Mapping between mathematical intrinsic operations and MLIR operations
|
||||
/// of some appropriate dialect (math, complex, etc.) or libm calls.
|
||||
/// TODO: support remaining Fortran math intrinsics.
|
||||
@@ -1668,11 +1648,11 @@ static constexpr MathOperation mathOperations[] = {
|
||||
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
|
||||
{"pow", "cpowf",
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowOp>},
|
||||
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowOp>},
|
||||
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowOp>},
|
||||
{"pow", RTNAME_STRING(FPow4i),
|
||||
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
|
||||
genMathOp<mlir::math::FPowIOp>},
|
||||
@@ -1693,20 +1673,20 @@ static constexpr MathOperation mathOperations[] = {
|
||||
genMathOp<mlir::math::FPowIOp>},
|
||||
{"pow", RTNAME_STRING(cpowi),
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowiOp>},
|
||||
{"pow", RTNAME_STRING(zpowi),
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowiOp>},
|
||||
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowiOp>},
|
||||
{"pow", RTNAME_STRING(cpowk),
|
||||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowiOp>},
|
||||
{"pow", RTNAME_STRING(zpowk),
|
||||
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowiOp>},
|
||||
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
|
||||
genComplexPow},
|
||||
genMathOp<mlir::complex::PowiOp>},
|
||||
{"pow-unsigned", RTNAME_STRING(UPow1),
|
||||
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
|
||||
{"pow-unsigned", RTNAME_STRING(UPow2),
|
||||
|
||||
@@ -47,39 +47,19 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
|
||||
return func;
|
||||
}
|
||||
|
||||
static bool isZero(Value v) {
|
||||
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
|
||||
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
|
||||
return attr.getValue().isZero();
|
||||
return false;
|
||||
}
|
||||
|
||||
void ConvertComplexPowPass::runOnOperation() {
|
||||
ModuleOp mod = getOperation();
|
||||
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
|
||||
|
||||
mod.walk([&](complex::PowOp op) {
|
||||
builder.setInsertionPoint(op);
|
||||
Location loc = op.getLoc();
|
||||
auto complexTy = cast<ComplexType>(op.getType());
|
||||
auto elemTy = complexTy.getElementType();
|
||||
|
||||
Value base = op.getLhs();
|
||||
Value rhs = op.getRhs();
|
||||
|
||||
Value intExp;
|
||||
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
|
||||
if (isZero(create.getImaginary())) {
|
||||
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
|
||||
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
|
||||
intExp = conv.getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func::FuncOp callee;
|
||||
SmallVector<Value> args;
|
||||
if (intExp) {
|
||||
mod.walk([&](Operation *op) {
|
||||
if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
|
||||
builder.setInsertionPoint(powIop);
|
||||
Location loc = powIop.getLoc();
|
||||
auto complexTy = cast<ComplexType>(powIop.getType());
|
||||
auto elemTy = complexTy.getElementType();
|
||||
Value base = powIop.getLhs();
|
||||
Value intExp = powIop.getRhs();
|
||||
func::FuncOp callee;
|
||||
unsigned realBits = cast<FloatType>(elemTy).getWidth();
|
||||
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
|
||||
auto funcTy = builder.getFunctionType(
|
||||
@@ -98,9 +78,20 @@ void ConvertComplexPowPass::runOnOperation() {
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
|
||||
else
|
||||
return;
|
||||
args = {base, intExp};
|
||||
} else {
|
||||
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
|
||||
if (auto fmf = powIop.getFastmathAttr())
|
||||
call.setFastmathAttr(fmf);
|
||||
powIop.replaceAllUsesWith(call.getResult(0));
|
||||
powIop.erase();
|
||||
}
|
||||
|
||||
if (auto powOp = dyn_cast<complex::PowOp>(op)) {
|
||||
builder.setInsertionPoint(powOp);
|
||||
Location loc = powOp.getLoc();
|
||||
auto complexTy = cast<ComplexType>(powOp.getType());
|
||||
auto elemTy = complexTy.getElementType();
|
||||
unsigned realBits = cast<FloatType>(elemTy).getWidth();
|
||||
func::FuncOp callee;
|
||||
auto funcTy =
|
||||
builder.getFunctionType({complexTy, complexTy}, {complexTy});
|
||||
if (realBits == 32)
|
||||
@@ -111,13 +102,12 @@ void ConvertComplexPowPass::runOnOperation() {
|
||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
|
||||
else
|
||||
return;
|
||||
args = {base, rhs};
|
||||
auto call = fir::CallOp::create(builder, loc, callee,
|
||||
{powOp.getLhs(), powOp.getRhs()});
|
||||
if (auto fmf = powOp.getFastmathAttr())
|
||||
call.setFastmathAttr(fmf);
|
||||
powOp.replaceAllUsesWith(call.getResult(0));
|
||||
powOp.erase();
|
||||
}
|
||||
|
||||
auto call = fir::CallOp::create(builder, loc, callee, args);
|
||||
if (auto fmf = op.getFastmathAttr())
|
||||
call.setFastmathAttr(fmf);
|
||||
op.replaceAllUsesWith(call.getResult(0));
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -193,7 +193,7 @@ end subroutine
|
||||
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
|
||||
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
|
||||
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
|
||||
! CHECK: %[[VAL_8:.*]] = complex.pow
|
||||
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32
|
||||
|
||||
subroutine extremum(c, n, l)
|
||||
integer(8), intent(in) :: l
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
complex(16) :: a
|
||||
integer(4) :: b
|
||||
b = a ** b
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
|
||||
complex(16) :: a
|
||||
integer(8) :: b
|
||||
b = a ** b
|
||||
|
||||
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
|
||||
complex :: a, b, c
|
||||
a = b**c
|
||||
end subroutine pow_test
|
||||
|
||||
! CHECK-LABEL: func @_QPpowi_test(
|
||||
! CHECK: complex.powi
|
||||
! CHECK-NOT: fir.call @_FortranAcpowi
|
||||
subroutine powi_test(a, b, c)
|
||||
complex :: a, b
|
||||
integer :: i
|
||||
b = a ** i
|
||||
end subroutine powi_test
|
||||
|
||||
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
|
||||
complex :: x, z
|
||||
integer :: y
|
||||
z = x ** y
|
||||
! CHECK: complex.pow
|
||||
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
|
||||
! PRECISE: fir.call @_FortranAcpowi
|
||||
end subroutine
|
||||
|
||||
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
|
||||
complex :: x, z
|
||||
integer(8) :: y
|
||||
z = x ** y
|
||||
! CHECK: complex.pow
|
||||
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
|
||||
! PRECISE: fir.call @_FortranAcpowk
|
||||
end subroutine
|
||||
|
||||
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
|
||||
complex(8) :: x, z
|
||||
integer :: y
|
||||
z = x ** y
|
||||
! CHECK: complex.pow
|
||||
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
|
||||
! PRECISE: fir.call @_FortranAzpowi
|
||||
end subroutine
|
||||
|
||||
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
|
||||
complex(8) :: x, z
|
||||
integer(8) :: y
|
||||
z = x ** y
|
||||
! CHECK: complex.pow
|
||||
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
|
||||
! PRECISE: fir.call @_FortranAzpowk
|
||||
end subroutine
|
||||
|
||||
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
|
||||
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
|
||||
! PRECISE: fir.call @cpow
|
||||
end subroutine
|
||||
|
||||
|
||||
@@ -2,51 +2,38 @@
|
||||
|
||||
module {
|
||||
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%0 = fir.convert %arg1 : (i32) -> f32
|
||||
%1 = complex.create %0, %c0 : complex<f32>
|
||||
%2 = complex.pow %arg0, %1 : complex<f32>
|
||||
return %2 : complex<f32>
|
||||
%0 = complex.powi %arg0, %arg1 : complex<f32>, i32
|
||||
return %0 : complex<f32>
|
||||
}
|
||||
|
||||
func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
|
||||
%0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
|
||||
return %0 : complex<f32>
|
||||
}
|
||||
|
||||
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%0 = fir.convert %arg1 : (i64) -> f32
|
||||
%1 = complex.create %0, %c0 : complex<f32>
|
||||
%2 = complex.pow %arg0, %1 : complex<f32>
|
||||
return %2 : complex<f32>
|
||||
%0 = complex.powi %arg0, %arg1 : complex<f32>, i64
|
||||
return %0 : complex<f32>
|
||||
}
|
||||
|
||||
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
|
||||
%c0 = arith.constant 0.0 : f64
|
||||
%0 = fir.convert %arg1 : (i32) -> f64
|
||||
%1 = complex.create %0, %c0 : complex<f64>
|
||||
%2 = complex.pow %arg0, %1 : complex<f64>
|
||||
return %2 : complex<f64>
|
||||
%0 = complex.powi %arg0, %arg1 : complex<f64>, i32
|
||||
return %0 : complex<f64>
|
||||
}
|
||||
|
||||
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
|
||||
%c0 = arith.constant 0.0 : f64
|
||||
%0 = fir.convert %arg1 : (i64) -> f64
|
||||
%1 = complex.create %0, %c0 : complex<f64>
|
||||
%2 = complex.pow %arg0, %1 : complex<f64>
|
||||
return %2 : complex<f64>
|
||||
%0 = complex.powi %arg0, %arg1 : complex<f64>, i64
|
||||
return %0 : complex<f64>
|
||||
}
|
||||
|
||||
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
|
||||
%c0 = arith.constant 0.0 : f128
|
||||
%0 = fir.convert %arg1 : (i32) -> f128
|
||||
%1 = complex.create %0, %c0 : complex<f128>
|
||||
%2 = complex.pow %arg0, %1 : complex<f128>
|
||||
return %2 : complex<f128>
|
||||
%0 = complex.powi %arg0, %arg1 : complex<f128>, i32
|
||||
return %0 : complex<f128>
|
||||
}
|
||||
|
||||
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
|
||||
%c0 = arith.constant 0.0 : f128
|
||||
%0 = fir.convert %arg1 : (i64) -> f128
|
||||
%1 = complex.create %0, %c0 : complex<f128>
|
||||
%2 = complex.pow %arg0, %1 : complex<f128>
|
||||
return %2 : complex<f128>
|
||||
%0 = complex.powi %arg0, %arg1 : complex<f128>, i64
|
||||
return %0 : complex<f128>
|
||||
}
|
||||
|
||||
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
|
||||
@@ -74,26 +61,37 @@ module {
|
||||
// CHECK-LABEL: func.func @pow_c4_i4(
|
||||
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c4_i4_fast(
|
||||
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c4_i8(
|
||||
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c8_i4(
|
||||
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c8_i8(
|
||||
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c16_i4(
|
||||
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c16_i8(
|
||||
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// CHECK-LABEL: func.func @pow_c4_fast(
|
||||
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
|
||||
@@ -108,4 +106,4 @@ module {
|
||||
// CHECK-LABEL: func.func @pow_c16_complex(
|
||||
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
|
||||
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
|
||||
// CHECK-NOT: complex.pow
|
||||
// CHECK-NOT: complex.pow
|
||||
|
||||
@@ -443,6 +443,36 @@ def PowOp : ComplexArithmeticOp<"pow"> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PowiOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PowiOp : Complex_Op<"powi",
|
||||
[Pure, Elementwise, SameOperandsAndResultShape,
|
||||
AllTypesMatch<["lhs", "result"]>,
|
||||
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
|
||||
let summary = "complex number raised to signed integer power";
|
||||
let description = [{
|
||||
The `powi` operation takes a `base` operand of complex type and a `power`
|
||||
operand of signed integer type and returns one result of the same type
|
||||
as `base`. The result is `base` raised to the power of `power`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%a = complex.powi %b, %c : complex<f32>, i32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Complex<AnyFloat>:$lhs,
|
||||
AnySignlessInteger:$rhs,
|
||||
OptionalAttr<Arith_FastMathAttr>:$fastmath);
|
||||
let results = (outs Complex<AnyFloat>:$result);
|
||||
|
||||
let assemblyFormat =
|
||||
"$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -7,9 +7,11 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -74,10 +76,39 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
|
||||
struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
|
||||
using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(complex::PowiOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
|
||||
Type elementType = complexType.getElementType();
|
||||
|
||||
Type exponentType = op.getRhs().getType();
|
||||
Type exponentFloatType = elementType;
|
||||
if (auto shapedType = dyn_cast<ShapedType>(exponentType))
|
||||
exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value exponentReal =
|
||||
rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
|
||||
Value zeroImag = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(exponentFloatType));
|
||||
Value exponent = rewriter.create<complex::CreateOp>(
|
||||
loc, op.getLhs().getType(), exponentReal, zeroImag);
|
||||
|
||||
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
|
||||
exponent, op.getFastmathAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
|
||||
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
|
||||
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
|
||||
patterns.getContext(), "__ocml_cabs_f32");
|
||||
@@ -128,11 +159,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
|
||||
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalOp<complex::MulOp>();
|
||||
target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
|
||||
target.addLegalOp<complex::CreateOp, complex::MulOp>();
|
||||
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
|
||||
complex::LogOp, complex::PowOp, complex::SinOp,
|
||||
complex::SqrtOp, complex::TanOp, complex::TanhOp>();
|
||||
complex::LogOp, complex::PowOp, complex::PowiOp,
|
||||
complex::SinOp, complex::SqrtOp, complex::TanOp,
|
||||
complex::TanhOp>();
|
||||
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
@@ -926,6 +926,30 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
|
||||
return cutoff4;
|
||||
}
|
||||
|
||||
struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
|
||||
using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
|
||||
auto type = cast<ComplexType>(op.getType());
|
||||
auto elementType = cast<FloatType>(type.getElementType());
|
||||
|
||||
Value floatExponent =
|
||||
builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
|
||||
Value zero = arith::ConstantOp::create(
|
||||
builder, elementType, builder.getFloatAttr(elementType, 0.0));
|
||||
Value complexExponent =
|
||||
complex::CreateOp::create(builder, type, floatExponent, zero);
|
||||
|
||||
auto pow = builder.create<complex::PowOp>(
|
||||
type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
|
||||
rewriter.replaceOp(op, pow.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
|
||||
using OpConversionPattern<complex::PowOp>::OpConversionPattern;
|
||||
|
||||
@@ -1070,6 +1094,7 @@ void mlir::populateComplexToStandardConversionPatterns(
|
||||
SqrtOpConversion,
|
||||
TanTanhOpConversion<complex::TanOp>,
|
||||
TanTanhOpConversion<complex::TanhOp>,
|
||||
PowiOpConversion,
|
||||
PowOpConversion,
|
||||
RsqrtOpConversion
|
||||
>(patterns.getContext());
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
|
||||
|
||||
Value one;
|
||||
Type opType = getElementTypeOrSelf(op.getType());
|
||||
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
|
||||
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
|
||||
one = arith::ConstantOp::create(rewriter, loc,
|
||||
rewriter.getFloatAttr(opType, 1.0));
|
||||
else
|
||||
} else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
|
||||
auto complexTy = cast<ComplexType>(opType);
|
||||
Type elementType = complexTy.getElementType();
|
||||
auto realPart = rewriter.getFloatAttr(elementType, 1.0);
|
||||
auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
|
||||
one = complex::ConstantOp::create(
|
||||
rewriter, loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
|
||||
} else {
|
||||
one = arith::ConstantOp::create(rewriter, loc,
|
||||
rewriter.getIntegerAttr(opType, 1));
|
||||
}
|
||||
|
||||
// Replace `[fi]powi(x, 0)` with `1`.
|
||||
if (exponentValue == 0) {
|
||||
@@ -208,13 +217,25 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
|
||||
// `[fi]powi(x, negative_exponent)`
|
||||
// with:
|
||||
// (1 / x) * (1 / x) * (1 / x) * ...
|
||||
auto buildMul = [&](Value lhs, Value rhs) {
|
||||
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
|
||||
return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
|
||||
op.getFastmathAttr());
|
||||
else
|
||||
return MulOpTy::create(rewriter, loc, lhs, rhs);
|
||||
};
|
||||
for (unsigned i = 1; i < exponentValue; ++i)
|
||||
result = MulOpTy::create(rewriter, loc, result, base);
|
||||
result = buildMul(result, base);
|
||||
|
||||
// Inverse the base for negative exponent, i.e. for
|
||||
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
|
||||
if (exponentIsNegative)
|
||||
result = DivOpTy::create(rewriter, loc, bcast(one), result);
|
||||
if (exponentIsNegative) {
|
||||
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
|
||||
result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result,
|
||||
op.getFastmathAttr());
|
||||
else
|
||||
result = DivOpTy::create(rewriter, loc, bcast(one), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
@@ -224,9 +245,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
|
||||
|
||||
void mlir::populateMathAlgebraicSimplificationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<PowFStrengthReduction,
|
||||
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
|
||||
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
|
||||
patterns.getContext());
|
||||
patterns.add<
|
||||
PowFStrengthReduction,
|
||||
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
|
||||
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
|
||||
PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
|
||||
patterns.getContext(), /*exponentThreshold=*/8);
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithDialect
|
||||
MLIRComplexDialect
|
||||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
MLIRMathDialect
|
||||
|
||||
@@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
|
||||
return %r : complex<f32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @powi_caller
|
||||
//CHECK: (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32)
|
||||
func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> {
|
||||
// CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex<f32>
|
||||
// CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> complex<f32>
|
||||
// CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32>
|
||||
// CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> complex<f32>
|
||||
// CHECK: return %[[EXP]] : complex<f32>
|
||||
%r = complex.powi %z, %n : complex<f32>, i32
|
||||
return %r : complex<f32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @sin_caller
|
||||
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
|
||||
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
|
||||
|
||||
@@ -700,6 +700,36 @@ func.func @complex_pow_with_fmf(%lhs: complex<f32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @complex_powi
|
||||
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
|
||||
func.func @complex_powi(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
|
||||
%pow = complex.powi %lhs, %rhs : complex<f32>, i32
|
||||
return %pow : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
|
||||
// CHECK: math.atan2
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @complex_powi_with_fmf
|
||||
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32
|
||||
func.func @complex_powi_with_fmf(%lhs: complex<f32>, %rhs: i32) -> complex<f32> {
|
||||
%pow = complex.powi %lhs, %rhs fastmath<nnan,contract> : complex<f32>, i32
|
||||
return %pow : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32>
|
||||
// CHECK: math.atan2 {{.*}} fastmath<nnan,contract> : f32
|
||||
// CHECK-NOT: complex.powi
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @complex_rsqrt
|
||||
func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
|
||||
%rsqrt = complex.rsqrt %arg : complex<f32>
|
||||
|
||||
20
mlir/test/Dialect/Complex/powi-simplify.mlir
Normal file
20
mlir/test/Dialect/Complex/powi-simplify.mlir
Normal file
@@ -0,0 +1,20 @@
|
||||
// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s
|
||||
|
||||
func.func @pow3(%arg0: complex<f32>) -> complex<f32> {
|
||||
%c3 = arith.constant 3 : i32
|
||||
%0 = complex.powi %arg0, %c3 : complex<f32>, i32
|
||||
return %0 : complex<f32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @pow3(
|
||||
// CHECK-NOT: complex.powi
|
||||
// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
|
||||
// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32>
|
||||
// CHECK: return %[[M1]] : complex<f32>
|
||||
|
||||
func.func @pow9(%arg0: complex<f32>) -> complex<f32> {
|
||||
%c9 = arith.constant 9 : i32
|
||||
%0 = complex.powi %arg0, %c9 : complex<f32>, i32
|
||||
return %0 : complex<f32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @pow9(
|
||||
// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
|
||||
Reference in New Issue
Block a user