[FuncSpec] Improve the accuracy of the cost model.
authorAlexandros Lamprineas <alexandros.lamprineas@arm.com>
Thu, 11 May 2023 23:07:49 +0000 (00:07 +0100)
committerAlexandros Lamprineas <alexandros.lamprineas@arm.com>
Wed, 24 May 2023 10:40:12 +0000 (11:40 +0100)
Instead of blindly traversing the use-def chain of constant arguments,
compute known constants along the way. Stop as soon as a user cannot
be replaced by a constant. Keep it light-weight by handling some basic
instruction types.

Differential Revision: https://reviews.llvm.org/D150464

llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
llvm/unittests/Transforms/IPO/CMakeLists.txt
llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp [new file with mode: 0644]

index 02e73e2..349d5a7 100644 (file)
@@ -52,6 +52,7 @@
 #include "llvm/Analysis/CodeMetrics.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/InstVisitor.h"
 #include "llvm/Transforms/Scalar/SCCP.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/SCCPSolver.h"
@@ -69,6 +70,9 @@ using SpecMap = DenseMap<Function *, std::pair<unsigned, unsigned>>;
 // Just a shorter abbreviation to improve indentation.
 using Cost = InstructionCost;
 
+// Map of known constants found during the specialization bonus estimation.
+using ConstMap = DenseMap<Value *, Constant *>;
+
 // Specialization signature, used to uniquely designate a specialization within
 // a function.
 struct SpecSig {
@@ -115,6 +119,39 @@ struct Spec {
       : F(F), Sig(S), Score(Score) {}
 };
 
+class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
+  const DataLayout &DL;
+  BlockFrequencyInfo &BFI;
+  TargetTransformInfo &TTI;
+  SCCPSolver &Solver;
+
+  ConstMap KnownConstants;
+
+  ConstMap::iterator LastVisited;
+
+public:
+  InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI,
+                  TargetTransformInfo &TTI, SCCPSolver &Solver)
+      : DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}
+
+  Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
+
+private:
+  friend class InstVisitor<InstCostVisitor, Constant *>;
+
+  Cost estimateSwitchInst(SwitchInst &I);
+  Cost estimateBranchInst(BranchInst &I);
+
+  Constant *visitInstruction(Instruction &I) { return nullptr; }
+  Constant *visitLoadInst(LoadInst &I);
+  Constant *visitGetElementPtrInst(GetElementPtrInst &I);
+  Constant *visitSelectInst(SelectInst &I);
+  Constant *visitCastInst(CastInst &I);
+  Constant *visitCmpInst(CmpInst &I);
+  Constant *visitUnaryOperator(UnaryOperator &I);
+  Constant *visitBinaryOperator(BinaryOperator &I);
+};
+
 class FunctionSpecializer {
 
   /// The IPSCCP Solver.
@@ -151,6 +188,16 @@ public:
 
   bool run();
 
+  InstCostVisitor getInstCostVisitorFor(Function *F) {
+    auto &BFI = (GetBFI)(*F);
+    auto &TTI = (GetTTI)(*F);
+    return InstCostVisitor(M.getDataLayout(), BFI, TTI, Solver);
+  }
+
+  /// Compute a bonus for replacing argument \p A with constant \p C.
+  Cost getSpecializationBonus(Argument *A, Constant *C,
+                              InstCostVisitor &Visitor);
+
 private:
   Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call);
 
@@ -194,9 +241,6 @@ private:
   /// Compute and return the cost of specializing function \p F.
   Cost getSpecializationCost(Function *F);
 
-  /// Compute a bonus for replacing argument \p A with constant \p C.
-  Cost getSpecializationBonus(Argument *A, Constant *C);
-
   /// Determine if it is possible to specialise the function for constant values
   /// of the formal parameter \p A.
   bool isArgumentInteresting(Argument *A);
index 87cc0f6..a970253 100644 (file)
 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InlineCost.h"
+#include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueLattice.h"
 #include "llvm/Analysis/ValueLatticeUtils.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Transforms/Scalar/SCCP.h"
 #include "llvm/Transforms/Utils/Cloning.h"
@@ -94,6 +97,210 @@ static cl::opt<bool> SpecializeLiteralConstant(
     "Enable specialization of functions that take a literal constant as an "
     "argument"));
 
+// Estimates the instruction cost of all the basic blocks in \p WorkList.
+// The successors of such blocks are added to the list as long as they are
+// executable and they have a unique predecessor. \p WorkList represents
+// the basic blocks of a specialization which become dead once we replace
+// instructions that are known to be constants. The aim here is to estimate
+// the combination of size and latency savings in comparison to the non
+// specialized version of the function.
+static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
+                                ConstMap &KnownConstants, SCCPSolver &Solver,
+                                BlockFrequencyInfo &BFI,
+                                TargetTransformInfo &TTI) {
+  Cost Bonus = 0;
+
+  // Accumulate the instruction cost of each basic block weighted by frequency.
+  while (!WorkList.empty()) {
+    BasicBlock *BB = WorkList.pop_back_val();
+
+    uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() /
+                      BFI.getEntryFreq();
+    if (!Weight)
+      continue;
+
+    for (Instruction &I : *BB) {
+      // Disregard SSA copies.
+      if (auto *II = dyn_cast<IntrinsicInst>(&I))
+        if (II->getIntrinsicID() == Intrinsic::ssa_copy)
+          continue;
+      // If it's a known constant we have already accounted for it.
+      if (KnownConstants.contains(&I))
+        continue;
+
+      Bonus += Weight *
+          TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+
+      LLVM_DEBUG(dbgs() << "FnSpecialization:     Bonus " << Bonus
+                        << " after user " << I << "\n");
+    }
+
+    // Keep adding dead successors to the list as long as they are
+    // executable and they have a unique predecessor.
+    for (BasicBlock *SuccBB : successors(BB))
+      if (Solver.isBlockExecutable(SuccBB) &&
+          SuccBB->getUniquePredecessor() == BB)
+        WorkList.push_back(SuccBB);
+  }
+  return Bonus;
+}
+
+static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
+  if (auto It = KnownConstants.find(V); It != KnownConstants.end())
+    return It->second;
+  return nullptr;
+}
+
+Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
+  // Cache the iterator before visiting.
+  LastVisited = KnownConstants.insert({Use, C}).first;
+
+  if (auto *I = dyn_cast<SwitchInst>(User))
+    return estimateSwitchInst(*I);
+
+  if (auto *I = dyn_cast<BranchInst>(User))
+    return estimateBranchInst(*I);
+
+  C = visit(*User);
+  if (!C)
+    return 0;
+
+  KnownConstants.insert({User, C});
+
+  uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() /
+                    BFI.getEntryFreq();
+  if (!Weight)
+    return 0;
+
+  Cost Bonus = Weight *
+      TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency);
+
+  LLVM_DEBUG(dbgs() << "FnSpecialization:     Bonus " << Bonus
+                    << " for user " << *User << "\n");
+
+  for (auto *U : User->users())
+    if (auto *UI = dyn_cast<Instruction>(U))
+      if (Solver.isBlockExecutable(UI->getParent()))
+        Bonus += getUserBonus(UI, User, C);
+
+  return Bonus;
+}
+
+Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
+  if (I.getCondition() != LastVisited->first)
+    return 0;
+
+  auto *C = cast<ConstantInt>(LastVisited->second);
+  BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor();
+  // Initialize the worklist with the dead basic blocks. These are the
+  // destination labels which are different from the one corresponding
+  // to \p C. They should be executable and have a unique predecessor.
+  SmallVector<BasicBlock *> WorkList;
+  for (const auto &Case : I.cases()) {
+    BasicBlock *BB = Case.getCaseSuccessor();
+    if (BB == Succ || !Solver.isBlockExecutable(BB) ||
+        BB->getUniquePredecessor() != I.getParent())
+      continue;
+    WorkList.push_back(BB);
+  }
+
+  return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+}
+
+Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
+  if (I.getCondition() != LastVisited->first)
+    return 0;
+
+  BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue());
+  // Initialize the worklist with the dead successor as long as
+  // it is executable and has a unique predecessor.
+  SmallVector<BasicBlock *> WorkList;
+  if (Solver.isBlockExecutable(Succ) &&
+      Succ->getUniquePredecessor() == I.getParent())
+    WorkList.push_back(Succ);
+
+  return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
+}
+
+Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
+  if (isa<ConstantPointerNull>(LastVisited->second))
+    return nullptr;
+  return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
+}
+
+Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
+  SmallVector<Value *, 8> Operands;
+  Operands.reserve(I.getNumOperands());
+
+  for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) {
+    Value *V = I.getOperand(Idx);
+    auto *C = dyn_cast<Constant>(V);
+    if (!C)
+      C = findConstantFor(V, KnownConstants);
+    if (!C)
+      return nullptr;
+    Operands.push_back(C);
+  }
+
+  auto *Ptr = cast<Constant>(Operands[0]);
+  auto Ops = ArrayRef(Operands.begin() + 1, Operands.end());
+  return ConstantFoldGetElementPtr(I.getSourceElementType(), Ptr,
+                                   I.isInBounds(), std::nullopt, Ops);
+}
+
+Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
+  if (I.getCondition() != LastVisited->first)
+    return nullptr;
+
+  Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
+                                                : I.getTrueValue();
+  auto *C = dyn_cast<Constant>(V);
+  if (!C)
+    C = findConstantFor(V, KnownConstants);
+  return C;
+}
+
+Constant *InstCostVisitor::visitCastInst(CastInst &I) {
+  return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second,
+                                 I.getType(), DL);
+}
+
+Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
+  bool Swap = I.getOperand(1) == LastVisited->first;
+  Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+  auto *Other = dyn_cast<Constant>(V);
+  if (!Other)
+    Other = findConstantFor(V, KnownConstants);
+
+  if (!Other)
+    return nullptr;
+
+  Constant *Const = LastVisited->second;
+  return Swap ?
+        ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL)
+      : ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL);
+}
+
+Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
+  return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
+}
+
+Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
+  bool Swap = I.getOperand(1) == LastVisited->first;
+  Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+  auto *Other = dyn_cast<Constant>(V);
+  if (!Other)
+    Other = findConstantFor(V, KnownConstants);
+
+  if (!Other)
+    return nullptr;
+
+  Constant *Const = LastVisited->second;
+  return dyn_cast_or_null<Constant>(Swap ?
+        simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
+      : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));
+}
+
 Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
                                                    CallInst *Call) {
   Value *StoreValue = nullptr;
@@ -412,10 +619,6 @@ CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) {
     CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
     for (BasicBlock &BB : *F)
       Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
-
-    LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
-                      << F->getName() << " is " << Metrics.NumInsts
-                      << " instructions\n");
   }
   return Metrics;
 }
@@ -496,8 +699,9 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
     } else {
       // Calculate the specialisation gain.
       Cost Score = 0 - SpecCost;
+      InstCostVisitor Visitor = getInstCostVisitorFor(F);
       for (ArgInfo &A : S.Args)
-        Score += getSpecializationBonus(A.Formal, A.Actual);
+        Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
 
       // Discard unprofitable specialisations.
       if (!ForceSpecialization && Score <= 0)
@@ -584,49 +788,23 @@ Cost FunctionSpecializer::getSpecializationCost(Function *F) {
 
   // Otherwise, set the specialization cost to be the cost of all the
   // instructions in the function.
-  return Metrics.NumInsts * InlineConstants::getInstrCost();
-}
-
-static Cost getUserBonus(User *U, TargetTransformInfo &TTI,
-                         BlockFrequencyInfo &BFI) {
-  auto *I = dyn_cast_or_null<Instruction>(U);
-  // If not an instruction we do not know how to evaluate.
-  // Keep minimum possible cost for now so that it doesnt affect
-  // specialization.
-  if (!I)
-    return 0;
-
-  uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() /
-                    BFI.getEntryFreq();
-  if (!Weight)
-    return 0;
-
-  Cost Bonus = Weight *
-      TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency);
-
-  // Traverse recursively if there are more uses.
-  // TODO: Any other instructions to be added here?
-  if (I->mayReadFromMemory() || I->isCast())
-    for (auto *User : I->users())
-      Bonus += getUserBonus(User, TTI, BFI);
-
-  return Bonus;
+  return Metrics.NumInsts;
 }
 
 /// Compute a bonus for replacing argument \p A with constant \p C.
-Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C) {
-  Function *F = A->getParent();
-  auto &TTI = (GetTTI)(*F);
-  auto &BFI = (GetBFI)(*F);
+Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
+                                                 InstCostVisitor &Visitor) {
   LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
                     << C->getNameOrAsOperand() << "\n");
 
   Cost TotalCost = 0;
-  for (auto *U : A->users()) {
-    TotalCost += getUserBonus(U, TTI, BFI);
-    LLVM_DEBUG(dbgs() << "FnSpecialization:   User cost ";
-               TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
-  }
+  for (auto *U : A->users())
+    if (auto *UI = dyn_cast<Instruction>(U))
+      if (Solver.isBlockExecutable(UI->getParent()))
+        TotalCost += Visitor.getUserBonus(UI, A, C);
+
+  LLVM_DEBUG(dbgs() << "FnSpecialization:   Accumulated user bonus "
+                    << TotalCost << " for argument " << *A << "\n");
 
   // The below heuristic is only concerned with exposing inlining
   // opportunities via indirect call promotion. If the argument is not a
index 3b16d81..4e43721 100644 (file)
@@ -12,6 +12,7 @@ add_llvm_unittest(IPOTests
   LowerTypeTests.cpp
   WholeProgramDevirt.cpp
   AttributorTest.cpp
+  FunctionSpecializationTest.cpp
   )
 
 set_property(TARGET IPOTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests")
diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
new file mode 100644 (file)
index 0000000..16c9a50
--- /dev/null
@@ -0,0 +1,258 @@
+//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/IPO/FunctionSpecialization.h"
+#include "llvm/Transforms/Utils/SCCPSolver.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+namespace llvm {
+
+class FunctionSpecializationTest : public testing::Test {
+protected:
+  LLVMContext Ctx;
+  FunctionAnalysisManager FAM;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<SCCPSolver> Solver;
+
+  FunctionSpecializationTest() {
+    FAM.registerPass([&] { return TargetLibraryAnalysis(); });
+    FAM.registerPass([&] { return TargetIRAnalysis(); });
+    FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
+    FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
+    FAM.registerPass([&] { return LoopAnalysis(); });
+    FAM.registerPass([&] { return AssumptionAnalysis(); });
+    FAM.registerPass([&] { return DominatorTreeAnalysis(); });
+    FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
+    FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  }
+
+  Module &parseModule(const char *ModuleString) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(ModuleString, Err, Ctx);
+    EXPECT_TRUE(M);
+    return *M;
+  }
+
+  FunctionSpecializer getSpecializerFor(Function *F) {
+    auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
+      return FAM.getResult<TargetLibraryAnalysis>(F);
+    };
+    auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
+      return FAM.getResult<TargetIRAnalysis>(F);
+    };
+    auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
+      return FAM.getResult<BlockFrequencyAnalysis>(F);
+    };
+    auto GetAC = [this](Function &F) -> AssumptionCache & {
+      return FAM.getResult<AssumptionAnalysis>(F);
+    };
+    auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn {
+      DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+      return { std::make_unique<PredicateInfo>(F, DT,
+                                FAM.getResult<AssumptionAnalysis>(F)),
+               &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) };
+    };
+
+    Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
+
+    Solver->addAnalysis(*F, GetAnalysis(*F));
+    Solver->markBlockExecutable(&F->front());
+    for (Argument &Arg : F->args())
+      Solver->markOverdefined(&Arg);
+    Solver->solveWhileResolvedUndefsIn(*M);
+
+    return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
+                               GetAC);
+  }
+
+  Cost getInstCost(Instruction &I) {
+    auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
+    auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
+
+    return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
+         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+  }
+};
+
+} // namespace llvm
+
+using namespace llvm;
+
+TEST_F(FunctionSpecializationTest, SwitchInst) {
+  const char *ModuleString = R"(
+    define void @foo(i32 %a, i32 %b, i32 %i) {
+    entry:
+      switch i32 %i, label %default
+      [ i32 1, label %case1
+        i32 2, label %case2 ]
+    case1:
+      %0 = mul i32 %a, 2
+      %1 = sub i32 6, 5
+      br label %bb1
+    case2:
+      %2 = and i32 %b, 3
+      %3 = sdiv i32 8, 2
+      br label %bb2
+    bb1:
+      %4 = add i32 %0, %b
+      br label %default
+    bb2:
+      %5 = or i32 %2, %a
+      br label %default
+    default:
+      ret void
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+  Function *F = M.getFunction("foo");
+  FunctionSpecializer Specializer = getSpecializerFor(F);
+  InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+  Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+
+  auto FuncIter = F->begin();
+  BasicBlock &Case1 = *++FuncIter;
+  BasicBlock &Case2 = *++FuncIter;
+  BasicBlock &BB1 = *++FuncIter;
+  BasicBlock &BB2 = *++FuncIter;
+
+  Instruction &Mul = Case1.front();
+  Instruction &And = Case2.front();
+  Instruction &Sdiv = *++Case2.begin();
+  Instruction &BrBB2 = Case2.back();
+  Instruction &Add = BB1.front();
+  Instruction &Or = BB2.front();
+  Instruction &BrDefault = BB2.back();
+
+  // mul
+  Cost Ref = getInstCost(Mul);
+  Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+
+  // and + or + add
+  Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+
+  // sdiv + br + br
+  Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+}
+
+TEST_F(FunctionSpecializationTest, BranchInst) {
+  const char *ModuleString = R"(
+    define void @foo(i32 %a, i32 %b, i1 %cond) {
+    entry:
+      br i1 %cond, label %bb0, label %bb2
+    bb0:
+      %0 = mul i32 %a, 2
+      %1 = sub i32 6, 5
+      br label %bb1
+    bb1:
+      %2 = add i32 %0, %b
+      %3 = sdiv i32 8, 2
+      br label %bb2
+    bb2:
+      ret void
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+  Function *F = M.getFunction("foo");
+  FunctionSpecializer Specializer = getSpecializerFor(F);
+  InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+  Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+  Constant *False = ConstantInt::getFalse(M.getContext());
+
+  auto FuncIter = F->begin();
+  BasicBlock &BB0 = *++FuncIter;
+  BasicBlock &BB1 = *++FuncIter;
+
+  Instruction &Mul = BB0.front();
+  Instruction &Sub = *++BB0.begin();
+  Instruction &BrBB1 = BB0.back();
+  Instruction &Add = BB1.front();
+  Instruction &Sdiv = *++BB1.begin();
+  Instruction &BrBB2 = BB1.back();
+
+  // mul
+  Cost Ref = getInstCost(Mul);
+  Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+
+  // add
+  Ref = getInstCost(Add);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+
+  // sub + br + sdiv + br
+  Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
+        getInstCost(BrBB2);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+}
+
+TEST_F(FunctionSpecializationTest, Misc) {
+  const char *ModuleString = R"(
+    @g = constant [2 x i32] zeroinitializer, align 4
+
+    define i32 @foo(i8 %a, i1 %cond, ptr %b) {
+      %cmp = icmp eq i8 %a, 10
+      %ext = zext i1 %cmp to i32
+      %sel = select i1 %cond, i32 %ext, i32 1
+      %gep = getelementptr i32, ptr %b, i32 %sel
+      %ld = load i32, ptr %gep
+      ret i32 %ld
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+  Function *F = M.getFunction("foo");
+  FunctionSpecializer Specializer = getSpecializerFor(F);
+  InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+  GlobalVariable *GV = M.getGlobalVariable("g");
+  Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
+  Constant *True = ConstantInt::getTrue(M.getContext());
+
+  auto BlockIter = F->front().begin();
+  Instruction &Icmp = *BlockIter++;
+  Instruction &Zext = *BlockIter++;
+  Instruction &Select = *BlockIter++;
+  Instruction &Gep = *BlockIter++;
+  Instruction &Load = *BlockIter++;
+
+  // icmp + zext
+  Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
+  Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+
+  // select
+  Ref = getInstCost(Select);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+
+  // gep + load
+  Ref = getInstCost(Gep) + getInstCost(Load);
+  Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
+  EXPECT_EQ(Bonus, Ref);
+}