/// the enclosing function's count (if available) and returns the value.
Optional<uint64_t> getBlockProfileCount(const BasicBlock *BB) const;
+ /// \brief Returns the estimated profile count of \p Freq.
+ /// This uses the frequency \p Freq and multiplies it by
+ /// the enclosing function's count (if available) and returns the value.
+ Optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
+
// Set the frequency of the given basic block.
void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
BlockFrequency getBlockFreq(const BlockNode &Node) const;
Optional<uint64_t> getBlockProfileCount(const Function &F,
const BlockNode &Node) const;
+ Optional<uint64_t> getProfileCountFromFreq(const Function &F,
+ uint64_t Freq) const;
void setBlockFreq(const BlockNode &Node, uint64_t Freq);
const BlockT *BB) const {
return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB));
}
+ Optional<uint64_t> getProfileCountFromFreq(const Function &F,
+ uint64_t Freq) const {
+ return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq);
+ }
void setBlockFreq(const BlockT *BB, uint64_t Freq);
Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB));
BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
Optional<uint64_t> getBlockProfileCount(const MachineBasicBlock *MBB) const;
+ Optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
const MachineFunction *getFunction() const;
const MachineBranchProbabilityInfo *getMBPI() const;
namespace llvm {
template <typename T> class ArrayRef;
class BasicBlock;
+ class BlockFrequency;
+ class BlockFrequencyInfo;
+ class BranchProbabilityInfo;
class DominatorTree;
class Function;
class Loop;
// Various bits of state computed on construction.
DominatorTree *const DT;
const bool AggregateArgs;
+ BlockFrequencyInfo *BFI;
+ BranchProbabilityInfo *BPI;
// Bits of intermediate state computed at various phases of extraction.
SetVector<BasicBlock *> Blocks;
///
/// In this formation, we don't require a dominator tree. The given basic
/// block is set up for extraction.
- CodeExtractor(BasicBlock *BB, bool AggregateArgs = false);
+ CodeExtractor(BasicBlock *BB, bool AggregateArgs = false,
+ BlockFrequencyInfo *BFI = nullptr,
+ BranchProbabilityInfo *BPI = nullptr);
/// \brief Create a code extractor for a sequence of blocks.
///
/// sequence out into its new function. When a DominatorTree is also given,
/// extra checking and transformations are enabled.
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
- bool AggregateArgs = false);
+ bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+ BranchProbabilityInfo *BPI = nullptr);
/// \brief Create a code extractor for a loop body.
///
/// Behaves just like the generic code sequence constructor, but uses the
/// block sequence of the loop.
- CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false);
+ CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
+ BlockFrequencyInfo *BFI = nullptr,
+ BranchProbabilityInfo *BPI = nullptr);
/// \brief Create a code extractor for a region node.
///
/// Behaves just like the generic code sequence constructor, but uses the
/// block sequence of the region node passed in.
CodeExtractor(DominatorTree &DT, const RegionNode &RN,
- bool AggregateArgs = false);
+ bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+ BranchProbabilityInfo *BPI = nullptr);
/// \brief Perform the extraction, returning the new function.
///
void moveCodeToFunction(Function *newFunction);
+ void calculateNewCallTerminatorWeights(
+ BasicBlock *CodeReplacer,
+ DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+ BranchProbabilityInfo *BPI);
+
void emitCallAndSwitchStatement(Function *newFunction,
BasicBlock *newHeader,
ValueSet &inputs,
return BFI->getBlockProfileCount(*getFunction(), BB);
}
+Optional<uint64_t>
+BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+ if (!BFI)
+ return None;
+ return BFI->getProfileCountFromFreq(*getFunction(), Freq);
+}
+
void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
assert(BFI && "Expected analysis to be available");
BFI->setBlockFreq(BB, Freq);
Optional<uint64_t>
BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
const BlockNode &Node) const {
+ return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency());
+}
+
+Optional<uint64_t>
+BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
+ uint64_t Freq) const {
auto EntryCount = F.getEntryCount();
if (!EntryCount)
return None;
// Use 128 bit APInt to do the arithmetic to avoid overflow.
APInt BlockCount(128, EntryCount.getValue());
- APInt BlockFreq(128, getBlockFreq(Node).getFrequency());
+ APInt BlockFreq(128, Freq);
APInt EntryFreq(128, getEntryFreq());
BlockCount *= BlockFreq;
BlockCount = BlockCount.udiv(EntryFreq);
return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None;
}
+Optional<uint64_t>
+MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+ const Function *F = MBFI->getFunction()->getFunction();
+ return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None;
+}
+
const MachineFunction *MachineBlockFrequencyInfo::getFunction() const {
return MBFI ? MBFI->getFunction() : nullptr;
}
#include "llvm/Transforms/IPO/PartialInlining.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
DominatorTree DT;
DT.recalculate(*DuplicateFunction);
+ // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo.
+ LoopInfo LI(DT);
+ BranchProbabilityInfo BPI(*DuplicateFunction, LI);
+ BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI);
+
// Extract the body of the if.
Function *ExtractedFunction =
- CodeExtractor(ToExtract, &DT).extractCodeRegion();
+ CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI)
+ .extractCodeRegion();
// Inline the top-level if test into all callers.
std::vector<User *> Users(DuplicateFunction->user_begin(),
if (Recursive)
continue;
- if (Function *newFunc = unswitchFunction(CurrFunc)) {
- Worklist.push_back(newFunc);
+ if (Function *NewFunc = unswitchFunction(CurrFunc)) {
+ Worklist.push_back(NewFunc);
Changed = true;
}
}
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/RegionInfo.h"
#include "llvm/Analysis/RegionIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
+#include "llvm/Support/BlockFrequency.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
return buildExtractionBlockSet(R.block_begin(), R.block_end());
}
-CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
- : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
+CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
+ BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
- bool AggregateArgs)
- : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
-
-CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
- : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
+ bool AggregateArgs, BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
+
+CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
+ BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
+ NumExitBlocks(~0U) {}
CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
- bool AggregateArgs)
- : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
+ bool AggregateArgs, BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
/// definedInRegion - Return true if the specified value is defined in the
/// extracted region.
}
}
+void CodeExtractor::calculateNewCallTerminatorWeights(
+ BasicBlock *CodeReplacer,
+ DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+ BranchProbabilityInfo *BPI) {
+ typedef BlockFrequencyInfoImplBase::Distribution Distribution;
+ typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
+
+ // Update the branch weights for the exit block.
+ TerminatorInst *TI = CodeReplacer->getTerminator();
+ SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+
+ // Block Frequency distribution with dummy node.
+ Distribution BranchDist;
+
+ // Add each of the frequencies of the successors.
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
+ BlockNode ExitNode(i);
+ uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+ if (ExitFreq != 0)
+ BranchDist.addExit(ExitNode, ExitFreq);
+ else
+ BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
+ }
+
+ // Check for no total weight.
+ if (BranchDist.Total == 0)
+ return;
+
+ // Normalize the distribution so that they can fit in unsigned.
+ BranchDist.normalize();
+
+ // Create normalized branch weights and set the metadata.
+ for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
+ const auto &Weight = BranchDist.Weights[I];
+
+ // Get the weight and update the current BFI.
+ BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
+ BranchProbability BP(Weight.Amount, BranchDist.Total);
+ BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
+ }
+ TI->setMetadata(
+ LLVMContext::MD_prof,
+ MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
+}
+
Function *CodeExtractor::extractCodeRegion() {
if (!isEligible())
return nullptr;
// block in the region.
BasicBlock *header = *Blocks.begin();
+ // Calculate the entry frequency of the new function before we change the root
+ // block.
+ BlockFrequency EntryFreq;
+ if (BFI) {
+ assert(BPI && "Both BPI and BFI are required to preserve profile info");
+ for (BasicBlock *Pred : predecessors(header)) {
+ if (Blocks.count(Pred))
+ continue;
+ EntryFreq +=
+ BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
+ }
+ }
+
// If we have to split PHI nodes or the entry block, do so now.
severSplitPHINodes(header);
// Find inputs to, outputs from the code region.
findInputsOutputs(inputs, outputs);
+ // Calculate the exit blocks for the extracted region and the total exit
+ // weights for each of those blocks.
+ DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
SmallPtrSet<BasicBlock *, 1> ExitBlocks;
- for (BasicBlock *Block : Blocks)
+ for (BasicBlock *Block : Blocks) {
for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
- ++SI)
- if (!Blocks.count(*SI))
+ ++SI) {
+ if (!Blocks.count(*SI)) {
+ // Update the branch weight for this successor.
+ if (BFI) {
+ BlockFrequency &BF = ExitWeights[*SI];
+ BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
+ }
ExitBlocks.insert(*SI);
+ }
+ }
+ }
NumExitBlocks = ExitBlocks.size();
// Construct new function based on inputs/outputs & add allocas for all defs.
codeReplacer, oldFunction,
oldFunction->getParent());
+ // Update the entry count of the function.
+ if (BFI) {
+ Optional<uint64_t> EntryCount =
+ BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+ if (EntryCount.hasValue())
+ newFunction->setEntryCount(EntryCount.getValue());
+ BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+ }
+
emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
moveCodeToFunction(newFunction);
+ // Update the branch weights for the exit block.
+ if (BFI && NumExitBlocks > 1)
+ calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+
// Loop over all of the PHI nodes in the header block, and change any
// references to the old incoming edge to be the new incoming edge.
for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
--- /dev/null
+; RUN: opt < %s -partial-inliner -S | FileCheck %s
+
+; This test checks to make sure that the CodeExtractor
+; properly sets the entry count for the function that is
+; extracted based on the root block being extracted and also
+; takes into consideration if the block has edges coming from
+; a block that is also being extracted.
+
+define i32 @inlinedFunc(i1 %cond) !prof !1 {
+entry:
+ br i1 %cond, label %if.then, label %return, !prof !2
+if.then:
+ br i1 %cond, label %if.then, label %return, !prof !3
+return: ; preds = %entry
+ ret i32 0
+}
+
+
+define internal i32 @dummyCaller(i1 %cond) !prof !1 {
+entry:
+ %val = call i32 @inlinedFunc(i1 %cond)
+ ret i32 %val
+}
+
+; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]]
+
+
+!llvm.module.flags = !{!0}
+; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250}
+!0 = !{i32 1, !"MaxFunctionCount", i32 1000}
+!1 = !{!"function_entry_count", i64 1000}
+!2 = !{!"branch_weights", i32 250, i32 750}
+!3 = !{!"branch_weights", i32 125, i32 125}
--- /dev/null
+; RUN: opt < %s -partial-inliner -S | FileCheck %s
+
+; This test checks to make sure that CodeExtractor updates
+; the exit branch probabilities for multiple exit blocks.
+
+define i32 @inlinedFunc(i1 %cond) !prof !1 {
+entry:
+ br i1 %cond, label %if.then, label %return, !prof !2
+if.then:
+ br i1 %cond, label %return, label %return.2, !prof !3
+return.2:
+ ret i32 10
+return: ; preds = %entry
+ ret i32 0
+}
+
+
+define internal i32 @dummyCaller(i1 %cond) !prof !1 {
+entry:
+%val = call i32 @inlinedFunc(i1 %cond)
+ret i32 %val
+
+; CHECK-LABEL: @dummyCaller
+; CHECK: call
+; CHECK-NEXT: br i1 {{.*}}!prof [[COUNT1:![0-9]+]]
+}
+
+!llvm.module.flags = !{!0}
+!0 = !{i32 1, !"MaxFunctionCount", i32 10000}
+!1 = !{!"function_entry_count", i64 10000}
+!2 = !{!"branch_weights", i32 5, i32 5}
+!3 = !{!"branch_weights", i32 4, i32 1}
+
+; CHECK: [[COUNT1]] = !{!"branch_weights", i32 8, i32 31}