From 3f23c7f5bedc8786d3f4567d2331a7efcbb2a77e Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Tue, 21 Mar 2023 18:00:08 -0700 Subject: [PATCH] [InstSimplify] Actually use NewOps for calls in simplifyInstructionWithOperands Resolves a TODO. Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D146599 --- llvm/include/llvm/Analysis/InstructionSimplify.h | 5 +- llvm/lib/Analysis/InstructionSimplify.cpp | 118 ++++++++++----------- .../Transforms/InstCombine/InstCombineCalls.cpp | 10 +- llvm/unittests/Transforms/Utils/LocalTest.cpp | 3 +- 4 files changed, 72 insertions(+), 64 deletions(-) diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index 861fa3b..826bd45 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -302,8 +302,9 @@ Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); -/// Given a callsite, fold the result or return null. -Value *simplifyCall(CallBase *Call, const SimplifyQuery &Q); +/// Given a callsite, callee, and arguments, fold the result or return null. +Value *simplifyCall(CallBase *Call, Value *Callee, ArrayRef Args, + const SimplifyQuery &Q); /// Given a constrained FP intrinsic call, tries to compute its simplified /// version. Returns a simplified result or null. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index ecb0cdb..eaf0af9 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -6391,10 +6391,13 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return nullptr; } -static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { - - unsigned NumOperands = Call->arg_size(); - Function *F = cast(Call->getCalledFunction()); +static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, + ArrayRef Args, + const SimplifyQuery &Q) { + // Operand bundles should not be in Args. + assert(Call->arg_size() == Args.size()); + unsigned NumOperands = Args.size(); + Function *F = cast(Callee); Intrinsic::ID IID = F->getIntrinsicID(); // Most of the intrinsics with no operands have some kind of side effect. @@ -6420,18 +6423,17 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } if (NumOperands == 1) - return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q); + return simplifyUnaryIntrinsic(F, Args[0], Q); if (NumOperands == 2) - return simplifyBinaryIntrinsic(F, Call->getArgOperand(0), - Call->getArgOperand(1), Q); + return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q); // Handle intrinsics with 3 or more arguments. switch (IID) { case Intrinsic::masked_load: case Intrinsic::masked_gather: { - Value *MaskArg = Call->getArgOperand(2); - Value *PassthruArg = Call->getArgOperand(3); + Value *MaskArg = Args[2]; + Value *PassthruArg = Args[3]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; @@ -6439,8 +6441,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } case Intrinsic::fshl: case Intrinsic::fshr: { - Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1), - *ShAmtArg = Call->getArgOperand(2); + Value *Op0 = Args[0], *Op1 = Args[1], *ShAmtArg = Args[2]; // If both operands are undef, the result is undef. if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1)) @@ -6448,14 +6449,14 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { // If shift amount is undef, assume it is zero. if (Q.isUndefValue(ShAmtArg)) - return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); + return Args[IID == Intrinsic::fshl ? 0 : 1]; const APInt *ShAmtC; if (match(ShAmtArg, m_APInt(ShAmtC))) { // If there's effectively no shift, return the 1st arg or 2nd arg. APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); if (ShAmtC->urem(BitWidth).isZero()) - return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); + return Args[IID == Intrinsic::fshl ? 0 : 1]; } // Rotating zero by anything is zero. @@ -6469,31 +6470,24 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { return nullptr; } case Intrinsic::experimental_constrained_fma: { - Value *Op0 = Call->getArgOperand(0); - Value *Op1 = Call->getArgOperand(1); - Value *Op2 = Call->getArgOperand(2); auto *FPI = cast(Call); - if (Value *V = - simplifyFPOp({Op0, Op1, Op2}, {}, Q, *FPI->getExceptionBehavior(), - *FPI->getRoundingMode())) + if (Value *V = simplifyFPOp(Args, {}, Q, *FPI->getExceptionBehavior(), + *FPI->getRoundingMode())) return V; return nullptr; } case Intrinsic::fma: case Intrinsic::fmuladd: { - Value *Op0 = Call->getArgOperand(0); - Value *Op1 = Call->getArgOperand(1); - Value *Op2 = Call->getArgOperand(2); - if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, fp::ebIgnore, + if (Value *V = simplifyFPOp(Args, {}, Q, fp::ebIgnore, RoundingMode::NearestTiesToEven)) return V; return nullptr; } case Intrinsic::smul_fix: case Intrinsic::smul_fix_sat: { - Value *Op0 = Call->getArgOperand(0); - Value *Op1 = Call->getArgOperand(1); - Value *Op2 = Call->getArgOperand(2); + Value *Op0 = Args[0]; + Value *Op1 = Args[1]; + Value *Op2 = Args[2]; Type *ReturnType = F->getReturnType(); // Canonicalize constant operand as Op1 (ConstantFolding handles the case @@ -6520,9 +6514,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { return nullptr; } case Intrinsic::vector_insert: { - Value *Vec = Call->getArgOperand(0); - Value *SubVec = Call->getArgOperand(1); - Value *Idx = Call->getArgOperand(2); + Value *Vec = Args[0]; + Value *SubVec = Args[1]; + Value *Idx = Args[2]; Type *ReturnType = F->getReturnType(); // (insert_vector Y, (extract_vector X, 0), 0) -> X @@ -6539,51 +6533,52 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } case Intrinsic::experimental_constrained_fadd: { auto *FPI = cast(Call); - return simplifyFAddInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFAddInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_fsub: { auto *FPI = cast(Call); - return simplifyFSubInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFSubInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_fmul: { auto *FPI = cast(Call); - return simplifyFMulInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFMulInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_fdiv: { auto *FPI = cast(Call); - return simplifyFDivInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFDivInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_frem: { auto *FPI = cast(Call); - return simplifyFRemInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFRemInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } default: return nullptr; } } -static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { - auto *F = dyn_cast(Call->getCalledOperand()); +static Value *tryConstantFoldCall(CallBase *Call, Value *Callee, + ArrayRef Args, + const SimplifyQuery &Q) { + auto *F = dyn_cast(Callee); if (!F || !canConstantFoldCallTo(Call, F)) return nullptr; SmallVector ConstantArgs; - unsigned NumArgs = Call->arg_size(); - ConstantArgs.reserve(NumArgs); - for (auto &Arg : Call->args()) { - Constant *C = dyn_cast(&Arg); + ConstantArgs.reserve(Args.size()); + for (Value *Arg : Args) { + Constant *C = dyn_cast(Arg); if (!C) { - if (isa(Arg.get())) + if (isa(Arg)) continue; return nullptr; } @@ -6593,7 +6588,11 @@ static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } -Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { +Value *llvm::simplifyCall(CallBase *Call, Value *Callee, ArrayRef Args, + const SimplifyQuery &Q) { + // Args should not contain operand bundle operands. + assert(Call->arg_size() == Args.size()); + // musttail calls can only be simplified if they are also DCEd. // As we can't guarantee this here, don't simplify them. if (Call->isMustTailCall()) @@ -6601,16 +6600,15 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { // call undef -> poison // call null -> poison - Value *Callee = Call->getCalledOperand(); if (isa(Callee) || isa(Callee)) return PoisonValue::get(Call->getType()); - if (Value *V = tryConstantFoldCall(Call, Q)) + if (Value *V = tryConstantFoldCall(Call, Callee, Args, Q)) return V; auto *F = dyn_cast(Callee); if (F && F->isIntrinsic()) - if (Value *Ret = simplifyIntrinsic(Call, Q)) + if (Value *Ret = simplifyIntrinsic(Call, Callee, Args, Q)) return Ret; return nullptr; @@ -6618,9 +6616,10 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { Value *llvm::simplifyConstrainedFPCall(CallBase *Call, const SimplifyQuery &Q) { assert(isa(Call)); - if (Value *V = tryConstantFoldCall(Call, Q)) + SmallVector Args(Call->args()); + if (Value *V = tryConstantFoldCall(Call, Call->getCalledOperand(), Args, Q)) return V; - if (Value *Ret = simplifyIntrinsic(Call, Q)) + if (Value *Ret = simplifyIntrinsic(Call, Call->getCalledOperand(), Args, Q)) return Ret; return nullptr; } @@ -6775,8 +6774,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I, case Instruction::PHI: return simplifyPHINode(cast(I), NewOps, Q); case Instruction::Call: - // TODO: Use NewOps - return simplifyCall(cast(I), Q); + return simplifyCall( + cast(I), NewOps.back(), + NewOps.drop_back(1 + cast(I)->getNumTotalBundleOperands()), Q); case Instruction::Freeze: return llvm::simplifyFreezeInst(NewOps[0], Q); #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 2b61b58..0fbd62e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1288,9 +1288,15 @@ foldShuffledIntrinsicOperands(IntrinsicInst *II, Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Don't try to simplify calls without uses. It will not do anything useful, // but will result in the following folds being skipped. - if (!CI.use_empty()) - if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI))) + if (!CI.use_empty()) { + SmallVector Args; + Args.reserve(CI.arg_size()); + for (Value *Op : CI.args()) + Args.push_back(Op); + if (Value *V = simplifyCall(&CI, CI.getCalledOperand(), Args, + SQ.getWithInstruction(&CI))) return replaceInstUsesWith(CI, V); + } if (Value *FreedOp = getFreedOperand(&CI, &TLI)) return visitFree(CI, FreedOp); diff --git a/llvm/unittests/Transforms/Utils/LocalTest.cpp b/llvm/unittests/Transforms/Utils/LocalTest.cpp index d6b09b3..443f1f0 100644 --- a/llvm/unittests/Transforms/Utils/LocalTest.cpp +++ b/llvm/unittests/Transforms/Utils/LocalTest.cpp @@ -598,7 +598,8 @@ TEST(Local, SimplifyVScaleWithRange) { // Test that simplifyCall won't try to query it's parent function for // vscale_range attributes in order to simplify llvm.vscale -> constant. - EXPECT_EQ(simplifyCall(CI, SimplifyQuery(M.getDataLayout())), nullptr); + EXPECT_EQ(simplifyCall(CI, VScale, {}, SimplifyQuery(M.getDataLayout())), + nullptr); delete CI; } -- 2.7.4