mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 05:32:28 +08:00
[mlir][NFC] update Conversion create APIs (5/n) (#149887)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
This commit is contained in:
@@ -23,41 +23,43 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
|
||||
Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe,
|
||||
Value *resultIm) {
|
||||
Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
|
||||
rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
|
||||
Value rhsSqNorm = LLVM::FAddOp::create(
|
||||
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
|
||||
|
||||
Value realNumerator = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
|
||||
Value realNumerator = LLVM::FAddOp::create(
|
||||
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
|
||||
|
||||
Value imagNumerator = rewriter.create<LLVM::FSubOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
|
||||
Value imagNumerator = LLVM::FSubOp::create(
|
||||
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
|
||||
|
||||
*resultRe = rewriter.create<LLVM::FDivOp>(loc, realNumerator, rhsSqNorm, fmf);
|
||||
*resultIm = rewriter.create<LLVM::FDivOp>(loc, imagNumerator, rhsSqNorm, fmf);
|
||||
*resultRe =
|
||||
LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
|
||||
*resultIm =
|
||||
LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
|
||||
}
|
||||
|
||||
void mlir::complex::convertDivToStandardUsingAlgebraic(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
|
||||
Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
|
||||
Value *resultIm) {
|
||||
Value rhsSqNorm = rewriter.create<arith::AddFOp>(
|
||||
loc, rewriter.create<arith::MulFOp>(loc, rhsRe, rhsRe, fmf),
|
||||
rewriter.create<arith::MulFOp>(loc, rhsIm, rhsIm, fmf), fmf);
|
||||
Value rhsSqNorm = arith::AddFOp::create(
|
||||
rewriter, loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
|
||||
|
||||
Value realNumerator = rewriter.create<arith::AddFOp>(
|
||||
loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRe, fmf),
|
||||
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsIm, fmf), fmf);
|
||||
Value imagNumerator = rewriter.create<arith::SubFOp>(
|
||||
loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRe, fmf),
|
||||
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsIm, fmf), fmf);
|
||||
Value realNumerator = arith::AddFOp::create(
|
||||
rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
|
||||
Value imagNumerator = arith::SubFOp::create(
|
||||
rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
|
||||
|
||||
*resultRe =
|
||||
rewriter.create<arith::DivFOp>(loc, realNumerator, rhsSqNorm, fmf);
|
||||
arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
|
||||
*resultIm =
|
||||
rewriter.create<arith::DivFOp>(loc, imagNumerator, rhsSqNorm, fmf);
|
||||
arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
|
||||
}
|
||||
|
||||
// Smith's algorithm to divide complex numbers. It is just a bit smarter
|
||||
@@ -94,181 +96,185 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction(
|
||||
auto elementType = cast<FloatType>(rhsRe.getType());
|
||||
|
||||
Value rhsRealImagRatio =
|
||||
rewriter.create<LLVM::FDivOp>(loc, rhsRe, rhsIm, fmf);
|
||||
Value rhsRealImagDenom = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rhsIm,
|
||||
rewriter.create<LLVM::FMulOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
|
||||
Value realNumerator1 = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealImagRatio, fmf),
|
||||
lhsIm, fmf);
|
||||
Value resultReal1 =
|
||||
rewriter.create<LLVM::FDivOp>(loc, realNumerator1, rhsRealImagDenom, fmf);
|
||||
Value imagNumerator1 = rewriter.create<LLVM::FSubOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealImagRatio, fmf),
|
||||
lhsRe, fmf);
|
||||
Value resultImag1 =
|
||||
rewriter.create<LLVM::FDivOp>(loc, imagNumerator1, rhsRealImagDenom, fmf);
|
||||
LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
|
||||
Value rhsRealImagDenom = LLVM::FAddOp::create(
|
||||
rewriter, loc, rhsIm,
|
||||
LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
|
||||
Value realNumerator1 = LLVM::FAddOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
|
||||
fmf);
|
||||
Value resultReal1 = LLVM::FDivOp::create(rewriter, loc, realNumerator1,
|
||||
rhsRealImagDenom, fmf);
|
||||
Value imagNumerator1 = LLVM::FSubOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
|
||||
fmf);
|
||||
Value resultImag1 = LLVM::FDivOp::create(rewriter, loc, imagNumerator1,
|
||||
rhsRealImagDenom, fmf);
|
||||
|
||||
Value rhsImagRealRatio =
|
||||
rewriter.create<LLVM::FDivOp>(loc, rhsIm, rhsRe, fmf);
|
||||
Value rhsImagRealDenom = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rhsRe,
|
||||
rewriter.create<LLVM::FMulOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
|
||||
Value realNumerator2 = rewriter.create<LLVM::FAddOp>(
|
||||
loc, lhsRe,
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultReal2 =
|
||||
rewriter.create<LLVM::FDivOp>(loc, realNumerator2, rhsImagRealDenom, fmf);
|
||||
Value imagNumerator2 = rewriter.create<LLVM::FSubOp>(
|
||||
loc, lhsIm,
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultImag2 =
|
||||
rewriter.create<LLVM::FDivOp>(loc, imagNumerator2, rhsImagRealDenom, fmf);
|
||||
LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
|
||||
Value rhsImagRealDenom = LLVM::FAddOp::create(
|
||||
rewriter, loc, rhsRe,
|
||||
LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
|
||||
Value realNumerator2 = LLVM::FAddOp::create(
|
||||
rewriter, loc, lhsRe,
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultReal2 = LLVM::FDivOp::create(rewriter, loc, realNumerator2,
|
||||
rhsImagRealDenom, fmf);
|
||||
Value imagNumerator2 = LLVM::FSubOp::create(
|
||||
rewriter, loc, lhsIm,
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultImag2 = LLVM::FDivOp::create(rewriter, loc, imagNumerator2,
|
||||
rhsImagRealDenom, fmf);
|
||||
|
||||
// Consider corner cases.
|
||||
// Case 1. Zero denominator, numerator contains at most one NaN value.
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
Value rhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsRe, fmf);
|
||||
Value rhsRealIsZero = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
|
||||
Value rhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsIm, fmf);
|
||||
Value rhsImagIsZero = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
|
||||
Value lhsRealIsNotNaN =
|
||||
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
|
||||
Value lhsImagIsNotNaN =
|
||||
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
|
||||
Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType,
|
||||
rewriter.getZeroAttr(elementType));
|
||||
Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf);
|
||||
Value rhsRealIsZero = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
|
||||
Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf);
|
||||
Value rhsImagIsZero = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
|
||||
Value lhsRealIsNotNaN = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
|
||||
Value lhsImagIsNotNaN = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
|
||||
Value lhsContainsNotNaNValue =
|
||||
rewriter.create<LLVM::OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
|
||||
Value resultIsInfinity = rewriter.create<LLVM::AndOp>(
|
||||
loc, lhsContainsNotNaNValue,
|
||||
rewriter.create<LLVM::AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
|
||||
Value inf = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, elementType,
|
||||
LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
|
||||
Value resultIsInfinity = LLVM::AndOp::create(
|
||||
rewriter, loc, lhsContainsNotNaNValue,
|
||||
LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
|
||||
Value inf = LLVM::ConstantOp::create(
|
||||
rewriter, loc, elementType,
|
||||
rewriter.getFloatAttr(elementType,
|
||||
APFloat::getInf(elementType.getFloatSemantics())));
|
||||
Value infWithSignOfrhsReal =
|
||||
rewriter.create<LLVM::CopySignOp>(loc, inf, rhsRe);
|
||||
LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe);
|
||||
Value infinityResultReal =
|
||||
rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsRe, fmf);
|
||||
LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf);
|
||||
Value infinityResultImag =
|
||||
rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsIm, fmf);
|
||||
LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf);
|
||||
|
||||
// Case 2. Infinite numerator, finite denominator.
|
||||
Value rhsRealFinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
|
||||
Value rhsImagFinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
|
||||
Value rhsRealFinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
|
||||
Value rhsImagFinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
|
||||
Value rhsFinite =
|
||||
rewriter.create<LLVM::AndOp>(loc, rhsRealFinite, rhsImagFinite);
|
||||
Value lhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsRe, fmf);
|
||||
Value lhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
|
||||
Value lhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsIm, fmf);
|
||||
Value lhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
|
||||
LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
|
||||
Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf);
|
||||
Value lhsRealInfinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
|
||||
Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf);
|
||||
Value lhsImagInfinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
|
||||
Value lhsInfinite =
|
||||
rewriter.create<LLVM::OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
|
||||
LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
|
||||
Value infNumFiniteDenom =
|
||||
rewriter.create<LLVM::AndOp>(loc, lhsInfinite, rhsFinite);
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, elementType, rewriter.getFloatAttr(elementType, 1));
|
||||
Value lhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
|
||||
loc, rewriter.create<LLVM::SelectOp>(loc, lhsRealInfinite, one, zero),
|
||||
lhsRe);
|
||||
Value lhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
|
||||
loc, rewriter.create<LLVM::SelectOp>(loc, lhsImagInfinite, one, zero),
|
||||
lhsIm);
|
||||
LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite);
|
||||
Value one = LLVM::ConstantOp::create(rewriter, loc, elementType,
|
||||
rewriter.getFloatAttr(elementType, 1));
|
||||
Value lhsRealIsInfWithSign = LLVM::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe);
|
||||
Value lhsImagIsInfWithSign = LLVM::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm);
|
||||
Value lhsRealIsInfWithSignTimesrhsReal =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
|
||||
Value lhsImagIsInfWithSignTimesrhsImag =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
|
||||
Value resultReal3 = rewriter.create<LLVM::FMulOp>(
|
||||
loc, inf,
|
||||
rewriter.create<LLVM::FAddOp>(loc, lhsRealIsInfWithSignTimesrhsReal,
|
||||
lhsImagIsInfWithSignTimesrhsImag, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
|
||||
Value resultReal3 = LLVM::FMulOp::create(
|
||||
rewriter, loc, inf,
|
||||
LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal,
|
||||
lhsImagIsInfWithSignTimesrhsImag, fmf),
|
||||
fmf);
|
||||
Value lhsRealIsInfWithSignTimesrhsImag =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
|
||||
Value lhsImagIsInfWithSignTimesrhsReal =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
|
||||
Value resultImag3 = rewriter.create<LLVM::FMulOp>(
|
||||
loc, inf,
|
||||
rewriter.create<LLVM::FSubOp>(loc, lhsImagIsInfWithSignTimesrhsReal,
|
||||
lhsRealIsInfWithSignTimesrhsImag, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
|
||||
Value resultImag3 = LLVM::FMulOp::create(
|
||||
rewriter, loc, inf,
|
||||
LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal,
|
||||
lhsRealIsInfWithSignTimesrhsImag, fmf),
|
||||
fmf);
|
||||
|
||||
// Case 3: Finite numerator, infinite denominator.
|
||||
Value lhsRealFinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
|
||||
Value lhsImagFinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
|
||||
Value lhsRealFinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
|
||||
Value lhsImagFinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
|
||||
Value lhsFinite =
|
||||
rewriter.create<LLVM::AndOp>(loc, lhsRealFinite, lhsImagFinite);
|
||||
Value rhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
|
||||
Value rhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
|
||||
LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
|
||||
Value rhsRealInfinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
|
||||
Value rhsImagInfinite = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
|
||||
Value rhsInfinite =
|
||||
rewriter.create<LLVM::OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
|
||||
LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
|
||||
Value finiteNumInfiniteDenom =
|
||||
rewriter.create<LLVM::AndOp>(loc, lhsFinite, rhsInfinite);
|
||||
Value rhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
|
||||
loc, rewriter.create<LLVM::SelectOp>(loc, rhsRealInfinite, one, zero),
|
||||
rhsRe);
|
||||
Value rhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
|
||||
loc, rewriter.create<LLVM::SelectOp>(loc, rhsImagInfinite, one, zero),
|
||||
rhsIm);
|
||||
LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite);
|
||||
Value rhsRealIsInfWithSign = LLVM::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe);
|
||||
Value rhsImagIsInfWithSign = LLVM::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm);
|
||||
Value rhsRealIsInfWithSignTimeslhsReal =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
|
||||
Value rhsImagIsInfWithSignTimeslhsImag =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
|
||||
Value resultReal4 = rewriter.create<LLVM::FMulOp>(
|
||||
loc, zero,
|
||||
rewriter.create<LLVM::FAddOp>(loc, rhsRealIsInfWithSignTimeslhsReal,
|
||||
rhsImagIsInfWithSignTimeslhsImag, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
|
||||
Value resultReal4 = LLVM::FMulOp::create(
|
||||
rewriter, loc, zero,
|
||||
LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal,
|
||||
rhsImagIsInfWithSignTimeslhsImag, fmf),
|
||||
fmf);
|
||||
Value rhsRealIsInfWithSignTimeslhsImag =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
|
||||
Value rhsImagIsInfWithSignTimeslhsReal =
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
|
||||
Value resultImag4 = rewriter.create<LLVM::FMulOp>(
|
||||
loc, zero,
|
||||
rewriter.create<LLVM::FSubOp>(loc, rhsRealIsInfWithSignTimeslhsImag,
|
||||
rhsImagIsInfWithSignTimeslhsReal, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
|
||||
Value resultImag4 = LLVM::FMulOp::create(
|
||||
rewriter, loc, zero,
|
||||
LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag,
|
||||
rhsImagIsInfWithSignTimeslhsReal, fmf),
|
||||
fmf);
|
||||
|
||||
Value realAbsSmallerThanImagAbs = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
|
||||
Value resultReal5 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
|
||||
Value resultImag5 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
|
||||
Value resultRealSpecialCase3 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
|
||||
Value resultImagSpecialCase3 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
|
||||
Value resultRealSpecialCase2 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
|
||||
Value resultImagSpecialCase2 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
|
||||
Value resultRealSpecialCase1 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
|
||||
Value resultImagSpecialCase1 = rewriter.create<LLVM::SelectOp>(
|
||||
loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
|
||||
Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
|
||||
Value resultReal5 = LLVM::SelectOp::create(
|
||||
rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
|
||||
Value resultImag5 = LLVM::SelectOp::create(
|
||||
rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
|
||||
Value resultRealSpecialCase3 = LLVM::SelectOp::create(
|
||||
rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
|
||||
Value resultImagSpecialCase3 = LLVM::SelectOp::create(
|
||||
rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
|
||||
Value resultRealSpecialCase2 = LLVM::SelectOp::create(
|
||||
rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
|
||||
Value resultImagSpecialCase2 = LLVM::SelectOp::create(
|
||||
rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
|
||||
Value resultRealSpecialCase1 =
|
||||
LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
|
||||
infinityResultReal, resultRealSpecialCase2);
|
||||
Value resultImagSpecialCase1 =
|
||||
LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
|
||||
infinityResultImag, resultImagSpecialCase2);
|
||||
|
||||
Value resultRealIsNaN = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
|
||||
Value resultImagIsNaN = rewriter.create<LLVM::FCmpOp>(
|
||||
loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
|
||||
Value resultRealIsNaN = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
|
||||
Value resultImagIsNaN = LLVM::FCmpOp::create(
|
||||
rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
|
||||
Value resultIsNaN =
|
||||
rewriter.create<LLVM::AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
|
||||
LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
|
||||
|
||||
*resultRe = rewriter.create<LLVM::SelectOp>(
|
||||
loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
|
||||
*resultIm = rewriter.create<LLVM::SelectOp>(
|
||||
loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
|
||||
*resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
|
||||
resultRealSpecialCase1, resultReal5);
|
||||
*resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
|
||||
resultImagSpecialCase1, resultImag5);
|
||||
}
|
||||
|
||||
void mlir::complex::convertDivToStandardUsingRangeReduction(
|
||||
@@ -278,179 +284,187 @@ void mlir::complex::convertDivToStandardUsingRangeReduction(
|
||||
auto elementType = cast<FloatType>(rhsRe.getType());
|
||||
|
||||
Value rhsRealImagRatio =
|
||||
rewriter.create<arith::DivFOp>(loc, rhsRe, rhsIm, fmf);
|
||||
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
|
||||
loc, rhsIm,
|
||||
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
|
||||
Value realNumerator1 = rewriter.create<arith::AddFOp>(
|
||||
loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealImagRatio, fmf),
|
||||
lhsIm, fmf);
|
||||
Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
|
||||
rhsRealImagDenom, fmf);
|
||||
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
|
||||
loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealImagRatio, fmf),
|
||||
lhsRe, fmf);
|
||||
Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
|
||||
rhsRealImagDenom, fmf);
|
||||
arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
|
||||
Value rhsRealImagDenom = arith::AddFOp::create(
|
||||
rewriter, loc, rhsIm,
|
||||
arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
|
||||
Value realNumerator1 = arith::AddFOp::create(
|
||||
rewriter, loc,
|
||||
arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
|
||||
fmf);
|
||||
Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1,
|
||||
rhsRealImagDenom, fmf);
|
||||
Value imagNumerator1 = arith::SubFOp::create(
|
||||
rewriter, loc,
|
||||
arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
|
||||
fmf);
|
||||
Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1,
|
||||
rhsRealImagDenom, fmf);
|
||||
|
||||
Value rhsImagRealRatio =
|
||||
rewriter.create<arith::DivFOp>(loc, rhsIm, rhsRe, fmf);
|
||||
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
|
||||
loc, rhsRe,
|
||||
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
|
||||
Value realNumerator2 = rewriter.create<arith::AddFOp>(
|
||||
loc, lhsRe,
|
||||
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
|
||||
rhsImagRealDenom, fmf);
|
||||
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
|
||||
loc, lhsIm,
|
||||
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
|
||||
rhsImagRealDenom, fmf);
|
||||
arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
|
||||
Value rhsImagRealDenom = arith::AddFOp::create(
|
||||
rewriter, loc, rhsRe,
|
||||
arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
|
||||
Value realNumerator2 = arith::AddFOp::create(
|
||||
rewriter, loc, lhsRe,
|
||||
arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2,
|
||||
rhsImagRealDenom, fmf);
|
||||
Value imagNumerator2 = arith::SubFOp::create(
|
||||
rewriter, loc, lhsIm,
|
||||
arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
|
||||
Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2,
|
||||
rhsImagRealDenom, fmf);
|
||||
|
||||
// Consider corner cases.
|
||||
// Case 1. Zero denominator, numerator contains at most one NaN value.
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsRe, fmf);
|
||||
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
|
||||
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsIm, fmf);
|
||||
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
|
||||
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ORD, lhsRe, zero);
|
||||
Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ORD, lhsIm, zero);
|
||||
Value zero = arith::ConstantOp::create(rewriter, loc, elementType,
|
||||
rewriter.getZeroAttr(elementType));
|
||||
Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf);
|
||||
Value rhsRealIsZero = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
|
||||
Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf);
|
||||
Value rhsImagIsZero = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
|
||||
Value lhsRealIsNotNaN = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero);
|
||||
Value lhsImagIsNotNaN = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero);
|
||||
Value lhsContainsNotNaNValue =
|
||||
rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
|
||||
Value resultIsInfinity = rewriter.create<arith::AndIOp>(
|
||||
loc, lhsContainsNotNaNValue,
|
||||
rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
|
||||
Value inf = rewriter.create<arith::ConstantOp>(
|
||||
loc, elementType,
|
||||
arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
|
||||
Value resultIsInfinity = arith::AndIOp::create(
|
||||
rewriter, loc, lhsContainsNotNaNValue,
|
||||
arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
|
||||
Value inf = arith::ConstantOp::create(
|
||||
rewriter, loc, elementType,
|
||||
rewriter.getFloatAttr(elementType,
|
||||
APFloat::getInf(elementType.getFloatSemantics())));
|
||||
Value infWithSignOfRhsReal =
|
||||
rewriter.create<math::CopySignOp>(loc, inf, rhsRe);
|
||||
math::CopySignOp::create(rewriter, loc, inf, rhsRe);
|
||||
Value infinityResultReal =
|
||||
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsRe, fmf);
|
||||
arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf);
|
||||
Value infinityResultImag =
|
||||
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsIm, fmf);
|
||||
arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf);
|
||||
|
||||
// Case 2. Infinite numerator, finite denominator.
|
||||
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
|
||||
Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
|
||||
Value rhsRealFinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
|
||||
Value rhsImagFinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
|
||||
Value rhsFinite =
|
||||
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
|
||||
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsRe, fmf);
|
||||
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
|
||||
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsIm, fmf);
|
||||
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
|
||||
arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
|
||||
Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf);
|
||||
Value lhsRealInfinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
|
||||
Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf);
|
||||
Value lhsImagInfinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
|
||||
Value lhsInfinite =
|
||||
rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
|
||||
arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
|
||||
Value infNumFiniteDenom =
|
||||
rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc, elementType, rewriter.getFloatAttr(elementType, 1));
|
||||
Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
|
||||
loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
|
||||
arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite);
|
||||
Value one = arith::ConstantOp::create(rewriter, loc, elementType,
|
||||
rewriter.getFloatAttr(elementType, 1));
|
||||
Value lhsRealIsInfWithSign = math::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero),
|
||||
lhsRe);
|
||||
Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
|
||||
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
|
||||
Value lhsImagIsInfWithSign = math::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero),
|
||||
lhsIm);
|
||||
Value lhsRealIsInfWithSignTimesRhsReal =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
|
||||
arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
|
||||
Value lhsImagIsInfWithSignTimesRhsImag =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
|
||||
Value resultReal3 = rewriter.create<arith::MulFOp>(
|
||||
loc, inf,
|
||||
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
|
||||
lhsImagIsInfWithSignTimesRhsImag, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
|
||||
Value resultReal3 = arith::MulFOp::create(
|
||||
rewriter, loc, inf,
|
||||
arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal,
|
||||
lhsImagIsInfWithSignTimesRhsImag, fmf),
|
||||
fmf);
|
||||
Value lhsRealIsInfWithSignTimesRhsImag =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
|
||||
arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
|
||||
Value lhsImagIsInfWithSignTimesRhsReal =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
|
||||
Value resultImag3 = rewriter.create<arith::MulFOp>(
|
||||
loc, inf,
|
||||
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
|
||||
lhsRealIsInfWithSignTimesRhsImag, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
|
||||
Value resultImag3 = arith::MulFOp::create(
|
||||
rewriter, loc, inf,
|
||||
arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal,
|
||||
lhsRealIsInfWithSignTimesRhsImag, fmf),
|
||||
fmf);
|
||||
|
||||
// Case 3: Finite numerator, infinite denominator.
|
||||
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
|
||||
Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
|
||||
Value lhsRealFinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
|
||||
Value lhsImagFinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
|
||||
Value lhsFinite =
|
||||
rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
|
||||
Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
|
||||
Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
|
||||
arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
|
||||
Value rhsRealInfinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
|
||||
Value rhsImagInfinite = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
|
||||
Value rhsInfinite =
|
||||
rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
|
||||
arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
|
||||
Value finiteNumInfiniteDenom =
|
||||
rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
|
||||
Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
|
||||
loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
|
||||
arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite);
|
||||
Value rhsRealIsInfWithSign = math::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero),
|
||||
rhsRe);
|
||||
Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
|
||||
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
|
||||
Value rhsImagIsInfWithSign = math::CopySignOp::create(
|
||||
rewriter, loc,
|
||||
arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero),
|
||||
rhsIm);
|
||||
Value rhsRealIsInfWithSignTimesLhsReal =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
|
||||
arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
|
||||
Value rhsImagIsInfWithSignTimesLhsImag =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
|
||||
Value resultReal4 = rewriter.create<arith::MulFOp>(
|
||||
loc, zero,
|
||||
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
|
||||
rhsImagIsInfWithSignTimesLhsImag, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
|
||||
Value resultReal4 = arith::MulFOp::create(
|
||||
rewriter, loc, zero,
|
||||
arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal,
|
||||
rhsImagIsInfWithSignTimesLhsImag, fmf),
|
||||
fmf);
|
||||
Value rhsRealIsInfWithSignTimesLhsImag =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
|
||||
arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
|
||||
Value rhsImagIsInfWithSignTimesLhsReal =
|
||||
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
|
||||
Value resultImag4 = rewriter.create<arith::MulFOp>(
|
||||
loc, zero,
|
||||
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
|
||||
rhsImagIsInfWithSignTimesLhsReal, fmf),
|
||||
arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
|
||||
Value resultImag4 = arith::MulFOp::create(
|
||||
rewriter, loc, zero,
|
||||
arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag,
|
||||
rhsImagIsInfWithSignTimesLhsReal, fmf),
|
||||
fmf);
|
||||
|
||||
Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
|
||||
Value resultReal5 = rewriter.create<arith::SelectOp>(
|
||||
loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
|
||||
Value resultImag5 = rewriter.create<arith::SelectOp>(
|
||||
loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
|
||||
Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
|
||||
loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
|
||||
Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
|
||||
loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
|
||||
Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
|
||||
loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
|
||||
Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
|
||||
loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
|
||||
Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
|
||||
loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
|
||||
Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
|
||||
loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
|
||||
Value realAbsSmallerThanImagAbs = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
|
||||
Value resultReal5 = arith::SelectOp::create(
|
||||
rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
|
||||
Value resultImag5 = arith::SelectOp::create(
|
||||
rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
|
||||
Value resultRealSpecialCase3 = arith::SelectOp::create(
|
||||
rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
|
||||
Value resultImagSpecialCase3 = arith::SelectOp::create(
|
||||
rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
|
||||
Value resultRealSpecialCase2 = arith::SelectOp::create(
|
||||
rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
|
||||
Value resultImagSpecialCase2 = arith::SelectOp::create(
|
||||
rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
|
||||
Value resultRealSpecialCase1 =
|
||||
arith::SelectOp::create(rewriter, loc, resultIsInfinity,
|
||||
infinityResultReal, resultRealSpecialCase2);
|
||||
Value resultImagSpecialCase1 =
|
||||
arith::SelectOp::create(rewriter, loc, resultIsInfinity,
|
||||
infinityResultImag, resultImagSpecialCase2);
|
||||
|
||||
Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::UNO, resultReal5, zero);
|
||||
Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::UNO, resultImag5, zero);
|
||||
Value resultRealIsNaN = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero);
|
||||
Value resultImagIsNaN = arith::CmpFOp::create(
|
||||
rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero);
|
||||
Value resultIsNaN =
|
||||
rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
|
||||
arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
|
||||
|
||||
*resultRe = rewriter.create<arith::SelectOp>(
|
||||
loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
|
||||
*resultIm = rewriter.create<arith::SelectOp>(
|
||||
loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
|
||||
*resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN,
|
||||
resultRealSpecialCase1, resultReal5);
|
||||
*resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN,
|
||||
resultImagSpecialCase1, resultImag5);
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
|
||||
|
||||
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
|
||||
Location loc, Type type) {
|
||||
Value val = builder.create<LLVM::PoisonOp>(loc, type);
|
||||
Value val = LLVM::PoisonOp::create(builder, loc, type);
|
||||
return ComplexStructBuilder(val);
|
||||
}
|
||||
|
||||
@@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
|
||||
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
|
||||
op.getContext(),
|
||||
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
|
||||
Value sqNorm = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
|
||||
rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
|
||||
Value sqNorm = LLVM::FAddOp::create(
|
||||
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
|
||||
return success();
|
||||
@@ -191,10 +191,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
|
||||
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
|
||||
op.getContext(),
|
||||
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
|
||||
Value real =
|
||||
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
|
||||
Value imag =
|
||||
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
|
||||
Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
|
||||
arg.rhs.real(), fmf);
|
||||
Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
|
||||
arg.rhs.imag(), fmf);
|
||||
result.setReal(rewriter, loc, real);
|
||||
result.setImaginary(rewriter, loc, imag);
|
||||
|
||||
@@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
|
||||
Value lhsRe = arg.lhs.real();
|
||||
Value lhsIm = arg.lhs.imag();
|
||||
|
||||
Value real = rewriter.create<LLVM::FSubOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
|
||||
rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
|
||||
Value real = LLVM::FSubOp::create(
|
||||
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
|
||||
|
||||
Value imag = rewriter.create<LLVM::FAddOp>(
|
||||
loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
|
||||
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
|
||||
Value imag = LLVM::FAddOp::create(
|
||||
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
|
||||
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
|
||||
|
||||
result.setReal(rewriter, loc, real);
|
||||
result.setImaginary(rewriter, loc, imag);
|
||||
@@ -313,10 +313,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
|
||||
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
|
||||
op.getContext(),
|
||||
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
|
||||
Value real =
|
||||
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
|
||||
Value imag =
|
||||
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
|
||||
Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
|
||||
arg.rhs.real(), fmf);
|
||||
Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
|
||||
arg.rhs.imag(), fmf);
|
||||
result.setReal(rewriter, loc, real);
|
||||
result.setImaginary(rewriter, loc, imag);
|
||||
|
||||
|
||||
@@ -84,8 +84,8 @@ LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
|
||||
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
|
||||
auto opFunctionTy = FunctionType::get(
|
||||
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
|
||||
opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
|
||||
opFunctionTy);
|
||||
opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name,
|
||||
opFunctionTy);
|
||||
opFunc.setPrivate();
|
||||
}
|
||||
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
|
||||
|
||||
@@ -44,8 +44,8 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
|
||||
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
|
||||
auto funcTy = FunctionType::get(
|
||||
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
|
||||
opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
|
||||
funcTy);
|
||||
opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(),
|
||||
funcName, funcTy);
|
||||
opFunc.setPrivate();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
||||
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
||||
"abort", abortFuncTy);
|
||||
abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
|
||||
"abort", abortFuncTy);
|
||||
}
|
||||
rewriter.create<LLVM::CallOp>(loc, abortFunc, ValueRange());
|
||||
rewriter.create<LLVM::UnreachableOp>(loc);
|
||||
LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
|
||||
LLVM::UnreachableOp::create(rewriter, loc);
|
||||
} else {
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
|
||||
LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
|
||||
}
|
||||
|
||||
// Generate assertion test.
|
||||
|
||||
@@ -33,8 +33,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
|
||||
MutableArrayRef<Region> regions) {
|
||||
if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
|
||||
assert(regions.size() == 2);
|
||||
auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
|
||||
resultTypes, condBrOp.getCondition());
|
||||
auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(),
|
||||
resultTypes, condBrOp.getCondition());
|
||||
ifOp.getThenRegion().takeBody(regions[0]);
|
||||
ifOp.getElseRegion().takeBody(regions[1]);
|
||||
return ifOp.getOperation();
|
||||
@@ -43,8 +43,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
|
||||
if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
|
||||
// `getCFGSwitchValue` returns an i32 that we need to convert to index
|
||||
// fist.
|
||||
auto cast = builder.create<arith::IndexCastUIOp>(
|
||||
controlFlowCondOp->getLoc(), builder.getIndexType(),
|
||||
auto cast = arith::IndexCastUIOp::create(
|
||||
builder, controlFlowCondOp->getLoc(), builder.getIndexType(),
|
||||
switchOp.getFlag());
|
||||
SmallVector<int64_t> cases;
|
||||
if (auto caseValues = switchOp.getCaseValues())
|
||||
@@ -55,8 +55,9 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
|
||||
|
||||
assert(regions.size() == cases.size() + 1);
|
||||
|
||||
auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
|
||||
controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
|
||||
auto indexSwitchOp =
|
||||
scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(),
|
||||
resultTypes, cast, cases, cases.size());
|
||||
|
||||
indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
|
||||
for (auto &&[targetRegion, sourceRegion] :
|
||||
@@ -75,7 +76,7 @@ LogicalResult
|
||||
ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
|
||||
Location loc, OpBuilder &builder, Operation *branchRegionOp,
|
||||
Operation *replacedControlFlowOp, ValueRange results) {
|
||||
builder.create<scf::YieldOp>(loc, results);
|
||||
scf::YieldOp::create(builder, loc, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -84,23 +85,24 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
|
||||
OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
|
||||
Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
|
||||
Location loc = replacedOp->getLoc();
|
||||
auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
|
||||
loopVariablesInit);
|
||||
auto whileOp = scf::WhileOp::create(
|
||||
builder, loc, loopVariablesInit.getTypes(), loopVariablesInit);
|
||||
|
||||
whileOp.getBefore().takeBody(loopBody);
|
||||
|
||||
builder.setInsertionPointToEnd(&whileOp.getBefore().back());
|
||||
// `getCFGSwitchValue` returns a i32. We therefore need to truncate the
|
||||
// condition to i1 first. It is guaranteed to be either 0 or 1 already.
|
||||
builder.create<scf::ConditionOp>(
|
||||
loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
|
||||
scf::ConditionOp::create(
|
||||
builder, loc,
|
||||
arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition),
|
||||
loopVariablesNextIter);
|
||||
|
||||
Block *afterBlock = builder.createBlock(&whileOp.getAfter());
|
||||
afterBlock->addArguments(
|
||||
loopVariablesInit.getTypes(),
|
||||
SmallVector<Location>(loopVariablesInit.size(), loc));
|
||||
builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
|
||||
scf::YieldOp::create(builder, loc, afterBlock->getArguments());
|
||||
|
||||
return whileOp.getOperation();
|
||||
}
|
||||
@@ -108,8 +110,8 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
|
||||
Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
|
||||
OpBuilder &builder,
|
||||
unsigned int value) {
|
||||
return builder.create<arith::ConstantOp>(loc,
|
||||
builder.getI32IntegerAttr(value));
|
||||
return arith::ConstantOp::create(builder, loc,
|
||||
builder.getI32IntegerAttr(value));
|
||||
}
|
||||
|
||||
void ControlFlowToSCFTransformation::createCFGSwitchOp(
|
||||
@@ -117,15 +119,15 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp(
|
||||
ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
|
||||
ArrayRef<ValueRange> caseArguments, Block *defaultDest,
|
||||
ValueRange defaultArgs) {
|
||||
builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
|
||||
llvm::to_vector_of<int32_t>(caseValues),
|
||||
caseDestinations, caseArguments);
|
||||
cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs,
|
||||
llvm::to_vector_of<int32_t>(caseValues),
|
||||
caseDestinations, caseArguments);
|
||||
}
|
||||
|
||||
Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
|
||||
OpBuilder &builder,
|
||||
Type type) {
|
||||
return builder.create<ub::PoisonOp>(loc, type, nullptr);
|
||||
return ub::PoisonOp::create(builder, loc, type, nullptr);
|
||||
}
|
||||
|
||||
FailureOr<Operation *>
|
||||
|
||||
@@ -99,8 +99,8 @@ public:
|
||||
}
|
||||
|
||||
// Create the converted `emitc.func` op.
|
||||
emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(),
|
||||
emitc::FuncOp newFuncOp = emitc::FuncOp::create(
|
||||
rewriter, funcOp.getLoc(), funcOp.getName(),
|
||||
FunctionType::get(rewriter.getContext(),
|
||||
signatureConverter.getConvertedTypes(),
|
||||
resultType ? TypeRange(resultType) : TypeRange()));
|
||||
|
||||
@@ -115,8 +115,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
||||
SmallVector<NamedAttribute> attributes;
|
||||
filterFuncAttributes(funcOp, attributes);
|
||||
|
||||
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
auto wrapperFuncOp = LLVM::LLVMFuncOp::create(
|
||||
rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
|
||||
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
|
||||
propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
|
||||
@@ -129,14 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
||||
for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
|
||||
Value arg = wrapperFuncOp.getArgument(index + argOffset);
|
||||
if (auto memrefType = dyn_cast<MemRefType>(argType)) {
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(
|
||||
loc, typeConverter.convertType(memrefType), arg);
|
||||
Value loaded = LLVM::LoadOp::create(
|
||||
rewriter, loc, typeConverter.convertType(memrefType), arg);
|
||||
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
|
||||
continue;
|
||||
}
|
||||
if (isa<UnrankedMemRefType>(argType)) {
|
||||
Value loaded = rewriter.create<LLVM::LoadOp>(
|
||||
loc, typeConverter.convertType(argType), arg);
|
||||
Value loaded = LLVM::LoadOp::create(
|
||||
rewriter, loc, typeConverter.convertType(argType), arg);
|
||||
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
|
||||
continue;
|
||||
}
|
||||
@@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
||||
args.push_back(arg);
|
||||
}
|
||||
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
|
||||
auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args);
|
||||
|
||||
if (resultStructType) {
|
||||
rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
|
||||
wrapperFuncOp.getArgument(0));
|
||||
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
|
||||
LLVM::StoreOp::create(rewriter, loc, call.getResult(),
|
||||
wrapperFuncOp.getArgument(0));
|
||||
LLVM::ReturnOp::create(rewriter, loc, ValueRange{});
|
||||
} else {
|
||||
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
|
||||
LLVM::ReturnOp::create(rewriter, loc, call.getResults());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,8 +182,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
||||
filterFuncAttributes(funcOp, attributes);
|
||||
|
||||
// Create the auxiliary function.
|
||||
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
|
||||
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
auto wrapperFunc = LLVM::LLVMFuncOp::create(
|
||||
builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
||||
wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
|
||||
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
|
||||
propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
|
||||
@@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
||||
if (resultStructType) {
|
||||
// Allocate the struct on the stack and pass the pointer.
|
||||
Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
|
||||
Value one = builder.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.convertType(builder.getIndexType()),
|
||||
Value one = LLVM::ConstantOp::create(
|
||||
builder, loc, typeConverter.convertType(builder.getIndexType()),
|
||||
builder.getIntegerAttr(builder.getIndexType(), 1));
|
||||
Value result =
|
||||
builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
|
||||
LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one);
|
||||
args.push_back(result);
|
||||
}
|
||||
|
||||
@@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
||||
wrapperArgsRange.take_front(numToDrop));
|
||||
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
|
||||
Value one = builder.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.convertType(builder.getIndexType()),
|
||||
Value one = LLVM::ConstantOp::create(
|
||||
builder, loc, typeConverter.convertType(builder.getIndexType()),
|
||||
builder.getIntegerAttr(builder.getIndexType(), 1));
|
||||
Value allocated = builder.create<LLVM::AllocaOp>(
|
||||
loc, ptrTy, packed.getType(), one, /*alignment=*/0);
|
||||
builder.create<LLVM::StoreOp>(loc, packed, allocated);
|
||||
Value allocated = LLVM::AllocaOp::create(
|
||||
builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0);
|
||||
LLVM::StoreOp::create(builder, loc, packed, allocated);
|
||||
arg = allocated;
|
||||
} else {
|
||||
arg = wrapperArgsRange[0];
|
||||
@@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
||||
}
|
||||
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
|
||||
|
||||
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
|
||||
auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args);
|
||||
|
||||
if (resultStructType) {
|
||||
Value result =
|
||||
builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
|
||||
builder.create<LLVM::ReturnOp>(loc, result);
|
||||
LLVM::LoadOp::create(builder, loc, resultStructType, args.front());
|
||||
LLVM::ReturnOp::create(builder, loc, result);
|
||||
} else {
|
||||
builder.create<LLVM::ReturnOp>(loc, call.getResults());
|
||||
LLVM::ReturnOp::create(builder, loc, call.getResults());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,7 +283,7 @@ static void restoreByValRefArgumentType(
|
||||
Type resTy = typeConverter.convertType(
|
||||
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
|
||||
|
||||
Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
|
||||
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
|
||||
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
|
||||
}
|
||||
}
|
||||
@@ -357,8 +357,8 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
|
||||
symbolTable.remove(funcOp);
|
||||
}
|
||||
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
auto newFuncOp = LLVM::LLVMFuncOp::create(
|
||||
rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
|
||||
attributes);
|
||||
|
||||
@@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
|
||||
return rewriter.notifyMatchFailure(op, "failed to convert result type");
|
||||
|
||||
auto newOp =
|
||||
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
|
||||
LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue());
|
||||
for (const NamedAttribute &attr : op->getAttrs()) {
|
||||
if (attr.getName().strref() == "value")
|
||||
continue;
|
||||
@@ -556,9 +556,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
||||
auto promoted = this->getTypeConverter()->promoteOperands(
|
||||
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
|
||||
adaptor.getOperands(), rewriter, useBarePtrCallConv);
|
||||
auto newOp = rewriter.create<LLVM::CallOp>(
|
||||
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
||||
promoted, callOp->getAttrs());
|
||||
auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(),
|
||||
packedResult ? TypeRange(packedResult)
|
||||
: TypeRange(),
|
||||
promoted, callOp->getAttrs());
|
||||
|
||||
newOp.getProperties().operandSegmentSizes = {
|
||||
static_cast<int32_t>(promoted.size()), 0};
|
||||
@@ -573,8 +574,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
||||
// Extract individual results from the structure and return them as list.
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
callOp.getLoc(), newOp->getResult(0), i));
|
||||
results.push_back(LLVM::ExtractValueOp::create(
|
||||
rewriter, callOp.getLoc(), newOp->getResult(0), i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -726,9 +727,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
||||
return rewriter.notifyMatchFailure(op, "could not convert result types");
|
||||
}
|
||||
|
||||
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
|
||||
Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
|
||||
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
|
||||
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
|
||||
packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
|
||||
op->getAttrs());
|
||||
|
||||
@@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
|
||||
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(moduleOp.getBody());
|
||||
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
|
||||
ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(moduleOp.getBody());
|
||||
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
|
||||
return b.create<LLVM::GlobalOp>(loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
name, attr, alignment, addrSpace);
|
||||
return LLVM::GlobalOp::create(b, loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
name, attr, alignment, addrSpace);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
@@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
gpuFuncOp.getWorkgroupAttributionAttr(
|
||||
idx, LLVM::LLVMDialect::getAlignAttrName())))
|
||||
alignment = alignAttr.getInt();
|
||||
auto globalOp = rewriter.create<LLVM::GlobalOp>(
|
||||
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
|
||||
auto globalOp = LLVM::GlobalOp::create(
|
||||
rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
|
||||
LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
|
||||
workgroupAddrSpace);
|
||||
workgroupBuffers.push_back(globalOp);
|
||||
@@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
|
||||
? kernelCallingConvention
|
||||
: nonKernelCallingConvention;
|
||||
auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
|
||||
auto llvmFuncOp = LLVM::LLVMFuncOp::create(
|
||||
rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
|
||||
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
|
||||
/*comdat=*/nullptr, attributes);
|
||||
|
||||
@@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
|
||||
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
|
||||
global.getAddrSpace());
|
||||
Value address = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc, ptrType, global.getSymNameAttr());
|
||||
Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
|
||||
global.getSymNameAttr());
|
||||
Value memory =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
|
||||
address, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
|
||||
address, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
|
||||
// Build a memref descriptor pointing to the buffer to plug with the
|
||||
// existing memref infrastructure. This may use more registers than
|
||||
@@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
Type elementType = typeConverter->convertType(type.getElementType());
|
||||
auto ptrType =
|
||||
LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
|
||||
Value numElements = rewriter.create<LLVM::ConstantOp>(
|
||||
gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
|
||||
Value numElements = LLVM::ConstantOp::create(
|
||||
rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
|
||||
uint64_t alignment = 0;
|
||||
if (auto alignAttr =
|
||||
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
|
||||
idx, LLVM::LLVMDialect::getAlignAttrName())))
|
||||
alignment = alignAttr.getInt();
|
||||
Value allocated = rewriter.create<LLVM::AllocaOp>(
|
||||
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
|
||||
Value allocated =
|
||||
LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
|
||||
elementType, numElements, alignment);
|
||||
Value descr = MemRefDescriptor::fromStaticShape(
|
||||
rewriter, loc, *getTypeConverter(), type, allocated);
|
||||
signatureConversion.remapInput(
|
||||
@@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
{llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
|
||||
/// Start the printf hostcall
|
||||
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
|
||||
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
|
||||
Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
|
||||
auto printfBeginCall =
|
||||
LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
|
||||
Value printfDesc = printfBeginCall.getResult();
|
||||
|
||||
// Create the global op or find an existing one.
|
||||
@@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
|
||||
|
||||
// Get a pointer to the format string's first element and pass it to printf()
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc,
|
||||
Value globalPtr = LLVM::AddressOfOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
|
||||
global.getSymNameAttr());
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
Value stringLen = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
|
||||
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
Value stringLen = LLVM::ConstantOp::create(
|
||||
rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
|
||||
|
||||
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
|
||||
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
|
||||
Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
|
||||
Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
|
||||
|
||||
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
|
||||
loc, ocklAppendStringN,
|
||||
auto appendFormatCall = LLVM::CallOp::create(
|
||||
rewriter, loc, ocklAppendStringN,
|
||||
ValueRange{printfDesc, stringStart, stringLen,
|
||||
adaptor.getArgs().empty() ? oneI32 : zeroI32});
|
||||
printfDesc = appendFormatCall.getResult();
|
||||
@@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
|
||||
arguments.push_back(printfDesc);
|
||||
arguments.push_back(
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
|
||||
LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
|
||||
for (size_t i = group; i < bound; ++i) {
|
||||
Value arg = adaptor.getArgs()[i];
|
||||
if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
|
||||
if (!floatType.isF64())
|
||||
arg = rewriter.create<LLVM::FPExtOp>(
|
||||
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
|
||||
arg = LLVM::FPExtOp::create(
|
||||
rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
|
||||
arg);
|
||||
arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
|
||||
}
|
||||
if (arg.getType().getIntOrFloatBitWidth() != 64)
|
||||
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
|
||||
arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
|
||||
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
@@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
|
||||
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
|
||||
arguments.push_back(isLast);
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
|
||||
auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
|
||||
printfDesc = call.getResult();
|
||||
}
|
||||
rewriter.eraseOp(gpuPrintfOp);
|
||||
@@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
|
||||
/*alignment=*/0, addressSpace);
|
||||
|
||||
// Get a pointer to the format string's first element
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc,
|
||||
Value globalPtr = LLVM::AddressOfOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
|
||||
global.getSymNameAttr());
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
|
||||
// Construct arguments and function call
|
||||
auto argsRange = adaptor.getArgs();
|
||||
@@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
|
||||
printfArgs.push_back(stringStart);
|
||||
printfArgs.append(argsRange.begin(), argsRange.end());
|
||||
|
||||
rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
|
||||
LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
|
||||
rewriter.eraseOp(gpuPrintfOp);
|
||||
return success();
|
||||
}
|
||||
@@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
|
||||
"printfFormat_", adaptor.getFormat());
|
||||
|
||||
// Get a pointer to the format string's first element
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
||||
Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
SmallVector<Type> types;
|
||||
SmallVector<Value> args;
|
||||
// Promote and pack the arguments into a stack allocation.
|
||||
@@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
|
||||
assert(type.isIntOrFloat());
|
||||
if (isa<FloatType>(type)) {
|
||||
type = rewriter.getF64Type();
|
||||
promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
|
||||
promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
|
||||
}
|
||||
types.push_back(type);
|
||||
args.push_back(promotedArg);
|
||||
}
|
||||
Type structType =
|
||||
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
|
||||
rewriter.getIndexAttr(1));
|
||||
Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
|
||||
rewriter.getIndexAttr(1));
|
||||
Value tempAlloc =
|
||||
rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
|
||||
/*alignment=*/0);
|
||||
LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
|
||||
/*alignment=*/0);
|
||||
for (auto [index, arg] : llvm::enumerate(args)) {
|
||||
Value ptr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, ptrType, structType, tempAlloc,
|
||||
Value ptr = LLVM::GEPOp::create(
|
||||
rewriter, loc, ptrType, structType, tempAlloc,
|
||||
ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
|
||||
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
|
||||
LLVM::StoreOp::create(rewriter, loc, arg, ptr);
|
||||
}
|
||||
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
|
||||
|
||||
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
|
||||
LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
|
||||
rewriter.eraseOp(gpuPrintfOp);
|
||||
return success();
|
||||
}
|
||||
@@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
|
||||
TypeRange operandTypes(operands);
|
||||
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
|
||||
Location loc = op->getLoc();
|
||||
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
|
||||
Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
|
||||
Type indexType = converter.convertType(rewriter.getIndexType());
|
||||
StringAttr name = op->getName().getIdentifier();
|
||||
Type elementType = vectorType.getElementType();
|
||||
|
||||
for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
|
||||
Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
|
||||
Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
|
||||
auto extractElement = [&](Value operand) -> Value {
|
||||
if (!isa<VectorType>(operand.getType()))
|
||||
return operand;
|
||||
return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
|
||||
return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
|
||||
};
|
||||
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
|
||||
Operation *scalarOp =
|
||||
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
|
||||
result = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, result, scalarOp->getResult(0), index);
|
||||
result = LLVM::InsertElementOp::create(rewriter, loc, result,
|
||||
scalarOp->getResult(0), index);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
|
||||
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
|
||||
typeConverter->convertType(memrefType.getElementType()), 0);
|
||||
|
||||
return rewriter.create<LLVM::GlobalOp>(
|
||||
op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
|
||||
LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
|
||||
addressSpace.value());
|
||||
return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
|
||||
/*isConstant=*/false, LLVM::Linkage::Internal,
|
||||
symName, /*value=*/Attribute(), alignmentByte,
|
||||
addressSpace.value());
|
||||
}
|
||||
|
||||
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
|
||||
@@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
|
||||
// Step 3. Get address of the global symbol
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
|
||||
auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
|
||||
Type baseType = basePtr->getResultTypes().front();
|
||||
|
||||
// Step 4. Generate GEP using offsets
|
||||
SmallVector<LLVM::GEPArg> gepArgs = {0};
|
||||
Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
|
||||
basePtr, gepArgs);
|
||||
Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
|
||||
basePtr, gepArgs);
|
||||
// Step 5. Create a memref descriptor
|
||||
SmallVector<Value> shape, strides;
|
||||
Value sizeBytes;
|
||||
@@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
|
||||
return rewriter.notifyMatchFailure(op, "could not convert result types");
|
||||
}
|
||||
|
||||
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
|
||||
Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
|
||||
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
|
||||
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
|
||||
packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
|
||||
op->getAttrs());
|
||||
|
||||
@@ -79,8 +79,8 @@ protected:
|
||||
uint64_t rank = type.getRank();
|
||||
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
|
||||
for (unsigned i = 1; i < rank; i++)
|
||||
numElements = rewriter.create<LLVM::MulOp>(
|
||||
loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
|
||||
numElements = LLVM::MulOp::create(rewriter, loc, numElements,
|
||||
desc.size(rewriter, loc, /*pos=*/i));
|
||||
return numElements;
|
||||
}
|
||||
|
||||
@@ -582,7 +582,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
|
||||
return OpBuilder::atBlockEnd(module.getBody())
|
||||
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
|
||||
}();
|
||||
return builder.create<LLVM::CallOp>(loc, function, arguments);
|
||||
return LLVM::CallOp::create(builder, loc, function, arguments);
|
||||
}
|
||||
|
||||
// Corresponding to cusparseIndexType_t defined in cusparse.h.
|
||||
@@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
|
||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||
// descriptor.
|
||||
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
|
||||
auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
|
||||
Value stream = adaptor.getAsyncDependencies().empty()
|
||||
? nullPtr
|
||||
: adaptor.getAsyncDependencies().front();
|
||||
|
||||
auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
|
||||
auto isHostShared = mlir::LLVM::ConstantOp::create(
|
||||
rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
|
||||
|
||||
Value allocatedPtr =
|
||||
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
|
||||
@@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
|
||||
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
|
||||
static_cast<uint64_t>(memrefTy.getNumElements());
|
||||
|
||||
Value sizeArg = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, getIndexType(), rewriter.getIndexAttr(staticSize));
|
||||
Value sizeArg = LLVM::ConstantOp::create(
|
||||
rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
|
||||
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
|
||||
llvmArgumentsWithSizes.push_back(sizeArg);
|
||||
}
|
||||
@@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
|
||||
gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
|
||||
adaptor.getClusterSizeZ()};
|
||||
}
|
||||
rewriter.create<gpu::LaunchFuncOp>(
|
||||
launchOp.getLoc(), launchOp.getKernelAttr(),
|
||||
gpu::LaunchFuncOp::create(
|
||||
rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
|
||||
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
|
||||
adaptor.getGridSizeZ()},
|
||||
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
|
||||
@@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc,
|
||||
const LLVMTypeConverter &typeConverter) {
|
||||
auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
|
||||
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
|
||||
sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
|
||||
loc,
|
||||
sourcePtr = LLVM::AddrSpaceCastOp::create(
|
||||
rewriter, loc,
|
||||
LLVM::LLVMPointerType::get(rewriter.getContext(),
|
||||
destinationType.getAddressSpace()),
|
||||
sourcePtr);
|
||||
@@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
|
||||
|
||||
Type elementPtrType = getElementPtrType(memRefType);
|
||||
Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
|
||||
Value gepPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, elementPtrType,
|
||||
Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
|
||||
Value gepPtr = LLVM::GEPOp::create(
|
||||
rewriter, loc, elementPtrType,
|
||||
typeConverter->convertType(memRefType.getElementType()), nullPtr,
|
||||
numElements);
|
||||
auto sizeBytes =
|
||||
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
|
||||
LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
|
||||
|
||||
auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
|
||||
srcDesc.alignedPtr(rewriter, loc),
|
||||
@@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
|
||||
|
||||
auto value =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
|
||||
LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
|
||||
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
|
||||
dstDesc.alignedPtr(rewriter, loc),
|
||||
*getTypeConverter());
|
||||
@@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
template <typename T>
|
||||
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
|
||||
Type llvmInt32Type = builder.getIntegerType(32);
|
||||
return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
static_cast<int32_t>(tValue));
|
||||
return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
|
||||
static_cast<int32_t>(tValue));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
|
||||
Type llvmFloat32Type = builder.getF32Type();
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, llvmFloat32Type,
|
||||
return LLVM::ConstantOp::create(
|
||||
builder, loc, llvmFloat32Type,
|
||||
builder.getF32FloatAttr(static_cast<float>(tValue)));
|
||||
}
|
||||
|
||||
@@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
// the dnmat is used with spmat with 2:4 sparsity
|
||||
if (dims.size() == 2) {
|
||||
if (isSpMMCusparseLtOp(op.getDnTensor())) {
|
||||
auto handleSz = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, getIndexType(), rewriter.getIndexAttr(11032));
|
||||
handle = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
|
||||
handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
|
||||
auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(11032));
|
||||
handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
|
||||
llvmInt8Type, handleSz, /*alignment=*/16);
|
||||
handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
|
||||
|
||||
createLtDnMatCallBuilder
|
||||
.create(loc, rewriter,
|
||||
@@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
|
||||
|
||||
// CUDA runner asserts the size is 44104 bytes.
|
||||
auto handleSz = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, getIndexType(), rewriter.getIndexAttr(44104));
|
||||
Value handle = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
|
||||
handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
|
||||
auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(44104));
|
||||
Value handle = LLVM::AllocaOp::create(
|
||||
rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
|
||||
handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
|
||||
|
||||
create2To4SpMatCallBuilder
|
||||
.create(loc, rewriter,
|
||||
@@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
|
||||
auto computeType = genConstInt32From(
|
||||
rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
|
||||
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||
rewriter.getIndexAttr(3));
|
||||
auto bufferSize = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
|
||||
auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(3));
|
||||
auto bufferSize =
|
||||
LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
|
||||
three, /*alignment=*/16);
|
||||
createCuSparseLtSpMMBufferSizeBuilder
|
||||
.create(loc, rewriter,
|
||||
{bufferSize, modeA, modeB, adaptor.getSpmatA(),
|
||||
@@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
pruneFlag, stream})
|
||||
.getResult();
|
||||
|
||||
auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmPointerType, llvmPointerType, bufferSize,
|
||||
ValueRange{rewriter.create<LLVM::ConstantOp>(
|
||||
loc, getIndexType(), rewriter.getIndexAttr(1))});
|
||||
auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmPointerType, llvmPointerType, bufferSize,
|
||||
ValueRange{rewriter.create<LLVM::ConstantOp>(
|
||||
loc, getIndexType(), rewriter.getIndexAttr(2))});
|
||||
auto bufferSizePtr1 = LLVM::GEPOp::create(
|
||||
rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
|
||||
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(1))});
|
||||
auto bufferSizePtr2 = LLVM::GEPOp::create(
|
||||
rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
|
||||
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(2))});
|
||||
auto bufferSize0 =
|
||||
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
|
||||
LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
|
||||
auto bufferSize1 =
|
||||
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
|
||||
LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
|
||||
auto bufferSize2 =
|
||||
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
|
||||
LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
|
||||
|
||||
rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
|
||||
} else {
|
||||
@@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
Location loc = op.getLoc();
|
||||
auto stream = adaptor.getAsyncDependencies().front();
|
||||
|
||||
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||
rewriter.getIndexAttr(3));
|
||||
auto buffer = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
|
||||
auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(3));
|
||||
auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
|
||||
llvmInt64Type, three, /*alignment=*/16);
|
||||
|
||||
auto rowsPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmPointerType, llvmPointerType, buffer,
|
||||
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||
rewriter.getIndexAttr(0))});
|
||||
auto colsPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmPointerType, llvmPointerType, buffer,
|
||||
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||
rewriter.getIndexAttr(1))});
|
||||
auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmPointerType, llvmPointerType, buffer,
|
||||
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
|
||||
rewriter.getIndexAttr(2))});
|
||||
auto rowsPtr = LLVM::GEPOp::create(
|
||||
rewriter, loc, llvmPointerType, llvmPointerType, buffer,
|
||||
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(0))});
|
||||
auto colsPtr = LLVM::GEPOp::create(
|
||||
rewriter, loc, llvmPointerType, llvmPointerType, buffer,
|
||||
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(1))});
|
||||
auto nnzsPtr = LLVM::GEPOp::create(
|
||||
rewriter, loc, llvmPointerType, llvmPointerType, buffer,
|
||||
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
||||
rewriter.getIndexAttr(2))});
|
||||
createSpMatGetSizeBuilder.create(
|
||||
loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
|
||||
auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
|
||||
auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
|
||||
auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
|
||||
auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
|
||||
auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
|
||||
auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
|
||||
|
||||
rewriter.replaceOp(op, {rows, cols, nnzs, stream});
|
||||
return success();
|
||||
|
||||
@@ -59,13 +59,13 @@ public:
|
||||
Operation *newOp;
|
||||
switch (op.getDimension()) {
|
||||
case gpu::Dimension::x:
|
||||
newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
|
||||
newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32));
|
||||
break;
|
||||
case gpu::Dimension::y:
|
||||
newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
|
||||
newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32));
|
||||
break;
|
||||
case gpu::Dimension::z:
|
||||
newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
|
||||
newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -124,11 +124,13 @@ public:
|
||||
rewriter.getContext(), 32, min, max));
|
||||
}
|
||||
if (indexBitwidth > 32) {
|
||||
newOp = rewriter.create<LLVM::SExtOp>(
|
||||
loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
|
||||
newOp = LLVM::SExtOp::create(rewriter, loc,
|
||||
IntegerType::get(context, indexBitwidth),
|
||||
newOp->getResult(0));
|
||||
} else if (indexBitwidth < 32) {
|
||||
newOp = rewriter.create<LLVM::TruncOp>(
|
||||
loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
|
||||
newOp = LLVM::TruncOp::create(rewriter, loc,
|
||||
IntegerType::get(context, indexBitwidth),
|
||||
newOp->getResult(0));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
|
||||
@@ -103,7 +103,7 @@ public:
|
||||
|
||||
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
|
||||
auto callOp =
|
||||
rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
|
||||
LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
|
||||
|
||||
if (resultType == adaptor.getOperands().front().getType()) {
|
||||
rewriter.replaceOp(op, {callOp.getResult()});
|
||||
@@ -115,19 +115,20 @@ public:
|
||||
// there is no guarantee of a specific value being used to indicate true,
|
||||
// compare for inequality with zero (rather than truncate or shift).
|
||||
if (isResultBool) {
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerType(32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
Value truncated = rewriter.create<LLVM::ICmpOp>(
|
||||
op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
|
||||
Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
|
||||
rewriter.getIntegerType(32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
Value truncated =
|
||||
LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
|
||||
callOp.getResult(), zero);
|
||||
rewriter.replaceOp(op, {truncated});
|
||||
return success();
|
||||
}
|
||||
|
||||
assert(callOp.getResult().getType().isF32() &&
|
||||
"only f32 types are supposed to be truncated back");
|
||||
Value truncated = rewriter.create<LLVM::FPTruncOp>(
|
||||
op->getLoc(), adaptor.getOperands().front().getType(),
|
||||
Value truncated = LLVM::FPTruncOp::create(
|
||||
rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
|
||||
callOp.getResult());
|
||||
rewriter.replaceOp(op, {truncated});
|
||||
return success();
|
||||
@@ -142,8 +143,9 @@ public:
|
||||
if (!f16Func.empty() && isa<Float16Type>(type))
|
||||
return operand;
|
||||
|
||||
return rewriter.create<LLVM::FPExtOp>(
|
||||
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
|
||||
return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
|
||||
Float32Type::get(rewriter.getContext()),
|
||||
operand);
|
||||
}
|
||||
|
||||
Type getFunctionType(Type resultType, ValueRange operands) const {
|
||||
@@ -169,7 +171,7 @@ public:
|
||||
// location as debug info metadata inside of a function cannot be used
|
||||
// outside of that function.
|
||||
auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
|
||||
return b.create<LLVMFuncOp>(globalloc, funcName, funcType);
|
||||
return LLVMFuncOp::create(b, globalloc, funcName, funcType);
|
||||
}
|
||||
|
||||
StringRef getFunctionName(Type type, SourceOp op) const {
|
||||
|
||||
@@ -54,8 +54,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
|
||||
SymbolTable::lookupSymbolIn(symbolTable, name));
|
||||
if (!func) {
|
||||
OpBuilder b(symbolTable->getRegion(0));
|
||||
func = b.create<LLVM::LLVMFuncOp>(
|
||||
symbolTable->getLoc(), name,
|
||||
func = LLVM::LLVMFuncOp::create(
|
||||
b, symbolTable->getLoc(), name,
|
||||
LLVM::LLVMFunctionType::get(resultType, paramTypes));
|
||||
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
|
||||
func.setNoUnwind(true);
|
||||
@@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
LLVM::LLVMFuncOp func,
|
||||
ValueRange args) {
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
|
||||
auto call = LLVM::CallOp::create(rewriter, loc, func, args);
|
||||
call.setCConv(func.getCConv());
|
||||
call.setConvergentAttr(func.getConvergentAttr());
|
||||
call.setNoUnwindAttr(func.getNoUnwindAttr());
|
||||
@@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
|
||||
constexpr int64_t localMemFenceFlag = 1;
|
||||
Location loc = op->getLoc();
|
||||
Value flag =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
|
||||
LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag);
|
||||
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
|
||||
return success();
|
||||
}
|
||||
@@ -162,8 +162,8 @@ struct LaunchConfigConversion : ConvertToLLVMPattern {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
gpu::Dimension dim = getDimension(op);
|
||||
Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
|
||||
static_cast<int64_t>(dim));
|
||||
Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
|
||||
static_cast<int64_t>(dim));
|
||||
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
|
||||
return success();
|
||||
}
|
||||
@@ -291,13 +291,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
return TypeSwitch<Type, Value>(oldVal.getType())
|
||||
.Case([&](BFloat16Type) {
|
||||
return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
|
||||
oldVal);
|
||||
return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
|
||||
oldVal);
|
||||
})
|
||||
.Case([&](IntegerType intTy) -> Value {
|
||||
if (intTy.getWidth() == 1)
|
||||
return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
|
||||
oldVal);
|
||||
return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
|
||||
oldVal);
|
||||
return oldVal;
|
||||
})
|
||||
.Default(oldVal);
|
||||
@@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
return TypeSwitch<Type, Value>(newTy)
|
||||
.Case([&](BFloat16Type) {
|
||||
return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
|
||||
return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
|
||||
})
|
||||
.Case([&](IntegerType intTy) -> Value {
|
||||
if (intTy.getWidth() == 1)
|
||||
return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
|
||||
return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
|
||||
return oldVal;
|
||||
})
|
||||
.Default(oldVal);
|
||||
@@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
|
||||
|
||||
Value trueVal =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
|
||||
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true);
|
||||
rewriter.replaceOp(op, {resultOrConversion, trueVal});
|
||||
return success();
|
||||
}
|
||||
@@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
|
||||
if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
|
||||
return failure();
|
||||
}
|
||||
result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
|
||||
result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
@@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
|
||||
Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
|
||||
Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
|
||||
|
||||
auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
|
||||
mode.value(), offset);
|
||||
auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
|
||||
op.getValue(), mode.value(), offset);
|
||||
|
||||
rewriter.replaceOp(op, reduxOp->getResult(0));
|
||||
return success();
|
||||
@@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
|
||||
auto predTy = IntegerType::get(rewriter.getContext(), 1);
|
||||
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
|
||||
Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
|
||||
Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
|
||||
Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
|
||||
loc, int32Type, thirtyTwo, adaptor.getWidth());
|
||||
Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
|
||||
Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
|
||||
Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
|
||||
Value numLeadInactiveLane = LLVM::SubOp::create(
|
||||
rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
|
||||
// Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
|
||||
Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
|
||||
numLeadInactiveLane);
|
||||
Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
|
||||
numLeadInactiveLane);
|
||||
Value maskAndClamp;
|
||||
if (op.getMode() == gpu::ShuffleMode::UP) {
|
||||
// Clamp lane: `32 - activeWidth`
|
||||
maskAndClamp = numLeadInactiveLane;
|
||||
} else {
|
||||
// Clamp lane: `activeWidth - 1`
|
||||
maskAndClamp =
|
||||
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
|
||||
maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
|
||||
adaptor.getWidth(), one);
|
||||
}
|
||||
|
||||
bool predIsUsed = !op->getResult(1).use_empty();
|
||||
@@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
|
||||
{valueTy, predTy});
|
||||
}
|
||||
Value shfl = rewriter.create<NVVM::ShflOp>(
|
||||
loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
|
||||
maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
|
||||
Value shfl = NVVM::ShflOp::create(
|
||||
rewriter, loc, resultTy, activeMask, adaptor.getValue(),
|
||||
adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
|
||||
returnValueAndIsValidAttr);
|
||||
if (predIsUsed) {
|
||||
Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
|
||||
Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
|
||||
Value isActiveSrcLane =
|
||||
rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
|
||||
LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
|
||||
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
|
||||
} else {
|
||||
rewriter.replaceOp(op, {shfl, nullptr});
|
||||
@@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
|
||||
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
|
||||
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
|
||||
Value newOp =
|
||||
rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
|
||||
NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
|
||||
// Truncate or extend the result depending on the index bitwidth specified
|
||||
// by the LLVMTypeConverter options.
|
||||
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
|
||||
if (indexBitwidth > 32) {
|
||||
newOp = rewriter.create<LLVM::SExtOp>(
|
||||
loc, IntegerType::get(context, indexBitwidth), newOp);
|
||||
newOp = LLVM::SExtOp::create(
|
||||
rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
|
||||
} else if (indexBitwidth < 32) {
|
||||
newOp = rewriter.create<LLVM::TruncOp>(
|
||||
loc, IntegerType::get(context, indexBitwidth), newOp);
|
||||
newOp = LLVM::TruncOp::create(
|
||||
rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
|
||||
}
|
||||
rewriter.replaceOp(op, {newOp});
|
||||
return success();
|
||||
@@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering
|
||||
Block *afterBlock =
|
||||
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
|
||||
rewriter.setInsertionPointToEnd(beforeBlock);
|
||||
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
|
||||
assertBlock);
|
||||
cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
|
||||
assertBlock);
|
||||
rewriter.setInsertionPointToEnd(assertBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, afterBlock);
|
||||
cf::BranchOp::create(rewriter, loc, afterBlock);
|
||||
|
||||
// Continue cf.assert lowering.
|
||||
rewriter.setInsertionPoint(assertOp);
|
||||
@@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering
|
||||
// Create constants.
|
||||
auto getGlobal = [&](LLVM::GlobalOp global) {
|
||||
// Get a pointer to the format string's first element.
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
|
||||
Value globalPtr = LLVM::AddressOfOp::create(
|
||||
rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
|
||||
global.getSymNameAttr());
|
||||
Value start =
|
||||
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
|
||||
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
return start;
|
||||
};
|
||||
Value assertMessage = getGlobal(getOrCreateStringConstant(
|
||||
@@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering
|
||||
Value assertFunc = getGlobal(getOrCreateStringConstant(
|
||||
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
|
||||
Value assertLine =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
|
||||
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
|
||||
LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
|
||||
Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
|
||||
|
||||
// Insert function call to __assertfail.
|
||||
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
|
||||
|
||||
@@ -126,8 +126,8 @@ struct WmmaLoadOpToNVVMLowering
|
||||
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
|
||||
adaptor.getSrcMemref(), adaptor.getIndices());
|
||||
|
||||
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
Value leadingDim = LLVM::ConstantOp::create(
|
||||
rewriter, loc, rewriter.getI32Type(),
|
||||
subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
|
||||
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
|
||||
op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
|
||||
@@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering
|
||||
auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
|
||||
for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
|
||||
Value toUse =
|
||||
rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
|
||||
LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
|
||||
storeOpOperands.push_back(toUse);
|
||||
}
|
||||
|
||||
@@ -181,8 +181,8 @@ struct WmmaStoreOpToNVVMLowering
|
||||
rewriter, loc,
|
||||
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
|
||||
adaptor.getDstMemref(), adaptor.getIndices());
|
||||
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
Value leadingDim = LLVM::ConstantOp::create(
|
||||
rewriter, loc, rewriter.getI32Type(),
|
||||
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
|
||||
rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
|
||||
op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
|
||||
@@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering
|
||||
auto unpackOp = [&](Value operand) {
|
||||
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
|
||||
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
|
||||
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
|
||||
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
|
||||
unpackedOps.push_back(toUse);
|
||||
}
|
||||
};
|
||||
@@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering
|
||||
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
|
||||
// If the element type is a vector create a vector from the operand.
|
||||
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
|
||||
Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
|
||||
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
|
||||
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
|
||||
Value idx = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(), vecEl);
|
||||
vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
|
||||
cst, idx);
|
||||
Value idx = LLVM::ConstantOp::create(rewriter, loc,
|
||||
rewriter.getI32Type(), vecEl);
|
||||
vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
|
||||
cst, idx);
|
||||
}
|
||||
cst = vecCst;
|
||||
}
|
||||
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
|
||||
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
|
||||
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
|
||||
matrixStruct =
|
||||
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
|
||||
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
|
||||
}
|
||||
rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
|
||||
return success();
|
||||
@@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
|
||||
Type i1Type = builder.getI1Type();
|
||||
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
|
||||
i1Type = VectorType::get(vecType.getShape(), i1Type);
|
||||
Value cmp = builder.create<LLVM::FCmpOp>(
|
||||
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
|
||||
lhs, rhs);
|
||||
Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
|
||||
Value isNan = builder.create<LLVM::FCmpOp>(
|
||||
loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
|
||||
Value nan = builder.create<LLVM::ConstantOp>(
|
||||
loc, lhs.getType(),
|
||||
Value cmp = LLVM::FCmpOp::create(
|
||||
builder, loc, i1Type,
|
||||
isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs);
|
||||
Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs);
|
||||
Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
|
||||
LLVM::FCmpPredicate::uno, lhs, rhs);
|
||||
Value nan = LLVM::ConstantOp::create(
|
||||
builder, loc, lhs.getType(),
|
||||
builder.getFloatAttr(floatType,
|
||||
APFloat::getQNaN(floatType.getFloatSemantics())));
|
||||
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
|
||||
return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
|
||||
}
|
||||
|
||||
static Value createScalarOp(OpBuilder &builder, Location loc,
|
||||
@@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
|
||||
ArrayRef<Value> operands) {
|
||||
switch (op) {
|
||||
case gpu::MMAElementwiseOp::ADDF:
|
||||
return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
|
||||
return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands);
|
||||
case gpu::MMAElementwiseOp::MULF:
|
||||
return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
|
||||
return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands);
|
||||
case gpu::MMAElementwiseOp::DIVF:
|
||||
return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
|
||||
return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands);
|
||||
case gpu::MMAElementwiseOp::MAXF:
|
||||
return createMinMaxF(builder, loc, operands[0], operands[1],
|
||||
/*isMin=*/false);
|
||||
@@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering
|
||||
size_t numOperands = adaptor.getOperands().size();
|
||||
LLVM::LLVMStructType destType = convertMMAToLLVMType(
|
||||
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
|
||||
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
|
||||
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
|
||||
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
|
||||
SmallVector<Value> extractedOperands;
|
||||
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, adaptor.getOperands()[opIdx], i));
|
||||
extractedOperands.push_back(LLVM::ExtractValueOp::create(
|
||||
rewriter, loc, adaptor.getOperands()[opIdx], i));
|
||||
}
|
||||
Value element =
|
||||
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
|
||||
extractedOperands);
|
||||
matrixStruct =
|
||||
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
|
||||
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
|
||||
}
|
||||
rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
|
||||
return success();
|
||||
|
||||
@@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
|
||||
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
|
||||
// TODO: use <=> in C++20.
|
||||
if (indexBitwidth > intWidth) {
|
||||
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
|
||||
return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
|
||||
}
|
||||
if (indexBitwidth < intWidth) {
|
||||
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
|
||||
return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@@ -82,12 +82,12 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
|
||||
static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const unsigned indexBitwidth) {
|
||||
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
|
||||
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
|
||||
Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
|
||||
ValueRange{minus1, zero});
|
||||
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
|
||||
ValueRange{minus1, mbcntLo});
|
||||
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
|
||||
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
|
||||
Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
|
||||
ValueRange{minus1, zero});
|
||||
Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
|
||||
ValueRange{minus1, mbcntLo});
|
||||
return laneId;
|
||||
}
|
||||
static constexpr StringLiteral amdgcnDataLayout =
|
||||
@@ -110,21 +110,21 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
|
||||
// followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
|
||||
|
||||
Type intTy = IntegerType::get(context, 32);
|
||||
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
|
||||
Value mbcntLo =
|
||||
rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
|
||||
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
|
||||
loc, intTy, ValueRange{minus1, mbcntLo});
|
||||
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
|
||||
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
|
||||
Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
|
||||
ValueRange{minus1, zero});
|
||||
Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
|
||||
ValueRange{minus1, mbcntLo});
|
||||
// Truncate or extend the result depending on the index bitwidth specified
|
||||
// by the LLVMTypeConverter options.
|
||||
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
|
||||
if (indexBitwidth > 32) {
|
||||
laneId = rewriter.create<LLVM::SExtOp>(
|
||||
loc, IntegerType::get(context, indexBitwidth), laneId);
|
||||
laneId = LLVM::SExtOp::create(
|
||||
rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
|
||||
} else if (indexBitwidth < 32) {
|
||||
laneId = rewriter.create<LLVM::TruncOp>(
|
||||
loc, IntegerType::get(context, indexBitwidth), laneId);
|
||||
laneId = LLVM::TruncOp::create(
|
||||
rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
|
||||
}
|
||||
rewriter.replaceOp(op, {laneId});
|
||||
return success();
|
||||
@@ -149,8 +149,8 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
|
||||
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
|
||||
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
|
||||
}
|
||||
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
|
||||
op.getLoc(), rewriter.getI32Type(), bounds);
|
||||
Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
|
||||
rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
|
||||
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
|
||||
*getTypeConverter());
|
||||
rewriter.replaceOp(op, {wavefrontOp});
|
||||
@@ -190,44 +190,44 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
|
||||
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
|
||||
Value width = adaptor.getWidth();
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
|
||||
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
|
||||
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
|
||||
Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
|
||||
Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
|
||||
Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
|
||||
Value widthOrZeroIfOutside =
|
||||
rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
|
||||
LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
|
||||
Value dstLane;
|
||||
|
||||
switch (op.getMode()) {
|
||||
case gpu::ShuffleMode::UP:
|
||||
dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId,
|
||||
adaptor.getOffset());
|
||||
dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
|
||||
adaptor.getOffset());
|
||||
break;
|
||||
case gpu::ShuffleMode::DOWN:
|
||||
dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
|
||||
adaptor.getOffset());
|
||||
dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
|
||||
adaptor.getOffset());
|
||||
break;
|
||||
case gpu::ShuffleMode::XOR:
|
||||
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
|
||||
adaptor.getOffset());
|
||||
dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
|
||||
adaptor.getOffset());
|
||||
break;
|
||||
case gpu::ShuffleMode::IDX:
|
||||
dstLane = adaptor.getOffset();
|
||||
break;
|
||||
}
|
||||
Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
|
||||
Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
|
||||
dstLane, srcLaneId);
|
||||
Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
|
||||
Value isActiveSrcLane = LLVM::ICmpOp::create(
|
||||
rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
|
||||
Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
|
||||
dstLane, srcLaneId);
|
||||
Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
|
||||
Value dwordAlignedDstLane =
|
||||
rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
|
||||
LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
|
||||
|
||||
SmallVector<Value> decomposed =
|
||||
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
|
||||
SmallVector<Value> swizzled;
|
||||
for (Value v : decomposed) {
|
||||
Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
|
||||
dwordAlignedDstLane, v);
|
||||
Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
|
||||
dwordAlignedDstLane, v);
|
||||
swizzled.emplace_back(res);
|
||||
}
|
||||
Value shflValue =
|
||||
|
||||
@@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
||||
|
||||
Value vector =
|
||||
spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
|
||||
Value dim = rewriter.create<spirv::CompositeExtractOp>(
|
||||
op.getLoc(), builtinType, vector,
|
||||
Value dim = spirv::CompositeExtractOp::create(
|
||||
rewriter, op.getLoc(), builtinType, vector,
|
||||
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
|
||||
if (forShader && builtinType != indexType)
|
||||
dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
|
||||
dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
|
||||
rewriter.replaceOp(op, dim);
|
||||
return success();
|
||||
}
|
||||
@@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
||||
Value builtinValue =
|
||||
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
|
||||
if (i32Type != indexType)
|
||||
builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
|
||||
builtinValue);
|
||||
builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
|
||||
builtinValue);
|
||||
rewriter.replaceOp(op, builtinValue);
|
||||
return success();
|
||||
}
|
||||
@@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
|
||||
signatureConverter.addInputs(argType.index(), convertedType);
|
||||
}
|
||||
}
|
||||
auto newFuncOp = rewriter.create<spirv::FuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(),
|
||||
auto newFuncOp = spirv::FuncOp::create(
|
||||
rewriter, funcOp.getLoc(), funcOp.getName(),
|
||||
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
|
||||
for (const auto &namedAttr : funcOp->getAttrs()) {
|
||||
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
|
||||
@@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
|
||||
|
||||
// Add a keyword to the module name to avoid symbolic conflict.
|
||||
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
|
||||
auto spvModule = rewriter.create<spirv::ModuleOp>(
|
||||
moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
|
||||
auto spvModule = spirv::ModuleOp::create(
|
||||
rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
|
||||
StringRef(spvModuleName));
|
||||
|
||||
// Move the region from the module op into the SPIR-V module.
|
||||
@@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
|
||||
|
||||
switch (shuffleOp.getMode()) {
|
||||
case gpu::ShuffleMode::XOR: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
result = spirv::GroupNonUniformShuffleXorOp::create(
|
||||
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
|
||||
shuffleOp.getLoc(), rewriter);
|
||||
break;
|
||||
}
|
||||
case gpu::ShuffleMode::IDX: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
result = spirv::GroupNonUniformShuffleOp::create(
|
||||
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
|
||||
shuffleOp.getLoc(), rewriter);
|
||||
break;
|
||||
}
|
||||
case gpu::ShuffleMode::DOWN: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
result = spirv::GroupNonUniformShuffleDownOp::create(
|
||||
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
|
||||
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
|
||||
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
|
||||
Value resultLaneId =
|
||||
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
|
||||
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
resultLaneId, adaptor.getWidth());
|
||||
arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
|
||||
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
|
||||
resultLaneId, adaptor.getWidth());
|
||||
break;
|
||||
}
|
||||
case gpu::ShuffleMode::UP: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
result = spirv::GroupNonUniformShuffleUpOp::create(
|
||||
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
|
||||
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
|
||||
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
|
||||
Value resultLaneId =
|
||||
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
|
||||
arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
|
||||
auto i32Type = rewriter.getIntegerType(32);
|
||||
validVal = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, resultLaneId,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
|
||||
validVal = arith::CmpIOp::create(
|
||||
rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
|
||||
arith::ConstantOp::create(rewriter, loc, i32Type,
|
||||
rewriter.getIntegerAttr(i32Type, 0)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -516,15 +516,16 @@ LogicalResult GPURotateConversion::matchAndRewrite(
|
||||
|
||||
Location loc = rotateOp.getLoc();
|
||||
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
|
||||
Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
|
||||
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
|
||||
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
|
||||
adaptor.getWidth());
|
||||
Value validVal;
|
||||
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
|
||||
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
|
||||
} else {
|
||||
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
|
||||
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
laneId, adaptor.getWidth());
|
||||
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
|
||||
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
|
||||
laneId, adaptor.getWidth());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
|
||||
@@ -548,14 +549,14 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
|
||||
? spirv::GroupOperation::ClusteredReduce
|
||||
: spirv::GroupOperation::Reduce);
|
||||
if (isUniform) {
|
||||
return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
|
||||
return UniformOp::create(builder, loc, type, scope, groupOp, arg)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value clusterSizeValue;
|
||||
if (clusterSize.has_value())
|
||||
clusterSizeValue = builder.create<spirv::ConstantOp>(
|
||||
loc, builder.getI32Type(),
|
||||
clusterSizeValue = spirv::ConstantOp::create(
|
||||
builder, loc, builder.getI32Type(),
|
||||
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
|
||||
|
||||
return builder
|
||||
@@ -740,8 +741,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
|
||||
std::string specCstName =
|
||||
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
|
||||
|
||||
return rewriter.create<spirv::SpecConstantOp>(
|
||||
loc, rewriter.getStringAttr(specCstName), attr);
|
||||
return spirv::SpecConstantOp::create(
|
||||
rewriter, loc, rewriter.getStringAttr(specCstName), attr);
|
||||
};
|
||||
{
|
||||
Operation *parent =
|
||||
@@ -774,8 +775,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
|
||||
std::string specCstCompositeName =
|
||||
(llvm::Twine(globalVarName) + "_scc").str();
|
||||
|
||||
specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
|
||||
loc, TypeAttr::get(globalType),
|
||||
specCstComposite = spirv::SpecConstantCompositeOp::create(
|
||||
rewriter, loc, TypeAttr::get(globalType),
|
||||
rewriter.getStringAttr(specCstCompositeName),
|
||||
rewriter.getArrayAttr(constituents));
|
||||
|
||||
@@ -785,23 +786,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
|
||||
// Define a GlobalVarOp initialized using specialized constants
|
||||
// that is used to specify the printf format string
|
||||
// to be passed to the SPIRV CLPrintfOp.
|
||||
globalVar = rewriter.create<spirv::GlobalVariableOp>(
|
||||
loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
|
||||
globalVar = spirv::GlobalVariableOp::create(
|
||||
rewriter, loc, ptrType, globalVarName,
|
||||
FlatSymbolRefAttr::get(specCstComposite));
|
||||
|
||||
globalVar->setAttr("Constant", rewriter.getUnitAttr());
|
||||
}
|
||||
// Get SSA value of Global variable and create pointer to i8 to point to
|
||||
// the format string.
|
||||
Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
|
||||
Value fmtStr = rewriter.create<spirv::BitcastOp>(
|
||||
loc,
|
||||
Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
|
||||
Value fmtStr = spirv::BitcastOp::create(
|
||||
rewriter, loc,
|
||||
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
|
||||
globalPtr);
|
||||
|
||||
// Get printf arguments.
|
||||
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
|
||||
|
||||
rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
|
||||
spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
|
||||
|
||||
// Need to erase the gpu.printf op as gpu.printf does not use result vs
|
||||
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
|
||||
|
||||
@@ -144,11 +144,12 @@ void GPUToSPIRVPass::runOnOperation() {
|
||||
if (targetEnvSupportsKernelCapability(moduleOp)) {
|
||||
moduleOp.walk([&](gpu::GPUFuncOp funcOp) {
|
||||
builder.setInsertionPoint(funcOp);
|
||||
auto newFuncOp = builder.create<func::FuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
|
||||
auto newFuncOp =
|
||||
func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(),
|
||||
funcOp.getFunctionType());
|
||||
auto entryBlock = newFuncOp.addEntryBlock();
|
||||
builder.setInsertionPointToEnd(entryBlock);
|
||||
builder.create<func::ReturnOp>(funcOp.getLoc());
|
||||
func::ReturnOp::create(builder, funcOp.getLoc());
|
||||
newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
funcOp.erase();
|
||||
|
||||
@@ -283,8 +283,8 @@ struct WmmaLoadOpToSPIRVLowering final
|
||||
|
||||
int64_t stride = op.getLeadDimension().getSExtValue();
|
||||
IntegerType i32Type = rewriter.getI32Type();
|
||||
auto strideValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, IntegerAttr::get(i32Type, stride));
|
||||
auto strideValue = spirv::ConstantOp::create(
|
||||
rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
|
||||
|
||||
bool isColMajor = op.getTranspose().value_or(false);
|
||||
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
|
||||
@@ -315,8 +315,8 @@ struct WmmaStoreOpToSPIRVLowering final
|
||||
|
||||
int64_t stride = op.getLeadDimension().getSExtValue();
|
||||
IntegerType i32Type = rewriter.getI32Type();
|
||||
auto strideValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, IntegerAttr::get(i32Type, stride));
|
||||
auto strideValue = spirv::ConstantOp::create(
|
||||
rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
|
||||
|
||||
bool isColMajor = op.getTranspose().value_or(false);
|
||||
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
|
||||
|
||||
Reference in New Issue
Block a user