#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
if (!Visited.insert(U).second)
continue;
- ImmutableCallSite CS(U);
+ const auto *CB = dyn_cast<CallBase>(U);
for (const auto &OI : U->operands()) {
const User *Operand = dyn_cast<User>(OI);
// We have a reference to a global value. This should be added to
// the reference set unless it is a callee. Callees are handled
// specially by WriteFunction and are added to a separate list.
- if (!(CS && CS.isCallee(&OI)))
+ if (!(CB && CB->isCallee(&OI)))
RefEdges.insert(Index.getOrInsertValueInfo(GV));
continue;
}
SetVector<FunctionSummary::ConstVCall> &ConstVCalls) {
std::vector<uint64_t> Args;
// Start from the second argument to skip the "this" pointer.
- for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) {
+ for (auto &Arg : make_range(Call.CB.arg_begin() + 1, Call.CB.arg_end())) {
auto *CI = dyn_cast<ConstantInt>(Arg);
if (!CI || CI->getBitWidth() > 64) {
VCalls.insert({Guid, Call.Offset});
}
}
findRefEdges(Index, &I, RefEdges, Visited);
- auto CS = ImmutableCallSite(&I);
- if (!CS)
+ const auto *CB = dyn_cast<CallBase>(&I);
+ if (!CB)
continue;
const auto *CI = dyn_cast<CallInst>(&I);
if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm())
HasInlineAsmMaybeReferencingInternal = true;
- auto *CalledValue = CS.getCalledValue();
- auto *CalledFunction = CS.getCalledFunction();
+ auto *CalledValue = CB->getCalledValue();
+ auto *CalledFunction = CB->getCalledFunction();
if (CalledValue && !CalledFunction) {
CalledValue = CalledValue->stripPointerCasts();
// Stripping pointer casts can reveal a called function.
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugLoc.h"
// A virtual call site. VTable is the loaded virtual table pointer, and CS is
// the indirect virtual call.
struct VirtualCallSite {
- Value *VTable;
- CallSite CS;
+ Value *VTable = nullptr;
+ CallBase &CB;
// If non-null, this field points to the associated unsafe use count stored in
// the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
// of that field for details.
- unsigned *NumUnsafeUses;
+ unsigned *NumUnsafeUses = nullptr;
void
emitRemark(const StringRef OptName, const StringRef TargetName,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
- Function *F = CS.getCaller();
- DebugLoc DLoc = CS->getDebugLoc();
- BasicBlock *Block = CS.getParent();
+ Function *F = CB.getCaller();
+ DebugLoc DLoc = CB.getDebugLoc();
+ BasicBlock *Block = CB.getParent();
using namespace ore;
OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
Value *New) {
if (RemarksEnabled)
emitRemark(OptName, TargetName, OREGetter);
- CS->replaceAllUsesWith(New);
- if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
- BranchInst::Create(II->getNormalDest(), CS.getInstruction());
+ CB.replaceAllUsesWith(New);
+ if (auto *II = dyn_cast<InvokeInst>(&CB)) {
+ BranchInst::Create(II->getNormalDest(), &CB);
II->getUnwindDest()->removePredecessor(II->getParent());
}
- CS->eraseFromParent();
+ CB.eraseFromParent();
// This use is no longer unsafe.
if (NumUnsafeUses)
--*NumUnsafeUses;
// "this"), grouped by argument list.
std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
- void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
+ void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
private:
- CallSiteInfo &findCallSiteInfo(CallSite CS);
+ CallSiteInfo &findCallSiteInfo(CallBase &CB);
};
-CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
+CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
std::vector<uint64_t> Args;
- auto *CI = dyn_cast<IntegerType>(CS.getType());
- if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
+ auto *CBType = dyn_cast<IntegerType>(CB.getType());
+ if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
return CSInfo;
- for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
+ for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) {
auto *CI = dyn_cast<ConstantInt>(Arg);
if (!CI || CI->getBitWidth() > 64)
return CSInfo;
return ConstCSInfo[Args];
}
-void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
+void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
unsigned *NumUnsafeUses) {
- auto &CSI = findCallSiteInfo(CS);
+ auto &CSI = findCallSiteInfo(CB);
CSI.AllCallSitesDevirted = false;
- CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
+ CSI.CallSites.push_back({VTable, CB, NumUnsafeUses});
}
struct DevirtModule {
if (RemarksEnabled)
VCallSite.emitRemark("single-impl",
TheFn->stripPointerCasts()->getName(), OREGetter);
- VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
- TheFn, VCallSite.CS.getCalledValue()->getType()));
+ VCallSite.CB.setCalledOperand(ConstantExpr::getBitCast(
+ TheFn, VCallSite.CB.getCalledValue()->getType()));
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
if (CSInfo.AllCallSitesDevirted)
return;
for (auto &&VCallSite : CSInfo.CallSites) {
- CallSite CS = VCallSite.CS;
+ CallBase &CB = VCallSite.CB;
// Jump tables are only profitable if the retpoline mitigation is enabled.
- Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+ Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
if (FSAttr.hasAttribute(Attribute::None) ||
!FSAttr.getValueAsString().contains("+retpoline"))
continue;
// x86_64.
std::vector<Type *> NewArgs;
NewArgs.push_back(Int8PtrTy);
- for (Type *T : CS.getFunctionType()->params())
+ for (Type *T : CB.getFunctionType()->params())
NewArgs.push_back(T);
FunctionType *NewFT =
- FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
- CS.getFunctionType()->isVarArg());
+ FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs,
+ CB.getFunctionType()->isVarArg());
PointerType *NewFTPtr = PointerType::getUnqual(NewFT);
- IRBuilder<> IRB(CS.getInstruction());
+ IRBuilder<> IRB(&CB);
std::vector<Value *> Args;
Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
- for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
- Args.push_back(CS.getArgOperand(I));
+ Args.insert(Args.end(), CB.arg_begin(), CB.arg_end());
- CallSite NewCS;
- if (CS.isCall())
+ CallBase *NewCS = nullptr;
+ if (isa<CallInst>(CB))
NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args);
else
- NewCS = IRB.CreateInvoke(
- NewFT, IRB.CreateBitCast(JT, NewFTPtr),
- cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
- cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
- NewCS.setCallingConv(CS.getCallingConv());
+ NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr),
+ cast<InvokeInst>(CB).getNormalDest(),
+ cast<InvokeInst>(CB).getUnwindDest(), Args);
+ NewCS->setCallingConv(CB.getCallingConv());
- AttributeList Attrs = CS.getAttributes();
+ AttributeList Attrs = CB.getAttributes();
std::vector<AttributeSet> NewArgAttrs;
NewArgAttrs.push_back(AttributeSet::get(
M.getContext(), ArrayRef<Attribute>{Attribute::get(
M.getContext(), Attribute::Nest)}));
for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I)
NewArgAttrs.push_back(Attrs.getParamAttributes(I));
- NewCS.setAttributes(
+ NewCS->setAttributes(
AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
Attrs.getRetAttributes(), NewArgAttrs));
- CS->replaceAllUsesWith(NewCS.getInstruction());
- CS->eraseFromParent();
+ CB.replaceAllUsesWith(NewCS);
+ CB.eraseFromParent();
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
for (auto Call : CSInfo.CallSites)
Call.replaceAndErase(
"uniform-ret-val", FnName, RemarksEnabled, OREGetter,
- ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
+ ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal));
CSInfo.markDevirt();
}
bool IsOne,
Constant *UniqueMemberAddr) {
for (auto &&Call : CSInfo.CallSites) {
- IRBuilder<> B(Call.CS.getInstruction());
+ IRBuilder<> B(&Call.CB);
Value *Cmp =
B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable,
B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType()));
- Cmp = B.CreateZExt(Cmp, Call.CS->getType());
+ Cmp = B.CreateZExt(Cmp, Call.CB.getType());
Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
Cmp);
}
void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
Constant *Byte, Constant *Bit) {
for (auto Call : CSInfo.CallSites) {
- auto *RetType = cast<IntegerType>(Call.CS.getType());
- IRBuilder<> B(Call.CS.getInstruction());
+ auto *RetType = cast<IntegerType>(Call.CB.getType());
+ IRBuilder<> B(&Call.CB);
Value *Addr =
B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
if (RetType->getBitWidth() == 1) {
// points to a member of the type identifier %md. Group calls by (type ID,
// offset) pair (effectively the identity of the virtual function) and store
// to CallSlots.
- DenseSet<CallSite> SeenCallSites;
+ DenseSet<CallBase *> SeenCallSites;
for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
I != E;) {
auto CI = dyn_cast<CallInst>(I->getUser());
// and we don't want to process call sites multiple times. We can't
// just skip the vtable Ptr if it has been seen before, however, since
// it may be shared by type tests that dominate different calls.
- if (SeenCallSites.insert(Call.CS).second)
- CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
+ if (SeenCallSites.insert(&Call.CB).second)
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr);
}
}
if (HasNonCallUses)
++NumUnsafeUses;
for (DevirtCallSite Call : DevirtCalls) {
- CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB,
&NumUnsafeUses);
}