#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/InstVisitor.h"
#include "llvm/Transforms/Scalar/SCCP.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/SCCPSolver.h"
// Just a shorter abbreviation to improve indentation.
using Cost = InstructionCost;
+// Map of known constants found during the specialization bonus estimation.
+using ConstMap = DenseMap<Value *, Constant *>;
+
// Specialization signature, used to uniquely designate a specialization within
// a function.
struct SpecSig {
: F(F), Sig(S), Score(Score) {}
};
+class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
+ const DataLayout &DL;
+ BlockFrequencyInfo &BFI;
+ TargetTransformInfo &TTI;
+ SCCPSolver &Solver;
+
+ ConstMap KnownConstants;
+
+ ConstMap::iterator LastVisited;
+
+public:
+ InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI,
+ TargetTransformInfo &TTI, SCCPSolver &Solver)
+ : DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
+
+ Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
+
+private:
+ friend class InstVisitor<InstCostVisitor, Constant *>;
+
+ Cost estimateSwitchInst(SwitchInst &I);
+ Cost estimateBranchInst(BranchInst &I);
+
+ Constant *visitInstruction(Instruction &I) { return nullptr; }
+ Constant *visitLoadInst(LoadInst &I);
+ Constant *visitGetElementPtrInst(GetElementPtrInst &I);
+ Constant *visitSelectInst(SelectInst &I);
+ Constant *visitCastInst(CastInst &I);
+ Constant *visitCmpInst(CmpInst &I);
+ Constant *visitUnaryOperator(UnaryOperator &I);
+ Constant *visitBinaryOperator(BinaryOperator &I);
+};
+
class FunctionSpecializer {
/// The IPSCCP Solver.
bool run();
+ InstCostVisitor getInstCostVisitorFor(Function *F) {
+ auto &BFI = (GetBFI)(*F);
+ auto &TTI = (GetTTI)(*F);
+ return InstCostVisitor(M.getDataLayout(), BFI, TTI, Solver);
+ }
+
+ /// Compute a bonus for replacing argument \p A with constant \p C.
+ Cost getSpecializationBonus(Argument *A, Constant *C,
+ InstCostVisitor &Visitor);
+
private:
Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call);
/// Compute and return the cost of specializing function \p F.
Cost getSpecializationCost(Function *F);
- /// Compute a bonus for replacing argument \p A with constant \p C.
- Cost getSpecializationBonus(Argument *A, Constant *C);
-
/// Determine if it is possible to specialise the function for constant values
/// of the formal parameter \p A.
bool isArgumentInteresting(Argument *A);
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InlineCost.h"
+#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueLattice.h"
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/ConstantFold.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Transforms/Scalar/SCCP.h"
#include "llvm/Transforms/Utils/Cloning.h"
"Enable specialization of functions that take a literal constant as an "
"argument"));
+// Estimates the instruction cost of all the basic blocks in \p WorkList.
+// The successors of such blocks are added to the list as long as they are
+// executable and they have a unique predecessor. \p WorkList represents
+// the basic blocks of a specialization which become dead once we replace
+// instructions that are known to be constants. The aim here is to estimate
+// the combination of size and latency savings in comparison to the non
+// specialized version of the function.
+static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
+ ConstMap &KnownConstants, SCCPSolver &Solver,
+ BlockFrequencyInfo &BFI,
+ TargetTransformInfo &TTI) {
+ Cost Bonus = 0;
+
+ // Accumulate the instruction cost of each basic block weighted by frequency.
+ while (!WorkList.empty()) {
+ BasicBlock *BB = WorkList.pop_back_val();
+
+ uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() /
+ BFI.getEntryFreq();
+ if (!Weight)
+ continue;
+
+ for (Instruction &I : *BB) {
+ // Disregard SSA copies.
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ if (II->getIntrinsicID() == Intrinsic::ssa_copy)
+ continue;
+ // If it's a known constant we have already accounted for it.
+ if (KnownConstants.contains(&I))
+ continue;
+
+ Bonus += Weight *
+ TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
+ << " after user " << I << "\n");
+ }
+
+ // Keep adding dead successors to the list as long as they are
+ // executable and they have a unique predecessor.
+ for (BasicBlock *SuccBB : successors(BB))
+ if (Solver.isBlockExecutable(SuccBB) &&
+ SuccBB->getUniquePredecessor() == BB)
+ WorkList.push_back(SuccBB);
+ }
+ return Bonus;
+}
+
+static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
+ if (auto It = KnownConstants.find(V); It != KnownConstants.end())
+ return It->second;
+ return nullptr;
+}
+
+Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+ // Cache the iterator before visiting.
+ LastVisited = KnownConstants.insert({Use, C}).first;
+
+ if (auto *I = dyn_cast<SwitchInst>(User))
+ return estimateSwitchInst(*I);
+
+ if (auto *I = dyn_cast<BranchInst>(User))
+ return estimateBranchInst(*I);
+
+ C = visit(*User);
+ if (!C)
+ return 0;
+
+ KnownConstants.insert({User, C});
+
+ uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
+ BFI.getEntryFreq();
+ if (!Weight)
+ return 0;
+
+ Cost Bonus = Weight *
+ TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus
+ << " for user " << *User << "\n");
+
+ for (auto *U : User->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ if (Solver.isBlockExecutable(UI->getParent()))
+ Bonus += getUserBonus(UI, User, C);
+
+ return Bonus;
+}
+
+Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
+ if (I.getCondition() != LastVisited->first)
+ return 0;
+
+ auto *C = cast<ConstantInt>(LastVisited->second);
+ BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor();
+ // Initialize the worklist with the dead basic blocks. These are the
+ // destination labels which are different from the one corresponding
+ // to \p C. They should be executable and have a unique predecessor.
+ SmallVector<BasicBlock *> WorkList;
+ for (const auto &Case : I.cases()) {
+ BasicBlock *BB = Case.getCaseSuccessor();
+ if (BB == Succ || !Solver.isBlockExecutable(BB) ||
+ BB->getUniquePredecessor() != I.getParent())
+ continue;
+ WorkList.push_back(BB);
+ }
+
+ return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+}
+
+Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
+ if (I.getCondition() != LastVisited->first)
+ return 0;
+
+ BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue());
+ // Initialize the worklist with the dead successor as long as
+ // it is executable and has a unique predecessor.
+ SmallVector<BasicBlock *> WorkList;
+ if (Solver.isBlockExecutable(Succ) &&
+ Succ->getUniquePredecessor() == I.getParent())
+ WorkList.push_back(Succ);
+
+ return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+}
+
+Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
+ if (isa<ConstantPointerNull>(LastVisited->second))
+ return nullptr;
+ return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
+}
+
+Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
+ SmallVector<Value *, 8> Operands;
+ Operands.reserve(I.getNumOperands());
+
+ for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) {
+ Value *V = I.getOperand(Idx);
+ auto *C = dyn_cast<Constant>(V);
+ if (!C)
+ C = findConstantFor(V, KnownConstants);
+ if (!C)
+ return nullptr;
+ Operands.push_back(C);
+ }
+
+ auto *Ptr = cast<Constant>(Operands[0]);
+ auto Ops = ArrayRef(Operands.begin() + 1, Operands.end());
+ return ConstantFoldGetElementPtr(I.getSourceElementType(), Ptr,
+ I.isInBounds(), std::nullopt, Ops);
+}
+
+Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
+ if (I.getCondition() != LastVisited->first)
+ return nullptr;
+
+ Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
+ : I.getTrueValue();
+ auto *C = dyn_cast<Constant>(V);
+ if (!C)
+ C = findConstantFor(V, KnownConstants);
+ return C;
+}
+
+Constant *InstCostVisitor::visitCastInst(CastInst &I) {
+ return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second,
+ I.getType(), DL);
+}
+
+Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
+ bool Swap = I.getOperand(1) == LastVisited->first;
+ Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+ auto *Other = dyn_cast<Constant>(V);
+ if (!Other)
+ Other = findConstantFor(V, KnownConstants);
+
+ if (!Other)
+ return nullptr;
+
+ Constant *Const = LastVisited->second;
+ return Swap ?
+ ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL)
+ : ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL);
+}
+
+Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
+ return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
+}
+
+Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
+ bool Swap = I.getOperand(1) == LastVisited->first;
+ Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+ auto *Other = dyn_cast<Constant>(V);
+ if (!Other)
+ Other = findConstantFor(V, KnownConstants);
+
+ if (!Other)
+ return nullptr;
+
+ Constant *Const = LastVisited->second;
+ return dyn_cast_or_null<Constant>(Swap ?
+ simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
+ : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));
+}
+
Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
CallInst *Call) {
Value *StoreValue = nullptr;
CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
for (BasicBlock &BB : *F)
Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
-
- LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
- << F->getName() << " is " << Metrics.NumInsts
- << " instructions\n");
}
return Metrics;
}
} else {
// Calculate the specialisation gain.
Cost Score = 0 - SpecCost;
+ InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
- Score += getSpecializationBonus(A.Formal, A.Actual);
+ Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
// Discard unprofitable specialisations.
if (!ForceSpecialization && Score <= 0)
// Otherwise, set the specialization cost to be the cost of all the
// instructions in the function.
- return Metrics.NumInsts * InlineConstants::getInstrCost();
-}
-
-static Cost getUserBonus(User *U, TargetTransformInfo &TTI,
- BlockFrequencyInfo &BFI) {
- auto *I = dyn_cast_or_null<Instruction>(U);
- // If not an instruction we do not know how to evaluate.
- // Keep minimum possible cost for now so that it doesnt affect
- // specialization.
- if (!I)
- return 0;
-
- uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() /
- BFI.getEntryFreq();
- if (!Weight)
- return 0;
-
- Cost Bonus = Weight *
- TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency);
-
- // Traverse recursively if there are more uses.
- // TODO: Any other instructions to be added here?
- if (I->mayReadFromMemory() || I->isCast())
- for (auto *User : I->users())
- Bonus += getUserBonus(User, TTI, BFI);
-
- return Bonus;
+ return Metrics.NumInsts;
}
/// Compute a bonus for replacing argument \p A with constant \p C.
-Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C) {
- Function *F = A->getParent();
- auto &TTI = (GetTTI)(*F);
- auto &BFI = (GetBFI)(*F);
+Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
+ InstCostVisitor &Visitor) {
LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
<< C->getNameOrAsOperand() << "\n");
Cost TotalCost = 0;
- for (auto *U : A->users()) {
- TotalCost += getUserBonus(U, TTI, BFI);
- LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
- TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
- }
+ for (auto *U : A->users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ if (Solver.isBlockExecutable(UI->getParent()))
+ TotalCost += Visitor.getUserBonus(UI, A, C);
+
+ LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus "
+ << TotalCost << " for argument " << *A << "\n");
// The below heuristic is only concerned with exposing inlining
// opportunities via indirect call promotion. If the argument is not a
LowerTypeTests.cpp
WholeProgramDevirt.cpp
AttributorTest.cpp
+ FunctionSpecializationTest.cpp
)
set_property(TARGET IPOTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests")
--- /dev/null
+//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/IPO/FunctionSpecialization.h"
+#include "llvm/Transforms/Utils/SCCPSolver.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+namespace llvm {
+
+class FunctionSpecializationTest : public testing::Test {
+protected:
+ LLVMContext Ctx;
+ FunctionAnalysisManager FAM;
+ std::unique_ptr<Module> M;
+ std::unique_ptr<SCCPSolver> Solver;
+
+ FunctionSpecializationTest() {
+ FAM.registerPass([&] { return TargetLibraryAnalysis(); });
+ FAM.registerPass([&] { return TargetIRAnalysis(); });
+ FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
+ FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
+ FAM.registerPass([&] { return LoopAnalysis(); });
+ FAM.registerPass([&] { return AssumptionAnalysis(); });
+ FAM.registerPass([&] { return DominatorTreeAnalysis(); });
+ FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
+ FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ }
+
+ Module &parseModule(const char *ModuleString) {
+ SMDiagnostic Err;
+ M = parseAssemblyString(ModuleString, Err, Ctx);
+ EXPECT_TRUE(M);
+ return *M;
+ }
+
+ FunctionSpecializer getSpecializerFor(Function *F) {
+ auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
+ return FAM.getResult<TargetLibraryAnalysis>(F);
+ };
+ auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
+ return FAM.getResult<TargetIRAnalysis>(F);
+ };
+ auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
+ return FAM.getResult<BlockFrequencyAnalysis>(F);
+ };
+ auto GetAC = [this](Function &F) -> AssumptionCache & {
+ return FAM.getResult<AssumptionAnalysis>(F);
+ };
+ auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn {
+ DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+ return { std::make_unique<PredicateInfo>(F, DT,
+ FAM.getResult<AssumptionAnalysis>(F)),
+ &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) };
+ };
+
+ Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
+
+ Solver->addAnalysis(*F, GetAnalysis(*F));
+ Solver->markBlockExecutable(&F->front());
+ for (Argument &Arg : F->args())
+ Solver->markOverdefined(&Arg);
+ Solver->solveWhileResolvedUndefsIn(*M);
+
+ return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
+ GetAC);
+ }
+
+ Cost getInstCost(Instruction &I) {
+ auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
+ auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
+
+ return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
+ TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+ }
+};
+
+} // namespace llvm
+
+using namespace llvm;
+
+TEST_F(FunctionSpecializationTest, SwitchInst) {
+ const char *ModuleString = R"(
+ define void @foo(i32 %a, i32 %b, i32 %i) {
+ entry:
+ switch i32 %i, label %default
+ [ i32 1, label %case1
+ i32 2, label %case2 ]
+ case1:
+ %0 = mul i32 %a, 2
+ %1 = sub i32 6, 5
+ br label %bb1
+ case2:
+ %2 = and i32 %b, 3
+ %3 = sdiv i32 8, 2
+ br label %bb2
+ bb1:
+ %4 = add i32 %0, %b
+ br label %default
+ bb2:
+ %5 = or i32 %2, %a
+ br label %default
+ default:
+ ret void
+ }
+ )";
+
+ Module &M = parseModule(ModuleString);
+ Function *F = M.getFunction("foo");
+ FunctionSpecializer Specializer = getSpecializerFor(F);
+ InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+ Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+
+ auto FuncIter = F->begin();
+ BasicBlock &Case1 = *++FuncIter;
+ BasicBlock &Case2 = *++FuncIter;
+ BasicBlock &BB1 = *++FuncIter;
+ BasicBlock &BB2 = *++FuncIter;
+
+ Instruction &Mul = Case1.front();
+ Instruction &And = Case2.front();
+ Instruction &Sdiv = *++Case2.begin();
+ Instruction &BrBB2 = Case2.back();
+ Instruction &Add = BB1.front();
+ Instruction &Or = BB2.front();
+ Instruction &BrDefault = BB2.back();
+
+ // mul
+ Cost Ref = getInstCost(Mul);
+ Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+
+ // and + or + add
+ Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
+ Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+
+ // sdiv + br + br
+ Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault);
+ Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+}
+
+TEST_F(FunctionSpecializationTest, BranchInst) {
+ const char *ModuleString = R"(
+ define void @foo(i32 %a, i32 %b, i1 %cond) {
+ entry:
+ br i1 %cond, label %bb0, label %bb2
+ bb0:
+ %0 = mul i32 %a, 2
+ %1 = sub i32 6, 5
+ br label %bb1
+ bb1:
+ %2 = add i32 %0, %b
+ %3 = sdiv i32 8, 2
+ br label %bb2
+ bb2:
+ ret void
+ }
+ )";
+
+ Module &M = parseModule(ModuleString);
+ Function *F = M.getFunction("foo");
+ FunctionSpecializer Specializer = getSpecializerFor(F);
+ InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+ Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+ Constant *False = ConstantInt::getFalse(M.getContext());
+
+ auto FuncIter = F->begin();
+ BasicBlock &BB0 = *++FuncIter;
+ BasicBlock &BB1 = *++FuncIter;
+
+ Instruction &Mul = BB0.front();
+ Instruction &Sub = *++BB0.begin();
+ Instruction &BrBB1 = BB0.back();
+ Instruction &Add = BB1.front();
+ Instruction &Sdiv = *++BB1.begin();
+ Instruction &BrBB2 = BB1.back();
+
+ // mul
+ Cost Ref = getInstCost(Mul);
+ Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+
+ // add
+ Ref = getInstCost(Add);
+ Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+
+ // sub + br + sdiv + br
+ Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
+ getInstCost(BrBB2);
+ Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+}
+
+TEST_F(FunctionSpecializationTest, Misc) {
+ const char *ModuleString = R"(
+ @g = constant [2 x i32] zeroinitializer, align 4
+
+ define i32 @foo(i8 %a, i1 %cond, ptr %b) {
+ %cmp = icmp eq i8 %a, 10
+ %ext = zext i1 %cmp to i32
+ %sel = select i1 %cond, i32 %ext, i32 1
+ %gep = getelementptr i32, ptr %b, i32 %sel
+ %ld = load i32, ptr %gep
+ ret i32 %ld
+ }
+ )";
+
+ Module &M = parseModule(ModuleString);
+ Function *F = M.getFunction("foo");
+ FunctionSpecializer Specializer = getSpecializerFor(F);
+ InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+ GlobalVariable *GV = M.getGlobalVariable("g");
+ Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
+ Constant *True = ConstantInt::getTrue(M.getContext());
+
+ auto BlockIter = F->front().begin();
+ Instruction &Icmp = *BlockIter++;
+ Instruction &Zext = *BlockIter++;
+ Instruction &Select = *BlockIter++;
+ Instruction &Gep = *BlockIter++;
+ Instruction &Load = *BlockIter++;
+
+ // icmp + zext
+ Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
+ Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+
+ // select
+ Ref = getInstCost(Select);
+ Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+
+ // gep + load
+ Ref = getInstCost(Gep) + getInstCost(Load);
+ Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
+ EXPECT_EQ(Bonus, Ref);
+}