[mlir][NFC] update Conversion create APIs (5/n) (#149887)

See https://github.com/llvm/llvm-project/pull/147168 for more info.
This commit is contained in:
Maksim Levental
2025-07-22 09:40:45 -05:00
committed by GitHub
parent a415d68e48
commit eaa67a3cf0
20 changed files with 1074 additions and 1031 deletions

View File

@@ -23,41 +23,43 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe,
Value *resultIm) {
Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
Value rhsSqNorm = LLVM::FAddOp::create(
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
Value realNumerator = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
Value realNumerator = LLVM::FAddOp::create(
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
Value imagNumerator = rewriter.create<LLVM::FSubOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
Value imagNumerator = LLVM::FSubOp::create(
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
*resultRe = rewriter.create<LLVM::FDivOp>(loc, realNumerator, rhsSqNorm, fmf);
*resultIm = rewriter.create<LLVM::FDivOp>(loc, imagNumerator, rhsSqNorm, fmf);
*resultRe =
LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
*resultIm =
LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
}
void mlir::complex::convertDivToStandardUsingAlgebraic(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
Value *resultIm) {
Value rhsSqNorm = rewriter.create<arith::AddFOp>(
loc, rewriter.create<arith::MulFOp>(loc, rhsRe, rhsRe, fmf),
rewriter.create<arith::MulFOp>(loc, rhsIm, rhsIm, fmf), fmf);
Value rhsSqNorm = arith::AddFOp::create(
rewriter, loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf),
arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf);
Value realNumerator = rewriter.create<arith::AddFOp>(
loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRe, fmf),
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsIm, fmf), fmf);
Value imagNumerator = rewriter.create<arith::SubFOp>(
loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRe, fmf),
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsIm, fmf), fmf);
Value realNumerator = arith::AddFOp::create(
rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf),
arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf);
Value imagNumerator = arith::SubFOp::create(
rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
*resultRe =
rewriter.create<arith::DivFOp>(loc, realNumerator, rhsSqNorm, fmf);
arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf);
*resultIm =
rewriter.create<arith::DivFOp>(loc, imagNumerator, rhsSqNorm, fmf);
arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf);
}
// Smith's algorithm to divide complex numbers. It is just a bit smarter
@@ -94,181 +96,185 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction(
auto elementType = cast<FloatType>(rhsRe.getType());
Value rhsRealImagRatio =
rewriter.create<LLVM::FDivOp>(loc, rhsRe, rhsIm, fmf);
Value rhsRealImagDenom = rewriter.create<LLVM::FAddOp>(
loc, rhsIm,
rewriter.create<LLVM::FMulOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
Value realNumerator1 = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealImagRatio, fmf),
lhsIm, fmf);
Value resultReal1 =
rewriter.create<LLVM::FDivOp>(loc, realNumerator1, rhsRealImagDenom, fmf);
Value imagNumerator1 = rewriter.create<LLVM::FSubOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealImagRatio, fmf),
lhsRe, fmf);
Value resultImag1 =
rewriter.create<LLVM::FDivOp>(loc, imagNumerator1, rhsRealImagDenom, fmf);
LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
Value rhsRealImagDenom = LLVM::FAddOp::create(
rewriter, loc, rhsIm,
LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
Value realNumerator1 = LLVM::FAddOp::create(
rewriter, loc,
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
fmf);
Value resultReal1 = LLVM::FDivOp::create(rewriter, loc, realNumerator1,
rhsRealImagDenom, fmf);
Value imagNumerator1 = LLVM::FSubOp::create(
rewriter, loc,
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
fmf);
Value resultImag1 = LLVM::FDivOp::create(rewriter, loc, imagNumerator1,
rhsRealImagDenom, fmf);
Value rhsImagRealRatio =
rewriter.create<LLVM::FDivOp>(loc, rhsIm, rhsRe, fmf);
Value rhsImagRealDenom = rewriter.create<LLVM::FAddOp>(
loc, rhsRe,
rewriter.create<LLVM::FMulOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
Value realNumerator2 = rewriter.create<LLVM::FAddOp>(
loc, lhsRe,
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
Value resultReal2 =
rewriter.create<LLVM::FDivOp>(loc, realNumerator2, rhsImagRealDenom, fmf);
Value imagNumerator2 = rewriter.create<LLVM::FSubOp>(
loc, lhsIm,
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
Value resultImag2 =
rewriter.create<LLVM::FDivOp>(loc, imagNumerator2, rhsImagRealDenom, fmf);
LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
Value rhsImagRealDenom = LLVM::FAddOp::create(
rewriter, loc, rhsRe,
LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
Value realNumerator2 = LLVM::FAddOp::create(
rewriter, loc, lhsRe,
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
Value resultReal2 = LLVM::FDivOp::create(rewriter, loc, realNumerator2,
rhsImagRealDenom, fmf);
Value imagNumerator2 = LLVM::FSubOp::create(
rewriter, loc, lhsIm,
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
Value resultImag2 = LLVM::FDivOp::create(rewriter, loc, imagNumerator2,
rhsImagRealDenom, fmf);
// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
Value zero = rewriter.create<LLVM::ConstantOp>(
loc, elementType, rewriter.getZeroAttr(elementType));
Value rhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsRe, fmf);
Value rhsRealIsZero = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
Value rhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsIm, fmf);
Value rhsImagIsZero = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
Value lhsRealIsNotNaN =
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
Value lhsImagIsNotNaN =
rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType,
rewriter.getZeroAttr(elementType));
Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf);
Value rhsRealIsZero = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf);
Value rhsImagIsZero = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
Value lhsRealIsNotNaN = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
Value lhsImagIsNotNaN = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
Value lhsContainsNotNaNValue =
rewriter.create<LLVM::OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
Value resultIsInfinity = rewriter.create<LLVM::AndOp>(
loc, lhsContainsNotNaNValue,
rewriter.create<LLVM::AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
Value inf = rewriter.create<LLVM::ConstantOp>(
loc, elementType,
LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
Value resultIsInfinity = LLVM::AndOp::create(
rewriter, loc, lhsContainsNotNaNValue,
LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
Value inf = LLVM::ConstantOp::create(
rewriter, loc, elementType,
rewriter.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
Value infWithSignOfrhsReal =
rewriter.create<LLVM::CopySignOp>(loc, inf, rhsRe);
LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe);
Value infinityResultReal =
rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsRe, fmf);
LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf);
Value infinityResultImag =
rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsIm, fmf);
LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf);
// Case 2. Infinite numerator, finite denominator.
Value rhsRealFinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
Value rhsImagFinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
Value rhsRealFinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
Value rhsImagFinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
Value rhsFinite =
rewriter.create<LLVM::AndOp>(loc, rhsRealFinite, rhsImagFinite);
Value lhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsRe, fmf);
Value lhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
Value lhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsIm, fmf);
Value lhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf);
Value lhsRealInfinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf);
Value lhsImagInfinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
Value lhsInfinite =
rewriter.create<LLVM::OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
Value infNumFiniteDenom =
rewriter.create<LLVM::AndOp>(loc, lhsInfinite, rhsFinite);
Value one = rewriter.create<LLVM::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 1));
Value lhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
loc, rewriter.create<LLVM::SelectOp>(loc, lhsRealInfinite, one, zero),
lhsRe);
Value lhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
loc, rewriter.create<LLVM::SelectOp>(loc, lhsImagInfinite, one, zero),
lhsIm);
LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite);
Value one = LLVM::ConstantOp::create(rewriter, loc, elementType,
rewriter.getFloatAttr(elementType, 1));
Value lhsRealIsInfWithSign = LLVM::CopySignOp::create(
rewriter, loc,
LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe);
Value lhsImagIsInfWithSign = LLVM::CopySignOp::create(
rewriter, loc,
LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm);
Value lhsRealIsInfWithSignTimesrhsReal =
rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
Value lhsImagIsInfWithSignTimesrhsImag =
rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
Value resultReal3 = rewriter.create<LLVM::FMulOp>(
loc, inf,
rewriter.create<LLVM::FAddOp>(loc, lhsRealIsInfWithSignTimesrhsReal,
lhsImagIsInfWithSignTimesrhsImag, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
Value resultReal3 = LLVM::FMulOp::create(
rewriter, loc, inf,
LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal,
lhsImagIsInfWithSignTimesrhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesrhsImag =
rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
Value lhsImagIsInfWithSignTimesrhsReal =
rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
Value resultImag3 = rewriter.create<LLVM::FMulOp>(
loc, inf,
rewriter.create<LLVM::FSubOp>(loc, lhsImagIsInfWithSignTimesrhsReal,
lhsRealIsInfWithSignTimesrhsImag, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
Value resultImag3 = LLVM::FMulOp::create(
rewriter, loc, inf,
LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal,
lhsRealIsInfWithSignTimesrhsImag, fmf),
fmf);
// Case 3: Finite numerator, infinite denominator.
Value lhsRealFinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
Value lhsImagFinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
Value lhsRealFinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf);
Value lhsImagFinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf);
Value lhsFinite =
rewriter.create<LLVM::AndOp>(loc, lhsRealFinite, lhsImagFinite);
Value rhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
Value rhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
Value rhsRealInfinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf);
Value rhsImagInfinite = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf);
Value rhsInfinite =
rewriter.create<LLVM::OrOp>(loc, rhsRealInfinite, rhsImagInfinite);
LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
Value finiteNumInfiniteDenom =
rewriter.create<LLVM::AndOp>(loc, lhsFinite, rhsInfinite);
Value rhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
loc, rewriter.create<LLVM::SelectOp>(loc, rhsRealInfinite, one, zero),
rhsRe);
Value rhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
loc, rewriter.create<LLVM::SelectOp>(loc, rhsImagInfinite, one, zero),
rhsIm);
LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite);
Value rhsRealIsInfWithSign = LLVM::CopySignOp::create(
rewriter, loc,
LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe);
Value rhsImagIsInfWithSign = LLVM::CopySignOp::create(
rewriter, loc,
LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm);
Value rhsRealIsInfWithSignTimeslhsReal =
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimeslhsImag =
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
Value resultReal4 = rewriter.create<LLVM::FMulOp>(
loc, zero,
rewriter.create<LLVM::FAddOp>(loc, rhsRealIsInfWithSignTimeslhsReal,
rhsImagIsInfWithSignTimeslhsImag, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
Value resultReal4 = LLVM::FMulOp::create(
rewriter, loc, zero,
LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal,
rhsImagIsInfWithSignTimeslhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimeslhsImag =
rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimeslhsReal =
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
Value resultImag4 = rewriter.create<LLVM::FMulOp>(
loc, zero,
rewriter.create<LLVM::FSubOp>(loc, rhsRealIsInfWithSignTimeslhsImag,
rhsImagIsInfWithSignTimeslhsReal, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
Value resultImag4 = LLVM::FMulOp::create(
rewriter, loc, zero,
LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag,
rhsImagIsInfWithSignTimeslhsReal, fmf),
fmf);
Value realAbsSmallerThanImagAbs = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
Value resultReal5 = rewriter.create<LLVM::SelectOp>(
loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
Value resultImag5 = rewriter.create<LLVM::SelectOp>(
loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
Value resultRealSpecialCase3 = rewriter.create<LLVM::SelectOp>(
loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
Value resultImagSpecialCase3 = rewriter.create<LLVM::SelectOp>(
loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
Value resultRealSpecialCase2 = rewriter.create<LLVM::SelectOp>(
loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
Value resultImagSpecialCase2 = rewriter.create<LLVM::SelectOp>(
loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
Value resultRealSpecialCase1 = rewriter.create<LLVM::SelectOp>(
loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
Value resultImagSpecialCase1 = rewriter.create<LLVM::SelectOp>(
loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs);
Value resultReal5 = LLVM::SelectOp::create(
rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
Value resultImag5 = LLVM::SelectOp::create(
rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
Value resultRealSpecialCase3 = LLVM::SelectOp::create(
rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
Value resultImagSpecialCase3 = LLVM::SelectOp::create(
rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
Value resultRealSpecialCase2 = LLVM::SelectOp::create(
rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
Value resultImagSpecialCase2 = LLVM::SelectOp::create(
rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
Value resultRealSpecialCase1 =
LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
infinityResultReal, resultRealSpecialCase2);
Value resultImagSpecialCase1 =
LLVM::SelectOp::create(rewriter, loc, resultIsInfinity,
infinityResultImag, resultImagSpecialCase2);
Value resultRealIsNaN = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
Value resultImagIsNaN = rewriter.create<LLVM::FCmpOp>(
loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
Value resultRealIsNaN = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero);
Value resultImagIsNaN = LLVM::FCmpOp::create(
rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero);
Value resultIsNaN =
rewriter.create<LLVM::AndOp>(loc, resultRealIsNaN, resultImagIsNaN);
LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
*resultRe = rewriter.create<LLVM::SelectOp>(
loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
*resultIm = rewriter.create<LLVM::SelectOp>(
loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
*resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
resultRealSpecialCase1, resultReal5);
*resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN,
resultImagSpecialCase1, resultImag5);
}
void mlir::complex::convertDivToStandardUsingRangeReduction(
@@ -278,179 +284,187 @@ void mlir::complex::convertDivToStandardUsingRangeReduction(
auto elementType = cast<FloatType>(rhsRe.getType());
Value rhsRealImagRatio =
rewriter.create<arith::DivFOp>(loc, rhsRe, rhsIm, fmf);
Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
loc, rhsIm,
rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
Value realNumerator1 = rewriter.create<arith::AddFOp>(
loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealImagRatio, fmf),
lhsIm, fmf);
Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
rhsRealImagDenom, fmf);
Value imagNumerator1 = rewriter.create<arith::SubFOp>(
loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealImagRatio, fmf),
lhsRe, fmf);
Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
rhsRealImagDenom, fmf);
arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf);
Value rhsRealImagDenom = arith::AddFOp::create(
rewriter, loc, rhsIm,
arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf);
Value realNumerator1 = arith::AddFOp::create(
rewriter, loc,
arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm,
fmf);
Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1,
rhsRealImagDenom, fmf);
Value imagNumerator1 = arith::SubFOp::create(
rewriter, loc,
arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe,
fmf);
Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1,
rhsRealImagDenom, fmf);
Value rhsImagRealRatio =
rewriter.create<arith::DivFOp>(loc, rhsIm, rhsRe, fmf);
Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
loc, rhsRe,
rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
Value realNumerator2 = rewriter.create<arith::AddFOp>(
loc, lhsRe,
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
rhsImagRealDenom, fmf);
Value imagNumerator2 = rewriter.create<arith::SubFOp>(
loc, lhsIm,
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
rhsImagRealDenom, fmf);
arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf);
Value rhsImagRealDenom = arith::AddFOp::create(
rewriter, loc, rhsRe,
arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf);
Value realNumerator2 = arith::AddFOp::create(
rewriter, loc, lhsRe,
arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf);
Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2,
rhsImagRealDenom, fmf);
Value imagNumerator2 = arith::SubFOp::create(
rewriter, loc, lhsIm,
arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf);
Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2,
rhsImagRealDenom, fmf);
// Consider corner cases.
// Case 1. Zero denominator, numerator contains at most one NaN value.
Value zero = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getZeroAttr(elementType));
Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsRe, fmf);
Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsIm, fmf);
Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ORD, lhsRe, zero);
Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ORD, lhsIm, zero);
Value zero = arith::ConstantOp::create(rewriter, loc, elementType,
rewriter.getZeroAttr(elementType));
Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf);
Value rhsRealIsZero = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf);
Value rhsImagIsZero = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
Value lhsRealIsNotNaN = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero);
Value lhsImagIsNotNaN = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero);
Value lhsContainsNotNaNValue =
rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
Value resultIsInfinity = rewriter.create<arith::AndIOp>(
loc, lhsContainsNotNaNValue,
rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
Value inf = rewriter.create<arith::ConstantOp>(
loc, elementType,
arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
Value resultIsInfinity = arith::AndIOp::create(
rewriter, loc, lhsContainsNotNaNValue,
arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero));
Value inf = arith::ConstantOp::create(
rewriter, loc, elementType,
rewriter.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
Value infWithSignOfRhsReal =
rewriter.create<math::CopySignOp>(loc, inf, rhsRe);
math::CopySignOp::create(rewriter, loc, inf, rhsRe);
Value infinityResultReal =
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsRe, fmf);
arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf);
Value infinityResultImag =
rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsIm, fmf);
arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf);
// Case 2. Infinite numerator, finite denominator.
Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
Value rhsRealFinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
Value rhsImagFinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
Value rhsFinite =
rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsRe, fmf);
Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsIm, fmf);
Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite);
Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf);
Value lhsRealInfinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf);
Value lhsImagInfinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
Value lhsInfinite =
rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite);
Value infNumFiniteDenom =
rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
Value one = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 1));
Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite);
Value one = arith::ConstantOp::create(rewriter, loc, elementType,
rewriter.getFloatAttr(elementType, 1));
Value lhsRealIsInfWithSign = math::CopySignOp::create(
rewriter, loc,
arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero),
lhsRe);
Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
Value lhsImagIsInfWithSign = math::CopySignOp::create(
rewriter, loc,
arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero),
lhsIm);
Value lhsRealIsInfWithSignTimesRhsReal =
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf);
Value lhsImagIsInfWithSignTimesRhsImag =
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
Value resultReal3 = rewriter.create<arith::MulFOp>(
loc, inf,
rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
lhsImagIsInfWithSignTimesRhsImag, fmf),
arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf);
Value resultReal3 = arith::MulFOp::create(
rewriter, loc, inf,
arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal,
lhsImagIsInfWithSignTimesRhsImag, fmf),
fmf);
Value lhsRealIsInfWithSignTimesRhsImag =
rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf);
arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf);
Value lhsImagIsInfWithSignTimesRhsReal =
rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf);
Value resultImag3 = rewriter.create<arith::MulFOp>(
loc, inf,
rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
lhsRealIsInfWithSignTimesRhsImag, fmf),
arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf);
Value resultImag3 = arith::MulFOp::create(
rewriter, loc, inf,
arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal,
lhsRealIsInfWithSignTimesRhsImag, fmf),
fmf);
// Case 3: Finite numerator, infinite denominator.
Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
Value lhsRealFinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
Value lhsImagFinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
Value lhsFinite =
rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite);
Value rhsRealInfinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
Value rhsImagInfinite = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
Value rhsInfinite =
rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite);
Value finiteNumInfiniteDenom =
rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite);
Value rhsRealIsInfWithSign = math::CopySignOp::create(
rewriter, loc,
arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero),
rhsRe);
Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
Value rhsImagIsInfWithSign = math::CopySignOp::create(
rewriter, loc,
arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero),
rhsIm);
Value rhsRealIsInfWithSignTimesLhsReal =
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf);
arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsImag =
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf);
Value resultReal4 = rewriter.create<arith::MulFOp>(
loc, zero,
rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
rhsImagIsInfWithSignTimesLhsImag, fmf),
arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf);
Value resultReal4 = arith::MulFOp::create(
rewriter, loc, zero,
arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal,
rhsImagIsInfWithSignTimesLhsImag, fmf),
fmf);
Value rhsRealIsInfWithSignTimesLhsImag =
rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf);
arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf);
Value rhsImagIsInfWithSignTimesLhsReal =
rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf);
Value resultImag4 = rewriter.create<arith::MulFOp>(
loc, zero,
rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
rhsImagIsInfWithSignTimesLhsReal, fmf),
arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf);
Value resultImag4 = arith::MulFOp::create(
rewriter, loc, zero,
arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag,
rhsImagIsInfWithSignTimesLhsReal, fmf),
fmf);
Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
Value resultReal5 = rewriter.create<arith::SelectOp>(
loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
Value resultImag5 = rewriter.create<arith::SelectOp>(
loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
Value realAbsSmallerThanImagAbs = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
Value resultReal5 = arith::SelectOp::create(
rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
Value resultImag5 = arith::SelectOp::create(
rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
Value resultRealSpecialCase3 = arith::SelectOp::create(
rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5);
Value resultImagSpecialCase3 = arith::SelectOp::create(
rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5);
Value resultRealSpecialCase2 = arith::SelectOp::create(
rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
Value resultImagSpecialCase2 = arith::SelectOp::create(
rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
Value resultRealSpecialCase1 =
arith::SelectOp::create(rewriter, loc, resultIsInfinity,
infinityResultReal, resultRealSpecialCase2);
Value resultImagSpecialCase1 =
arith::SelectOp::create(rewriter, loc, resultIsInfinity,
infinityResultImag, resultImagSpecialCase2);
Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, resultReal5, zero);
Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, resultImag5, zero);
Value resultRealIsNaN = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero);
Value resultImagIsNaN = arith::CmpFOp::create(
rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero);
Value resultIsNaN =
rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN);
*resultRe = rewriter.create<arith::SelectOp>(
loc, resultIsNaN, resultRealSpecialCase1, resultReal5);
*resultIm = rewriter.create<arith::SelectOp>(
loc, resultIsNaN, resultImagSpecialCase1, resultImag5);
*resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN,
resultRealSpecialCase1, resultReal5);
*resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN,
resultImagSpecialCase1, resultImag5);
}

View File

@@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder,
Location loc, Type type) {
Value val = builder.create<LLVM::PoisonOp>(loc, type);
Value val = LLVM::PoisonOp::create(builder, loc, type);
return ComplexStructBuilder(val);
}
@@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
Value sqNorm = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
Value sqNorm = LLVM::FAddOp::create(
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
return success();
@@ -191,10 +191,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
Value real =
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
arg.rhs.real(), fmf);
Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
Value lhsRe = arg.lhs.real();
Value lhsIm = arg.lhs.imag();
Value real = rewriter.create<LLVM::FSubOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
Value real = LLVM::FSubOp::create(
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
Value imag = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
Value imag = LLVM::FAddOp::create(
rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -313,10 +313,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
op.getContext(),
convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
Value real =
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
arg.rhs.real(), fmf);
Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);

View File

@@ -84,8 +84,8 @@ LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
auto opFunctionTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
opFunctionTy);
opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name,
opFunctionTy);
opFunc.setPrivate();
}
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));

View File

@@ -44,8 +44,8 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
auto funcTy = FunctionType::get(
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
funcTy);
opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(),
funcName, funcTy);
opFunc.setPrivate();
}
rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),

File diff suppressed because it is too large Load Diff

View File

@@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
rewriter.create<LLVM::CallOp>(loc, abortFunc, ValueRange());
rewriter.create<LLVM::UnreachableOp>(loc);
LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange());
LLVM::UnreachableOp::create(rewriter, loc);
} else {
rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock);
}
// Generate assertion test.

View File

@@ -33,8 +33,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
MutableArrayRef<Region> regions) {
if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
assert(regions.size() == 2);
auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
resultTypes, condBrOp.getCondition());
auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(),
resultTypes, condBrOp.getCondition());
ifOp.getThenRegion().takeBody(regions[0]);
ifOp.getElseRegion().takeBody(regions[1]);
return ifOp.getOperation();
@@ -43,8 +43,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
// `getCFGSwitchValue` returns an i32 that we need to convert to index
// fist.
auto cast = builder.create<arith::IndexCastUIOp>(
controlFlowCondOp->getLoc(), builder.getIndexType(),
auto cast = arith::IndexCastUIOp::create(
builder, controlFlowCondOp->getLoc(), builder.getIndexType(),
switchOp.getFlag());
SmallVector<int64_t> cases;
if (auto caseValues = switchOp.getCaseValues())
@@ -55,8 +55,9 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
assert(regions.size() == cases.size() + 1);
auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
auto indexSwitchOp =
scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(),
resultTypes, cast, cases, cases.size());
indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
for (auto &&[targetRegion, sourceRegion] :
@@ -75,7 +76,7 @@ LogicalResult
ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
Location loc, OpBuilder &builder, Operation *branchRegionOp,
Operation *replacedControlFlowOp, ValueRange results) {
builder.create<scf::YieldOp>(loc, results);
scf::YieldOp::create(builder, loc, results);
return success();
}
@@ -84,23 +85,24 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
Location loc = replacedOp->getLoc();
auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
loopVariablesInit);
auto whileOp = scf::WhileOp::create(
builder, loc, loopVariablesInit.getTypes(), loopVariablesInit);
whileOp.getBefore().takeBody(loopBody);
builder.setInsertionPointToEnd(&whileOp.getBefore().back());
// `getCFGSwitchValue` returns a i32. We therefore need to truncate the
// condition to i1 first. It is guaranteed to be either 0 or 1 already.
builder.create<scf::ConditionOp>(
loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
scf::ConditionOp::create(
builder, loc,
arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition),
loopVariablesNextIter);
Block *afterBlock = builder.createBlock(&whileOp.getAfter());
afterBlock->addArguments(
loopVariablesInit.getTypes(),
SmallVector<Location>(loopVariablesInit.size(), loc));
builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
scf::YieldOp::create(builder, loc, afterBlock->getArguments());
return whileOp.getOperation();
}
@@ -108,8 +110,8 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
OpBuilder &builder,
unsigned int value) {
return builder.create<arith::ConstantOp>(loc,
builder.getI32IntegerAttr(value));
return arith::ConstantOp::create(builder, loc,
builder.getI32IntegerAttr(value));
}
void ControlFlowToSCFTransformation::createCFGSwitchOp(
@@ -117,15 +119,15 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp(
ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseArguments, Block *defaultDest,
ValueRange defaultArgs) {
builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
llvm::to_vector_of<int32_t>(caseValues),
caseDestinations, caseArguments);
cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs,
llvm::to_vector_of<int32_t>(caseValues),
caseDestinations, caseArguments);
}
Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
OpBuilder &builder,
Type type) {
return builder.create<ub::PoisonOp>(loc, type, nullptr);
return ub::PoisonOp::create(builder, loc, type, nullptr);
}
FailureOr<Operation *>

View File

@@ -99,8 +99,8 @@ public:
}
// Create the converted `emitc.func` op.
emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
emitc::FuncOp newFuncOp = emitc::FuncOp::create(
rewriter, funcOp.getLoc(), funcOp.getName(),
FunctionType::get(rewriter.getContext(),
signatureConverter.getConvertedTypes(),
resultType ? TypeRange(resultType) : TypeRange()));

View File

@@ -115,8 +115,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
SmallVector<NamedAttribute> attributes;
filterFuncAttributes(funcOp, attributes);
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
auto wrapperFuncOp = LLVM::LLVMFuncOp::create(
rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
@@ -129,14 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(index + argOffset);
if (auto memrefType = dyn_cast<MemRefType>(argType)) {
Value loaded = rewriter.create<LLVM::LoadOp>(
loc, typeConverter.convertType(memrefType), arg);
Value loaded = LLVM::LoadOp::create(
rewriter, loc, typeConverter.convertType(memrefType), arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
continue;
}
if (isa<UnrankedMemRefType>(argType)) {
Value loaded = rewriter.create<LLVM::LoadOp>(
loc, typeConverter.convertType(argType), arg);
Value loaded = LLVM::LoadOp::create(
rewriter, loc, typeConverter.convertType(argType), arg);
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
continue;
}
@@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
args.push_back(arg);
}
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args);
if (resultStructType) {
rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
wrapperFuncOp.getArgument(0));
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
LLVM::StoreOp::create(rewriter, loc, call.getResult(),
wrapperFuncOp.getArgument(0));
LLVM::ReturnOp::create(rewriter, loc, ValueRange{});
} else {
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
LLVM::ReturnOp::create(rewriter, loc, call.getResults());
}
}
@@ -182,8 +182,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
filterFuncAttributes(funcOp, attributes);
// Create the auxiliary function.
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
auto wrapperFunc = LLVM::LLVMFuncOp::create(
builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
/*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
@@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
if (resultStructType) {
// Allocate the struct on the stack and pass the pointer.
Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
Value one = LLVM::ConstantOp::create(
builder, loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
Value result =
builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one);
args.push_back(result);
}
@@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
wrapperArgsRange.take_front(numToDrop));
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
Value one = LLVM::ConstantOp::create(
builder, loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
Value allocated = builder.create<LLVM::AllocaOp>(
loc, ptrTy, packed.getType(), one, /*alignment=*/0);
builder.create<LLVM::StoreOp>(loc, packed, allocated);
Value allocated = LLVM::AllocaOp::create(
builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0);
LLVM::StoreOp::create(builder, loc, packed, allocated);
arg = allocated;
} else {
arg = wrapperArgsRange[0];
@@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args);
if (resultStructType) {
Value result =
builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
builder.create<LLVM::ReturnOp>(loc, result);
LLVM::LoadOp::create(builder, loc, resultStructType, args.front());
LLVM::ReturnOp::create(builder, loc, result);
} else {
builder.create<LLVM::ReturnOp>(loc, call.getResults());
LLVM::ReturnOp::create(builder, loc, call.getResults());
}
}
@@ -283,7 +283,7 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}
@@ -357,8 +357,8 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
symbolTable.remove(funcOp);
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
auto newFuncOp = LLVM::LLVMFuncOp::create(
rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
attributes);
@@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
return rewriter.notifyMatchFailure(op, "failed to convert result type");
auto newOp =
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue());
for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.getName().strref() == "value")
continue;
@@ -556,9 +556,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
auto promoted = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter, useBarePtrCallConv);
auto newOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promoted, callOp->getAttrs());
auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(),
packedResult ? TypeRange(packedResult)
: TypeRange(),
promoted, callOp->getAttrs());
newOp.getProperties().operandSegmentSizes = {
static_cast<int32_t>(promoted.size()), 0};
@@ -573,8 +574,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
callOp.getLoc(), newOp->getResult(0), i));
results.push_back(LLVM::ExtractValueOp::create(
rewriter, callOp.getLoc(), newOp->getResult(0), i));
}
}
@@ -726,9 +727,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());

View File

@@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}
return ret;
}
@@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
return b.create<LLVM::GlobalOp>(loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
return LLVM::GlobalOp::create(b, loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
}
LogicalResult
@@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
gpuFuncOp.getWorkgroupAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
auto globalOp = rewriter.create<LLVM::GlobalOp>(
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
auto globalOp = LLVM::GlobalOp::create(
rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
workgroupAddrSpace);
workgroupBuffers.push_back(globalOp);
@@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
LLVM::CConv callingConvention = gpuFuncOp.isKernel()
? kernelCallingConvention
: nonKernelCallingConvention;
auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
auto llvmFuncOp = LLVM::LLVMFuncOp::create(
rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
/*comdat=*/nullptr, attributes);
@@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
global.getAddrSpace());
Value address = rewriter.create<LLVM::AddressOfOp>(
loc, ptrType, global.getSymNameAttr());
Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
global.getSymNameAttr());
Value memory =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
address, ArrayRef<LLVM::GEPArg>{0, 0});
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
address, ArrayRef<LLVM::GEPArg>{0, 0});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
@@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
Type elementType = typeConverter->convertType(type.getElementType());
auto ptrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
Value numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
Value numElements = LLVM::ConstantOp::create(
rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
uint64_t alignment = 0;
if (auto alignAttr =
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
Value allocated =
LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
elementType, numElements, alignment);
Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
@@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
{llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
auto printfBeginCall =
LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
@@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
Value globalPtr = LLVM::AddressOfOp::create(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen = LLVM::ConstantOp::create(
rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
loc, ocklAppendStringN,
auto appendFormatCall = LLVM::CallOp::create(
rewriter, loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
adaptor.getArgs().empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
@@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = adaptor.getArgs()[i];
if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
if (!floatType.isF64())
arg = rewriter.create<LLVM::FPExtOp>(
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
arg = LLVM::FPExtOp::create(
rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
arg);
arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);
arguments.push_back(arg);
}
@@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
rewriter.eraseOp(gpuPrintfOp);
@@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
/*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
Value globalPtr = LLVM::AddressOfOp::create(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.getArgs();
@@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
"printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
Value stringStart =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
@@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
assert(type.isIntOrFloat());
if (isa<FloatType>(type)) {
type = rewriter.getF64Type();
promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
}
types.push_back(type);
args.push_back(promotedArg);
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
rewriter.getIndexAttr(1));
Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
rewriter.getIndexAttr(1));
Value tempAlloc =
rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
/*alignment=*/0);
LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
/*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, ptrType, structType, tempAlloc,
Value ptr = LLVM::GEPOp::create(
rewriter, loc, ptrType, structType, tempAlloc,
ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
LLVM::StoreOp::create(rewriter, loc, arg, ptr);
}
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
@@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
TypeRange operandTypes(operands);
VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
Location loc = op->getLoc();
Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType);
Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
Type indexType = converter.convertType(rewriter.getIndexType());
StringAttr name = op->getName().getIdentifier();
Type elementType = vectorType.getElementType();
for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
auto extractElement = [&](Value operand) -> Value {
if (!isa<VectorType>(operand.getType()))
return operand;
return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
};
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
Operation *scalarOp =
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
result = rewriter.create<LLVM::InsertElementOp>(
loc, result, scalarOp->getResult(0), index);
result = LLVM::InsertElementOp::create(rewriter, loc, result,
scalarOp->getResult(0), index);
}
return result;
}
@@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
return rewriter.create<LLVM::GlobalOp>(
op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
addressSpace.value());
return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
/*isConstant=*/false, LLVM::Linkage::Internal,
symName, /*value=*/Attribute(), alignmentByte,
addressSpace.value());
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
@@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
Type baseType = basePtr->getResultTypes().front();
// Step 4. Generate GEP using offsets
SmallVector<LLVM::GEPArg> gepArgs = {0};
Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
basePtr, gepArgs);
Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
basePtr, gepArgs);
// Step 5. Create a memref descriptor
SmallVector<Value> shape, strides;
Value sizeBytes;
@@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
op->getAttrs());

View File

@@ -79,8 +79,8 @@ protected:
uint64_t rank = type.getRank();
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
for (unsigned i = 1; i < rank; i++)
numElements = rewriter.create<LLVM::MulOp>(
loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
numElements = LLVM::MulOp::create(rewriter, loc, numElements,
desc.size(rewriter, loc, /*pos=*/i));
return numElements;
}
@@ -582,7 +582,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder::atBlockEnd(module.getBody())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
return builder.create<LLVM::CallOp>(loc, function, arguments);
return LLVM::CallOp::create(builder, loc, function, arguments);
}
// Corresponding to cusparseIndexType_t defined in cusparse.h.
@@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType);
Value stream = adaptor.getAsyncDependencies().empty()
? nullPtr
: adaptor.getAsyncDependencies().front();
auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
auto isHostShared = mlir::LLVM::ConstantOp::create(
rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
Value allocatedPtr =
allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
@@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
static_cast<uint64_t>(memrefTy.getNumElements());
Value sizeArg = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(staticSize));
Value sizeArg = LLVM::ConstantOp::create(
rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize));
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
llvmArgumentsWithSizes.push_back(sizeArg);
}
@@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
adaptor.getClusterSizeZ()};
}
rewriter.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::LaunchFuncOp::create(
rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ()},
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
@@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc,
const LLVMTypeConverter &typeConverter) {
auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc,
sourcePtr = LLVM::AddrSpaceCastOp::create(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(),
destinationType.getAddressSpace()),
sourcePtr);
@@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
Value gepPtr = rewriter.create<LLVM::GEPOp>(
loc, elementPtrType,
Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
Value gepPtr = LLVM::GEPOp::create(
rewriter, loc, elementPtrType,
typeConverter->convertType(memRefType.getElementType()), nullPtr,
numElements);
auto sizeBytes =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
srcDesc.alignedPtr(rewriter, loc),
@@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
auto value =
rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue());
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
dstDesc.alignedPtr(rewriter, loc),
*getTypeConverter());
@@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
template <typename T>
static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmInt32Type = builder.getIntegerType(32);
return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
static_cast<int32_t>(tValue));
return LLVM::ConstantOp::create(builder, loc, llvmInt32Type,
static_cast<int32_t>(tValue));
}
template <typename T>
static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
Type llvmFloat32Type = builder.getF32Type();
return builder.create<LLVM::ConstantOp>(
loc, llvmFloat32Type,
return LLVM::ConstantOp::create(
builder, loc, llvmFloat32Type,
builder.getF32FloatAttr(static_cast<float>(tValue)));
}
@@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
// the dnmat is used with spmat with 2:4 sparsity
if (dims.size() == 2) {
if (isSpMMCusparseLtOp(op.getDnTensor())) {
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(11032));
handle = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(11032));
handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
llvmInt8Type, handleSz, /*alignment=*/16);
handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
createLtDnMatCallBuilder
.create(loc, rewriter,
@@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
// CUDA runner asserts the size is 44104 bytes.
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(44104));
Value handle = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(44104));
Value handle = LLVM::AllocaOp::create(
rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle);
create2To4SpMatCallBuilder
.create(loc, rewriter,
@@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
auto computeType = genConstInt32From(
rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(3));
auto bufferSize = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(3));
auto bufferSize =
LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType,
three, /*alignment=*/16);
createCuSparseLtSpMMBufferSizeBuilder
.create(loc, rewriter,
{bufferSize, modeA, modeB, adaptor.getSpmatA(),
@@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
pruneFlag, stream})
.getResult();
auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, bufferSize,
ValueRange{rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(1))});
auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, bufferSize,
ValueRange{rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(2))});
auto bufferSizePtr1 = LLVM::GEPOp::create(
rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(1))});
auto bufferSizePtr2 = LLVM::GEPOp::create(
rewriter, loc, llvmPointerType, llvmPointerType, bufferSize,
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(2))});
auto bufferSize0 =
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize);
auto bufferSize1 =
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1);
auto bufferSize2 =
rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2);
rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
} else {
@@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op.getLoc();
auto stream = adaptor.getAsyncDependencies().front();
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(3));
auto buffer = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(3));
auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType,
llvmInt64Type, three, /*alignment=*/16);
auto rowsPtr = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(0))});
auto colsPtr = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(1))});
auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(2))});
auto rowsPtr = LLVM::GEPOp::create(
rewriter, loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(0))});
auto colsPtr = LLVM::GEPOp::create(
rewriter, loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(1))});
auto nnzsPtr = LLVM::GEPOp::create(
rewriter, loc, llvmPointerType, llvmPointerType, buffer,
ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
rewriter.getIndexAttr(2))});
createSpMatGetSizeBuilder.create(
loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr);
auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr);
auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr);
rewriter.replaceOp(op, {rows, cols, nnzs, stream});
return success();

View File

@@ -59,13 +59,13 @@ public:
Operation *newOp;
switch (op.getDimension()) {
case gpu::Dimension::x:
newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32));
newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
case gpu::Dimension::y:
newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32));
newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
case gpu::Dimension::z:
newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32));
newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32));
break;
}
@@ -124,11 +124,13 @@ public:
rewriter.getContext(), 32, min, max));
}
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
newOp = LLVM::SExtOp::create(rewriter, loc,
IntegerType::get(context, indexBitwidth),
newOp->getResult(0));
} else if (indexBitwidth < 32) {
newOp = rewriter.create<LLVM::TruncOp>(
loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0));
newOp = LLVM::TruncOp::create(rewriter, loc,
IntegerType::get(context, indexBitwidth),
newOp->getResult(0));
}
rewriter.replaceOp(op, newOp->getResults());

View File

@@ -103,7 +103,7 @@ public:
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp =
rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);
if (resultType == adaptor.getOperands().front().getType()) {
rewriter.replaceOp(op, {callOp.getResult()});
@@ -115,19 +115,20 @@ public:
// there is no guarantee of a specific value being used to indicate true,
// compare for inequality with zero (rather than truncate or shift).
if (isResultBool) {
Value zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getIntegerType(32),
rewriter.getI32IntegerAttr(0));
Value truncated = rewriter.create<LLVM::ICmpOp>(
op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero);
Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
rewriter.getIntegerType(32),
rewriter.getI32IntegerAttr(0));
Value truncated =
LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
callOp.getResult(), zero);
rewriter.replaceOp(op, {truncated});
return success();
}
assert(callOp.getResult().getType().isF32() &&
"only f32 types are supposed to be truncated back");
Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), adaptor.getOperands().front().getType(),
Value truncated = LLVM::FPTruncOp::create(
rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
callOp.getResult());
rewriter.replaceOp(op, {truncated});
return success();
@@ -142,8 +143,9 @@ public:
if (!f16Func.empty() && isa<Float16Type>(type))
return operand;
return rewriter.create<LLVM::FPExtOp>(
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
Float32Type::get(rewriter.getContext()),
operand);
}
Type getFunctionType(Type resultType, ValueRange operands) const {
@@ -169,7 +171,7 @@ public:
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
return b.create<LLVMFuncOp>(globalloc, funcName, funcType);
return LLVMFuncOp::create(b, globalloc, funcName, funcType);
}
StringRef getFunctionName(Type type, SourceOp op) const {

View File

@@ -54,8 +54,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
SymbolTable::lookupSymbolIn(symbolTable, name));
if (!func) {
OpBuilder b(symbolTable->getRegion(0));
func = b.create<LLVM::LLVMFuncOp>(
symbolTable->getLoc(), name,
func = LLVM::LLVMFuncOp::create(
b, symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
func.setNoUnwind(true);
@@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFuncOp func,
ValueRange args) {
auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
auto call = LLVM::CallOp::create(rewriter, loc, func, args);
call.setCConv(func.getCConv());
call.setConvergentAttr(func.getConvergentAttr());
call.setNoUnwindAttr(func.getNoUnwindAttr());
@@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
constexpr int64_t localMemFenceFlag = 1;
Location loc = op->getLoc();
Value flag =
rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag);
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
return success();
}
@@ -162,8 +162,8 @@ struct LaunchConfigConversion : ConvertToLLVMPattern {
Location loc = op->getLoc();
gpu::Dimension dim = getDimension(op);
Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
static_cast<int64_t>(dim));
Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
static_cast<int64_t>(dim));
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
return success();
}
@@ -291,13 +291,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) {
return TypeSwitch<Type, Value>(oldVal.getType())
.Case([&](BFloat16Type) {
return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
oldVal);
return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(),
oldVal);
})
.Case([&](IntegerType intTy) -> Value {
if (intTy.getWidth() == 1)
return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
oldVal);
return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(),
oldVal);
return oldVal;
})
.Default(oldVal);
@@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) {
return TypeSwitch<Type, Value>(newTy)
.Case([&](BFloat16Type) {
return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal);
})
.Case([&](IntegerType intTy) -> Value {
if (intTy.getWidth() == 1)
return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal);
return oldVal;
})
.Default(oldVal);
@@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {resultOrConversion, trueVal});
return success();
}
@@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
return failure();
}
result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result);
}
rewriter.replaceOp(op, result);

View File

@@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering
Location loc = op->getLoc();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
mode.value(), offset);
auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type,
op.getValue(), mode.value(), offset);
rewriter.replaceOp(op, reduxOp->getResult(0));
return success();
@@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
auto predTy = IntegerType::get(rewriter.getContext(), 1);
Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
loc, int32Type, thirtyTwo, adaptor.getWidth());
Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1);
Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1);
Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32);
Value numLeadInactiveLane = LLVM::SubOp::create(
rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth());
// Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
numLeadInactiveLane);
Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne,
numLeadInactiveLane);
Value maskAndClamp;
if (op.getMode() == gpu::ShuffleMode::UP) {
// Clamp lane: `32 - activeWidth`
maskAndClamp = numLeadInactiveLane;
} else {
// Clamp lane: `activeWidth - 1`
maskAndClamp =
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type,
adaptor.getWidth(), one);
}
bool predIsUsed = !op->getResult(1).use_empty();
@@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueTy, predTy});
}
Value shfl = rewriter.create<NVVM::ShflOp>(
loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
Value shfl = NVVM::ShflOp::create(
rewriter, loc, resultTy, activeMask, adaptor.getValue(),
adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()),
returnValueAndIsValidAttr);
if (predIsUsed) {
Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0);
Value isActiveSrcLane =
rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1);
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
} else {
rewriter.replaceOp(op, {shfl, nullptr});
@@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
Value newOp =
rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
newOp = rewriter.create<LLVM::SExtOp>(
loc, IntegerType::get(context, indexBitwidth), newOp);
newOp = LLVM::SExtOp::create(
rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
} else if (indexBitwidth < 32) {
newOp = rewriter.create<LLVM::TruncOp>(
loc, IntegerType::get(context, indexBitwidth), newOp);
newOp = LLVM::TruncOp::create(
rewriter, loc, IntegerType::get(context, indexBitwidth), newOp);
}
rewriter.replaceOp(op, {newOp});
return success();
@@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering
Block *afterBlock =
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
rewriter.setInsertionPointToEnd(beforeBlock);
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
assertBlock);
cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock,
assertBlock);
rewriter.setInsertionPointToEnd(assertBlock);
rewriter.create<cf::BranchOp>(loc, afterBlock);
cf::BranchOp::create(rewriter, loc, afterBlock);
// Continue cf.assert lowering.
rewriter.setInsertionPoint(assertOp);
@@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering
// Create constants.
auto getGlobal = [&](LLVM::GlobalOp global) {
// Get a pointer to the format string's first element.
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
Value globalPtr = LLVM::AddressOfOp::create(
rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
global.getSymNameAttr());
Value start =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
return start;
};
Value assertMessage = getGlobal(getOrCreateStringConstant(
@@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering
Value assertFunc = getGlobal(getOrCreateStringConstant(
rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
Value assertLine =
rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine);
Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1);
// Insert function call to __assertfail.
SmallVector<Value> arguments{assertMessage, assertFile, assertLine,

View File

@@ -126,8 +126,8 @@ struct WmmaLoadOpToNVVMLowering
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
adaptor.getSrcMemref(), adaptor.getIndices());
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
Value leadingDim = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32Type(),
subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
@@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering
auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
Value toUse =
rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i);
storeOpOperands.push_back(toUse);
}
@@ -181,8 +181,8 @@ struct WmmaStoreOpToNVVMLowering
rewriter, loc,
cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
adaptor.getDstMemref(), adaptor.getIndices());
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
Value leadingDim = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32Type(),
subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
@@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering
auto unpackOp = [&](Value operand) {
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
unpackedOps.push_back(toUse);
}
};
@@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element type is a vector create a vector from the operand.
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), vecEl);
vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
cst, idx);
Value idx = LLVM::ConstantOp::create(rewriter, loc,
rewriter.getI32Type(), vecEl);
vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst,
cst, idx);
}
cst = vecCst;
}
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
matrixStruct =
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
return success();
@@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
i1Type = VectorType::get(vecType.getShape(), i1Type);
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
Value isNan = builder.create<LLVM::FCmpOp>(
loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
Value nan = builder.create<LLVM::ConstantOp>(
loc, lhs.getType(),
Value cmp = LLVM::FCmpOp::create(
builder, loc, i1Type,
isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs);
Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs);
Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type,
LLVM::FCmpPredicate::uno, lhs, rhs);
Value nan = LLVM::ConstantOp::create(
builder, loc, lhs.getType(),
builder.getFloatAttr(floatType,
APFloat::getQNaN(floatType.getFloatSemantics())));
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
return LLVM::SelectOp::create(builder, loc, isNan, nan, sel);
}
static Value createScalarOp(OpBuilder &builder, Location loc,
@@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
ArrayRef<Value> operands) {
switch (op) {
case gpu::MMAElementwiseOp::ADDF:
return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MULF:
return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::DIVF:
return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MAXF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/false);
@@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getOperands()[opIdx], i));
extractedOperands.push_back(LLVM::ExtractValueOp::create(
rewriter, loc, adaptor.getOperands()[opIdx], i));
}
Value element =
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
extractedOperands);
matrixStruct =
rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i);
}
rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
return success();

View File

@@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
// TODO: use <=> in C++20.
if (indexBitwidth > intWidth) {
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value);
}
if (indexBitwidth < intWidth) {
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value);
}
return value;
}
@@ -82,12 +82,12 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
const unsigned indexBitwidth) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
ValueRange{minus1, zero});
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
ValueRange{minus1, mbcntLo});
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
ValueRange{minus1, zero});
Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
ValueRange{minus1, mbcntLo});
return laneId;
}
static constexpr StringLiteral amdgcnDataLayout =
@@ -110,21 +110,21 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
// followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
Type intTy = IntegerType::get(context, 32);
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
Value mbcntLo =
rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
loc, intTy, ValueRange{minus1, mbcntLo});
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
ValueRange{minus1, zero});
Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
ValueRange{minus1, mbcntLo});
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
laneId = rewriter.create<LLVM::SExtOp>(
loc, IntegerType::get(context, indexBitwidth), laneId);
laneId = LLVM::SExtOp::create(
rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
} else if (indexBitwidth < 32) {
laneId = rewriter.create<LLVM::TruncOp>(
loc, IntegerType::get(context, indexBitwidth), laneId);
laneId = LLVM::TruncOp::create(
rewriter, loc, IntegerType::get(context, indexBitwidth), laneId);
}
rewriter.replaceOp(op, {laneId});
return success();
@@ -149,8 +149,8 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
}
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
op.getLoc(), rewriter.getI32Type(), bounds);
Value wavefrontOp = ROCDL::WavefrontSizeOp::create(
rewriter, op.getLoc(), rewriter.getI32Type(), bounds);
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
*getTypeConverter());
rewriter.replaceOp(op, {wavefrontOp});
@@ -190,44 +190,44 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0);
Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width);
Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width);
Value widthOrZeroIfOutside =
rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth);
Value dstLane;
switch (op.getMode()) {
case gpu::ShuffleMode::UP:
dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId,
adaptor.getOffset());
break;
case gpu::ShuffleMode::DOWN:
dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId,
adaptor.getOffset());
break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId,
adaptor.getOffset());
break;
case gpu::ShuffleMode::IDX:
dstLane = adaptor.getOffset();
break;
}
Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane,
dstLane, srcLaneId);
Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
Value isActiveSrcLane = LLVM::ICmpOp::create(
rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside);
Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane,
dstLane, srcLaneId);
Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2);
Value dwordAlignedDstLane =
rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
dwordAlignedDstLane, v);
Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
dwordAlignedDstLane, v);
swizzled.emplace_back(res);
}
Value shflValue =

View File

@@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value vector =
spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
Value dim = rewriter.create<spirv::CompositeExtractOp>(
op.getLoc(), builtinType, vector,
Value dim = spirv::CompositeExtractOp::create(
rewriter, op.getLoc(), builtinType, vector,
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
if (forShader && builtinType != indexType)
dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim);
rewriter.replaceOp(op, dim);
return success();
}
@@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
Value builtinValue =
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
if (i32Type != indexType)
builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
builtinValue);
builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType,
builtinValue);
rewriter.replaceOp(op, builtinValue);
return success();
}
@@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
signatureConverter.addInputs(argType.index(), convertedType);
}
}
auto newFuncOp = rewriter.create<spirv::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
auto newFuncOp = spirv::FuncOp::create(
rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {}));
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
@@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
auto spvModule = rewriter.create<spirv::ModuleOp>(
moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
auto spvModule = spirv::ModuleOp::create(
rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
@@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR: {
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
result = spirv::GroupNonUniformShuffleXorOp::create(
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::IDX: {
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
result = spirv::GroupNonUniformShuffleOp::create(
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::DOWN: {
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
result = spirv::GroupNonUniformShuffleDownOp::create(
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
resultLaneId, adaptor.getWidth());
arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset());
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
resultLaneId, adaptor.getWidth());
break;
}
case gpu::ShuffleMode::UP: {
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
result = spirv::GroupNonUniformShuffleUpOp::create(
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset());
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
validVal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, resultLaneId,
rewriter.create<arith::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
validVal = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, resultLaneId,
arith::ConstantOp::create(rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, 0)));
break;
}
}
@@ -516,15 +516,16 @@ LogicalResult GPURotateConversion::matchAndRewrite(
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
adaptor.getWidth());
Value validVal;
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
laneId, adaptor.getWidth());
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
laneId, adaptor.getWidth());
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -548,14 +549,14 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
? spirv::GroupOperation::ClusteredReduce
: spirv::GroupOperation::Reduce);
if (isUniform) {
return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
return UniformOp::create(builder, loc, type, scope, groupOp, arg)
.getResult();
}
Value clusterSizeValue;
if (clusterSize.has_value())
clusterSizeValue = builder.create<spirv::ConstantOp>(
loc, builder.getI32Type(),
clusterSizeValue = spirv::ConstantOp::create(
builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
return builder
@@ -740,8 +741,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstName =
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
return rewriter.create<spirv::SpecConstantOp>(
loc, rewriter.getStringAttr(specCstName), attr);
return spirv::SpecConstantOp::create(
rewriter, loc, rewriter.getStringAttr(specCstName), attr);
};
{
Operation *parent =
@@ -774,8 +775,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
std::string specCstCompositeName =
(llvm::Twine(globalVarName) + "_scc").str();
specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
loc, TypeAttr::get(globalType),
specCstComposite = spirv::SpecConstantCompositeOp::create(
rewriter, loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));
@@ -785,23 +786,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite(
// Define a GlobalVarOp initialized using specialized constants
// that is used to specify the printf format string
// to be passed to the SPIRV CLPrintfOp.
globalVar = rewriter.create<spirv::GlobalVariableOp>(
loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
globalVar = spirv::GlobalVariableOp::create(
rewriter, loc, ptrType, globalVarName,
FlatSymbolRefAttr::get(specCstComposite));
globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
// Get SSA value of Global variable and create pointer to i8 to point to
// the format string.
Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
Value fmtStr = rewriter.create<spirv::BitcastOp>(
loc,
Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar);
Value fmtStr = spirv::BitcastOp::create(
rewriter, loc,
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);
// Get printf arguments.
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs);
// Need to erase the gpu.printf op as gpu.printf does not use result vs
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V

View File

@@ -144,11 +144,12 @@ void GPUToSPIRVPass::runOnOperation() {
if (targetEnvSupportsKernelCapability(moduleOp)) {
moduleOp.walk([&](gpu::GPUFuncOp funcOp) {
builder.setInsertionPoint(funcOp);
auto newFuncOp = builder.create<func::FuncOp>(
funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
auto newFuncOp =
func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(),
funcOp.getFunctionType());
auto entryBlock = newFuncOp.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);
builder.create<func::ReturnOp>(funcOp.getLoc());
func::ReturnOp::create(builder, funcOp.getLoc());
newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
funcOp.erase();

View File

@@ -283,8 +283,8 @@ struct WmmaLoadOpToSPIRVLowering final
int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
auto strideValue = spirv::ConstantOp::create(
rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
@@ -315,8 +315,8 @@ struct WmmaStoreOpToSPIRVLowering final
int64_t stride = op.getLeadDimension().getSExtValue();
IntegerType i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
auto strideValue = spirv::ConstantOp::create(
rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride));
bool isColMajor = op.getTranspose().value_or(false);
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor