CodeExtractor : Add ability to preserve profile data.
authorSean Silva <chisophugis@gmail.com>
Tue, 2 Aug 2016 02:15:45 +0000 (02:15 +0000)
committerSean Silva <chisophugis@gmail.com>
Tue, 2 Aug 2016 02:15:45 +0000 (02:15 +0000)
Added ability to estimate the entry count of the extracted function and
the branch probabilities of the exit branches.

Patch by River Riddle!

Differential Revision: https://reviews.llvm.org/D22744

llvm-svn: 277411

llvm/include/llvm/Analysis/BlockFrequencyInfo.h
llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h
llvm/include/llvm/CodeGen/MachineBlockFrequencyInfo.h
llvm/include/llvm/Transforms/Utils/CodeExtractor.h
llvm/lib/Analysis/BlockFrequencyInfo.cpp
llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp
llvm/lib/CodeGen/MachineBlockFrequencyInfo.cpp
llvm/lib/Transforms/IPO/PartialInlining.cpp
llvm/lib/Transforms/Utils/CodeExtractor.cpp
llvm/test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll [new file with mode: 0644]
llvm/test/Transforms/CodeExtractor/MultipleExitBranchProb.ll [new file with mode: 0644]

index 7d48dfc..d7e76a1 100644 (file)
@@ -61,6 +61,11 @@ public:
   /// the enclosing function's count (if available) and returns the value.
   Optional<uint64_t> getBlockProfileCount(const BasicBlock *BB) const;
 
+  /// \brief Returns the estimated profile count of \p Freq.
+  /// This uses the frequency \p Freq and multiplies it by
+  /// the enclosing function's count (if available) and returns the value.
+  Optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
+
   // Set the frequency of the given basic block.
   void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
 
index 7ed06b1..8a0ced9 100644 (file)
@@ -482,6 +482,8 @@ public:
   BlockFrequency getBlockFreq(const BlockNode &Node) const;
   Optional<uint64_t> getBlockProfileCount(const Function &F,
                                           const BlockNode &Node) const;
+  Optional<uint64_t> getProfileCountFromFreq(const Function &F,
+                                             uint64_t Freq) const;
 
   void setBlockFreq(const BlockNode &Node, uint64_t Freq);
 
@@ -925,6 +927,10 @@ public:
                                           const BlockT *BB) const {
     return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB));
   }
+  Optional<uint64_t> getProfileCountFromFreq(const Function &F,
+                                             uint64_t Freq) const {
+    return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq);
+  }
   void setBlockFreq(const BlockT *BB, uint64_t Freq);
   Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
     return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB));
index 7a23608..bfa5bf6 100644 (file)
@@ -52,6 +52,7 @@ public:
   BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
 
   Optional<uint64_t> getBlockProfileCount(const MachineBasicBlock *MBB) const;
+  Optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
 
   const MachineFunction *getFunction() const;
   const MachineBranchProbabilityInfo *getMBPI() const;
index 970b3a1..a297866 100644 (file)
@@ -20,6 +20,9 @@
 namespace llvm {
 template <typename T> class ArrayRef;
   class BasicBlock;
+  class BlockFrequency;
+  class BlockFrequencyInfo;
+  class BranchProbabilityInfo;
   class DominatorTree;
   class Function;
   class Loop;
@@ -47,6 +50,8 @@ template <typename T> class ArrayRef;
     // Various bits of state computed on construction.
     DominatorTree *const DT;
     const bool AggregateArgs;
+    BlockFrequencyInfo *BFI;
+    BranchProbabilityInfo *BPI;
 
     // Bits of intermediate state computed at various phases of extraction.
     SetVector<BasicBlock *> Blocks;
@@ -64,7 +69,9 @@ template <typename T> class ArrayRef;
     ///
     /// In this formation, we don't require a dominator tree. The given basic
     /// block is set up for extraction.
-    CodeExtractor(BasicBlock *BB, bool AggregateArgs = false);
+    CodeExtractor(BasicBlock *BB, bool AggregateArgs = false,
+                  BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Create a code extractor for a sequence of blocks.
     ///
@@ -73,20 +80,24 @@ template <typename T> class ArrayRef;
     /// sequence out into its new function. When a DominatorTree is also given,
     /// extra checking and transformations are enabled.
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
-                  bool AggregateArgs = false);
+                  bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Create a code extractor for a loop body.
     ///
     /// Behaves just like the generic code sequence constructor, but uses the
     /// block sequence of the loop.
-    CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false);
+    CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
+                  BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Create a code extractor for a region node.
     ///
     /// Behaves just like the generic code sequence constructor, but uses the
     /// block sequence of the region node passed in.
     CodeExtractor(DominatorTree &DT, const RegionNode &RN,
-                  bool AggregateArgs = false);
+                  bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Perform the extraction, returning the new function.
     ///
@@ -122,6 +133,11 @@ template <typename T> class ArrayRef;
 
     void moveCodeToFunction(Function *newFunction);
 
+    void calculateNewCallTerminatorWeights(
+        BasicBlock *CodeReplacer,
+        DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+        BranchProbabilityInfo *BPI);
+
     void emitCallAndSwitchStatement(Function *newFunction,
                                     BasicBlock *newHeader,
                                     ValueSet &inputs,
index 1dd8f4f..5f7060a 100644 (file)
@@ -162,6 +162,13 @@ BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB) const {
   return BFI->getBlockProfileCount(*getFunction(), BB);
 }
 
+Optional<uint64_t>
+BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+  if (!BFI)
+    return None;
+  return BFI->getProfileCountFromFreq(*getFunction(), Freq);
+}
+
 void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
   assert(BFI && "Expected analysis to be available");
   BFI->setBlockFreq(BB, Freq);
index c2039e1..9d3045c 100644 (file)
@@ -533,12 +533,18 @@ BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const {
 Optional<uint64_t>
 BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
                                                  const BlockNode &Node) const {
+  return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency());
+}
+
+Optional<uint64_t>
+BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
+                                                    uint64_t Freq) const {
   auto EntryCount = F.getEntryCount();
   if (!EntryCount)
     return None;
   // Use 128 bit APInt to do the arithmetic to avoid overflow.
   APInt BlockCount(128, EntryCount.getValue());
-  APInt BlockFreq(128, getBlockFreq(Node).getFrequency());
+  APInt BlockFreq(128, Freq);
   APInt EntryFreq(128, getEntryFreq());
   BlockCount *= BlockFreq;
   BlockCount = BlockCount.udiv(EntryFreq);
index 6c0f99f..faf9ecc 100644 (file)
@@ -175,6 +175,12 @@ Optional<uint64_t> MachineBlockFrequencyInfo::getBlockProfileCount(
   return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None;
 }
 
+Optional<uint64_t>
+MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+  const Function *F = MBFI->getFunction()->getFunction();
+  return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None;
+}
+
 const MachineFunction *MachineBlockFrequencyInfo::getFunction() const {
   return MBFI ? MBFI->getFunction() : nullptr;
 }
index 6c762e4..7ef3fc1 100644 (file)
@@ -14,6 +14,9 @@
 
 #include "llvm/Transforms/IPO/PartialInlining.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
@@ -133,9 +136,15 @@ Function *PartialInlinerImpl::unswitchFunction(Function *F) {
   DominatorTree DT;
   DT.recalculate(*DuplicateFunction);
 
+  // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo.
+  LoopInfo LI(DT);
+  BranchProbabilityInfo BPI(*DuplicateFunction, LI);
+  BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI);
+
   // Extract the body of the if.
   Function *ExtractedFunction =
-      CodeExtractor(ToExtract, &DT).extractCodeRegion();
+      CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI)
+          .extractCodeRegion();
 
   // Inline the top-level if test into all callers.
   std::vector<User *> Users(DuplicateFunction->user_begin(),
@@ -181,8 +190,8 @@ bool PartialInlinerImpl::run(Module &M) {
     if (Recursive)
       continue;
 
-    if (Function *newFunc = unswitchFunction(CurrFunc)) {
-      Worklist.push_back(newFunc);
+    if (Function *NewFunc = unswitchFunction(CurrFunc)) {
+      Worklist.push_back(NewFunc);
       Changed = true;
     }
   }
index 8d0bc03..c514c9c 100644 (file)
@@ -17,6 +17,9 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/RegionInfo.h"
 #include "llvm/Analysis/RegionIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/BlockFrequency.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -119,23 +124,30 @@ buildExtractionBlockSet(const RegionNode &RN) {
   return buildExtractionBlockSet(R.block_begin(), R.block_end());
 }
 
-CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
-  : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
+CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
+                             BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
-                             bool AggregateArgs)
-  : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
-
-CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
-  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
+                             bool AggregateArgs, BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
+
+CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
+                             BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
+      NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
-                             bool AggregateArgs)
-  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
+                             bool AggregateArgs, BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -687,6 +699,51 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
   }
 }
 
+void CodeExtractor::calculateNewCallTerminatorWeights(
+    BasicBlock *CodeReplacer,
+    DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+    BranchProbabilityInfo *BPI) {
+  typedef BlockFrequencyInfoImplBase::Distribution Distribution;
+  typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
+
+  // Update the branch weights for the exit block.
+  TerminatorInst *TI = CodeReplacer->getTerminator();
+  SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+
+  // Block Frequency distribution with dummy node.
+  Distribution BranchDist;
+
+  // Add each of the frequencies of the successors.
+  for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
+    BlockNode ExitNode(i);
+    uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+    if (ExitFreq != 0)
+      BranchDist.addExit(ExitNode, ExitFreq);
+    else
+      BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
+  }
+
+  // Check for no total weight.
+  if (BranchDist.Total == 0)
+    return;
+
+  // Normalize the distribution so that they can fit in unsigned.
+  BranchDist.normalize();
+
+  // Create normalized branch weights and set the metadata.
+  for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
+    const auto &Weight = BranchDist.Weights[I];
+
+    // Get the weight and update the current BFI.
+    BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
+    BranchProbability BP(Weight.Amount, BranchDist.Total);
+    BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
+  }
+  TI->setMetadata(
+      LLVMContext::MD_prof,
+      MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
+}
+
 Function *CodeExtractor::extractCodeRegion() {
   if (!isEligible())
     return nullptr;
@@ -697,6 +754,19 @@ Function *CodeExtractor::extractCodeRegion() {
   // block in the region.
   BasicBlock *header = *Blocks.begin();
 
+  // Calculate the entry frequency of the new function before we change the root
+  //   block.
+  BlockFrequency EntryFreq;
+  if (BFI) {
+    assert(BPI && "Both BPI and BFI are required to preserve profile info");
+    for (BasicBlock *Pred : predecessors(header)) {
+      if (Blocks.count(Pred))
+        continue;
+      EntryFreq +=
+          BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
+    }
+  }
+
   // If we have to split PHI nodes or the entry block, do so now.
   severSplitPHINodes(header);
 
@@ -720,12 +790,23 @@ Function *CodeExtractor::extractCodeRegion() {
   // Find inputs to, outputs from the code region.
   findInputsOutputs(inputs, outputs);
 
+  // Calculate the exit blocks for the extracted region and the total exit
+  //  weights for each of those blocks.
+  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
   SmallPtrSet<BasicBlock *, 1> ExitBlocks;
-  for (BasicBlock *Block : Blocks)
+  for (BasicBlock *Block : Blocks) {
     for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
-         ++SI)
-      if (!Blocks.count(*SI))
+         ++SI) {
+      if (!Blocks.count(*SI)) {
+        // Update the branch weight for this successor.
+        if (BFI) {
+          BlockFrequency &BF = ExitWeights[*SI];
+          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
+        }
         ExitBlocks.insert(*SI);
+      }
+    }
+  }
   NumExitBlocks = ExitBlocks.size();
 
   // Construct new function based on inputs/outputs & add allocas for all defs.
@@ -734,10 +815,23 @@ Function *CodeExtractor::extractCodeRegion() {
                                             codeReplacer, oldFunction,
                                             oldFunction->getParent());
 
+  // Update the entry count of the function.
+  if (BFI) {
+    Optional<uint64_t> EntryCount =
+        BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+    if (EntryCount.hasValue())
+      newFunction->setEntryCount(EntryCount.getValue());
+    BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+  }
+
   emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
 
   moveCodeToFunction(newFunction);
 
+  // Update the branch weights for the exit block.
+  if (BFI && NumExitBlocks > 1)
+    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+
   // Loop over all of the PHI nodes in the header block, and change any
   // references to the old incoming edge to be the new incoming edge.
   for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
diff --git a/llvm/test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll b/llvm/test/Transforms/CodeExtractor/ExtractedFnEntryCount.ll
new file mode 100644 (file)
index 0000000..509a4d7
--- /dev/null
@@ -0,0 +1,33 @@
+; RUN: opt < %s -partial-inliner -S | FileCheck %s
+
+; This test checks to make sure that the CodeExtractor
+;  properly sets the entry count for the function that is
+;  extracted based on the root block being extracted and also
+;  takes into consideration if the block has edges coming from
+;  a block that is also being extracted.
+
+define i32 @inlinedFunc(i1 %cond) !prof !1 {
+entry:
+  br i1 %cond, label %if.then, label %return, !prof !2
+if.then:
+  br i1 %cond, label %if.then, label %return, !prof !3
+return:             ; preds = %entry
+  ret i32 0
+}
+
+
+define internal i32 @dummyCaller(i1 %cond) !prof !1 {
+entry:
+  %val = call i32 @inlinedFunc(i1 %cond)
+  ret i32 %val
+}
+
+; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]]
+
+
+!llvm.module.flags = !{!0}
+; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250}
+!0 = !{i32 1, !"MaxFunctionCount", i32 1000}
+!1 = !{!"function_entry_count", i64 1000}
+!2 = !{!"branch_weights", i32 250, i32 750}
+!3 = !{!"branch_weights", i32 125, i32 125}
diff --git a/llvm/test/Transforms/CodeExtractor/MultipleExitBranchProb.ll b/llvm/test/Transforms/CodeExtractor/MultipleExitBranchProb.ll
new file mode 100644 (file)
index 0000000..e37b7e6
--- /dev/null
@@ -0,0 +1,34 @@
+; RUN: opt < %s -partial-inliner -S | FileCheck %s
+
+; This test checks to make sure that CodeExtractor updates
+;  the exit branch probabilities for multiple exit blocks.
+
+define i32 @inlinedFunc(i1 %cond) !prof !1 {
+entry:
+  br i1 %cond, label %if.then, label %return, !prof !2
+if.then:
+  br i1 %cond, label %return, label %return.2, !prof !3
+return.2:
+  ret i32 10
+return:             ; preds = %entry
+  ret i32 0
+}
+
+
+define internal i32 @dummyCaller(i1 %cond) !prof !1 {
+entry:
+%val = call i32 @inlinedFunc(i1 %cond)
+ret i32 %val
+
+; CHECK-LABEL: @dummyCaller
+; CHECK: call
+; CHECK-NEXT: br i1 {{.*}}!prof [[COUNT1:![0-9]+]]
+}
+
+!llvm.module.flags = !{!0}
+!0 = !{i32 1, !"MaxFunctionCount", i32 10000}
+!1 = !{!"function_entry_count", i64 10000}
+!2 = !{!"branch_weights", i32 5, i32 5}
+!3 = !{!"branch_weights", i32 4, i32 1}
+
+; CHECK: [[COUNT1]] = !{!"branch_weights", i32 8, i32 31}