[InstSimplify] Actually use NewOps for calls in simplifyInstructionWithOperands

Resolves a TODO.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D146599
This commit is contained in:
Arthur Eubanks
2023-03-21 18:00:08 -07:00
parent 0528087663
commit 3f23c7f5be
4 changed files with 72 additions and 64 deletions

View File

@@ -302,8 +302,9 @@ Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF,
const SimplifyQuery &Q);
/// Given a callsite, fold the result or return null.
Value *simplifyCall(CallBase *Call, const SimplifyQuery &Q);
/// Given a callsite, callee, and arguments, fold the result or return null.
Value *simplifyCall(CallBase *Call, Value *Callee, ArrayRef<Value *> Args,
const SimplifyQuery &Q);
/// Given a constrained FP intrinsic call, tries to compute its simplified
/// version. Returns a simplified result or null.

View File

@@ -6391,10 +6391,13 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
return nullptr;
}
static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
unsigned NumOperands = Call->arg_size();
Function *F = cast<Function>(Call->getCalledFunction());
static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
ArrayRef<Value *> Args,
const SimplifyQuery &Q) {
// Operand bundles should not be in Args.
assert(Call->arg_size() == Args.size());
unsigned NumOperands = Args.size();
Function *F = cast<Function>(Callee);
Intrinsic::ID IID = F->getIntrinsicID();
// Most of the intrinsics with no operands have some kind of side effect.
@@ -6420,18 +6423,17 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
}
if (NumOperands == 1)
return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q);
return simplifyUnaryIntrinsic(F, Args[0], Q);
if (NumOperands == 2)
return simplifyBinaryIntrinsic(F, Call->getArgOperand(0),
Call->getArgOperand(1), Q);
return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q);
// Handle intrinsics with 3 or more arguments.
switch (IID) {
case Intrinsic::masked_load:
case Intrinsic::masked_gather: {
Value *MaskArg = Call->getArgOperand(2);
Value *PassthruArg = Call->getArgOperand(3);
Value *MaskArg = Args[2];
Value *PassthruArg = Args[3];
// If the mask is all zeros or undef, the "passthru" argument is the result.
if (maskIsAllZeroOrUndef(MaskArg))
return PassthruArg;
@@ -6439,8 +6441,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
}
case Intrinsic::fshl:
case Intrinsic::fshr: {
Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1),
*ShAmtArg = Call->getArgOperand(2);
Value *Op0 = Args[0], *Op1 = Args[1], *ShAmtArg = Args[2];
// If both operands are undef, the result is undef.
if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1))
@@ -6448,14 +6449,14 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
// If shift amount is undef, assume it is zero.
if (Q.isUndefValue(ShAmtArg))
return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);
return Args[IID == Intrinsic::fshl ? 0 : 1];
const APInt *ShAmtC;
if (match(ShAmtArg, m_APInt(ShAmtC))) {
// If there's effectively no shift, return the 1st arg or 2nd arg.
APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth());
if (ShAmtC->urem(BitWidth).isZero())
return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);
return Args[IID == Intrinsic::fshl ? 0 : 1];
}
// Rotating zero by anything is zero.
@@ -6469,31 +6470,24 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
return nullptr;
}
case Intrinsic::experimental_constrained_fma: {
Value *Op0 = Call->getArgOperand(0);
Value *Op1 = Call->getArgOperand(1);
Value *Op2 = Call->getArgOperand(2);
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
if (Value *V =
simplifyFPOp({Op0, Op1, Op2}, {}, Q, *FPI->getExceptionBehavior(),
*FPI->getRoundingMode()))
if (Value *V = simplifyFPOp(Args, {}, Q, *FPI->getExceptionBehavior(),
*FPI->getRoundingMode()))
return V;
return nullptr;
}
case Intrinsic::fma:
case Intrinsic::fmuladd: {
Value *Op0 = Call->getArgOperand(0);
Value *Op1 = Call->getArgOperand(1);
Value *Op2 = Call->getArgOperand(2);
if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, fp::ebIgnore,
if (Value *V = simplifyFPOp(Args, {}, Q, fp::ebIgnore,
RoundingMode::NearestTiesToEven))
return V;
return nullptr;
}
case Intrinsic::smul_fix:
case Intrinsic::smul_fix_sat: {
Value *Op0 = Call->getArgOperand(0);
Value *Op1 = Call->getArgOperand(1);
Value *Op2 = Call->getArgOperand(2);
Value *Op0 = Args[0];
Value *Op1 = Args[1];
Value *Op2 = Args[2];
Type *ReturnType = F->getReturnType();
// Canonicalize constant operand as Op1 (ConstantFolding handles the case
@@ -6520,9 +6514,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
return nullptr;
}
case Intrinsic::vector_insert: {
Value *Vec = Call->getArgOperand(0);
Value *SubVec = Call->getArgOperand(1);
Value *Idx = Call->getArgOperand(2);
Value *Vec = Args[0];
Value *SubVec = Args[1];
Value *Idx = Args[2];
Type *ReturnType = F->getReturnType();
// (insert_vector Y, (extract_vector X, 0), 0) -> X
@@ -6539,51 +6533,52 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
}
case Intrinsic::experimental_constrained_fadd: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFAddInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
return simplifyFAddInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
*FPI->getExceptionBehavior(),
*FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_fsub: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFSubInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
return simplifyFSubInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
*FPI->getExceptionBehavior(),
*FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_fmul: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFMulInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
return simplifyFMulInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
*FPI->getExceptionBehavior(),
*FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_fdiv: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFDivInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
return simplifyFDivInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
*FPI->getExceptionBehavior(),
*FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_frem: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFRemInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
return simplifyFRemInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
*FPI->getExceptionBehavior(),
*FPI->getRoundingMode());
}
default:
return nullptr;
}
}
static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) {
auto *F = dyn_cast<Function>(Call->getCalledOperand());
static Value *tryConstantFoldCall(CallBase *Call, Value *Callee,
ArrayRef<Value *> Args,
const SimplifyQuery &Q) {
auto *F = dyn_cast<Function>(Callee);
if (!F || !canConstantFoldCallTo(Call, F))
return nullptr;
SmallVector<Constant *, 4> ConstantArgs;
unsigned NumArgs = Call->arg_size();
ConstantArgs.reserve(NumArgs);
for (auto &Arg : Call->args()) {
Constant *C = dyn_cast<Constant>(&Arg);
ConstantArgs.reserve(Args.size());
for (Value *Arg : Args) {
Constant *C = dyn_cast<Constant>(Arg);
if (!C) {
if (isa<MetadataAsValue>(Arg.get()))
if (isa<MetadataAsValue>(Arg))
continue;
return nullptr;
}
@@ -6593,7 +6588,11 @@ static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) {
return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI);
}
Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) {
Value *llvm::simplifyCall(CallBase *Call, Value *Callee, ArrayRef<Value *> Args,
const SimplifyQuery &Q) {
// Args should not contain operand bundle operands.
assert(Call->arg_size() == Args.size());
// musttail calls can only be simplified if they are also DCEd.
// As we can't guarantee this here, don't simplify them.
if (Call->isMustTailCall())
@@ -6601,16 +6600,15 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) {
// call undef -> poison
// call null -> poison
Value *Callee = Call->getCalledOperand();
if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee))
return PoisonValue::get(Call->getType());
if (Value *V = tryConstantFoldCall(Call, Q))
if (Value *V = tryConstantFoldCall(Call, Callee, Args, Q))
return V;
auto *F = dyn_cast<Function>(Callee);
if (F && F->isIntrinsic())
if (Value *Ret = simplifyIntrinsic(Call, Q))
if (Value *Ret = simplifyIntrinsic(Call, Callee, Args, Q))
return Ret;
return nullptr;
@@ -6618,9 +6616,10 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) {
Value *llvm::simplifyConstrainedFPCall(CallBase *Call, const SimplifyQuery &Q) {
assert(isa<ConstrainedFPIntrinsic>(Call));
if (Value *V = tryConstantFoldCall(Call, Q))
SmallVector<Value *, 4> Args(Call->args());
if (Value *V = tryConstantFoldCall(Call, Call->getCalledOperand(), Args, Q))
return V;
if (Value *Ret = simplifyIntrinsic(Call, Q))
if (Value *Ret = simplifyIntrinsic(Call, Call->getCalledOperand(), Args, Q))
return Ret;
return nullptr;
}
@@ -6775,8 +6774,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
case Instruction::PHI:
return simplifyPHINode(cast<PHINode>(I), NewOps, Q);
case Instruction::Call:
// TODO: Use NewOps
return simplifyCall(cast<CallInst>(I), Q);
return simplifyCall(
cast<CallInst>(I), NewOps.back(),
NewOps.drop_back(1 + cast<CallInst>(I)->getNumTotalBundleOperands()), Q);
case Instruction::Freeze:
return llvm::simplifyFreezeInst(NewOps[0], Q);
#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc:

View File

@@ -1288,9 +1288,15 @@ foldShuffledIntrinsicOperands(IntrinsicInst *II,
Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// Don't try to simplify calls without uses. It will not do anything useful,
// but will result in the following folds being skipped.
if (!CI.use_empty())
if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI)))
if (!CI.use_empty()) {
SmallVector<Value *, 4> Args;
Args.reserve(CI.arg_size());
for (Value *Op : CI.args())
Args.push_back(Op);
if (Value *V = simplifyCall(&CI, CI.getCalledOperand(), Args,
SQ.getWithInstruction(&CI)))
return replaceInstUsesWith(CI, V);
}
if (Value *FreedOp = getFreedOperand(&CI, &TLI))
return visitFree(CI, FreedOp);

View File

@@ -598,7 +598,8 @@ TEST(Local, SimplifyVScaleWithRange) {
// Test that simplifyCall won't try to query it's parent function for
// vscale_range attributes in order to simplify llvm.vscale -> constant.
EXPECT_EQ(simplifyCall(CI, SimplifyQuery(M.getDataLayout())), nullptr);
EXPECT_EQ(simplifyCall(CI, VScale, {}, SimplifyQuery(M.getDataLayout())),
nullptr);
delete CI;
}