[SampleFDO][NFC] Refactor SampleProfileLoader to reuse in CodeGen
authorRong Xu <xur@google.com>
Wed, 10 Feb 2021 03:43:26 +0000 (19:43 -0800)
committerRong Xu <xur@google.com>
Wed, 10 Feb 2021 21:29:15 +0000 (13:29 -0800)
Break SampleProfileLoader into to a base and a derived class.
Base class (SampleProfileLoaderBaseImpl) includes the common
code for IR and MachineIR (CodeGen) sample loader.
It will be templatelized in the later patch.

Inline and Probe related code will remain in the derived class of
SampleProfileLoader and stays in SampleProfile.cpp.

We need to refactor some functions:
(1) getInstWeight() to enable the code sharing -- put the core into
getInstWeightImpl().
(2) emitAnnotation() and propagateWeights() to carve out the code
specific to SampleProfileLoader.
(3) make getInstWeight() and findFunctionSamples() virtual and override
in SampleProfileLoader as they need to access the fields in the derived
class.

Differential Revision: https://reviews.llvm.org/D95832

llvm/lib/Transforms/IPO/SampleProfile.cpp

index 03e6efd..22ef3e5 100644 (file)
@@ -234,8 +234,6 @@ using EdgeWeightMap = DenseMap<Edge, uint64_t>;
 using BlockEdgeMap =
     DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>>;
 
-class SampleProfileLoader;
-
 class SampleCoverageTracker {
 public:
   bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset,
@@ -389,63 +387,22 @@ using CandidateQueue =
     PriorityQueue<InlineCandidate, std::vector<InlineCandidate>,
                   CandidateComparer>;
 
-/// Sample profile pass.
-///
-/// This pass reads profile data from the file specified by
-/// -sample-profile-file and annotates every affected function with the
-/// profile information found in that file.
-class SampleProfileLoader {
+class SampleProfileLoaderBaseImpl {
 public:
-  SampleProfileLoader(
-      StringRef Name, StringRef RemapName, ThinOrFullLTOPhase LTOPhase,
-      std::function<AssumptionCache &(Function &)> GetAssumptionCache,
-      std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo,
-      std::function<const TargetLibraryInfo &(Function &)> GetTLI)
-      : GetAC(std::move(GetAssumptionCache)),
-        GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)),
-        Filename(std::string(Name)), RemappingFilename(std::string(RemapName)),
-        LTOPhase(LTOPhase) {}
-
-  bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr);
-  bool runOnModule(Module &M, ModuleAnalysisManager *AM,
-                   ProfileSummaryInfo *_PSI, CallGraph *CG);
-
+  SampleProfileLoaderBaseImpl(std::string Name) : Filename(Name) {}
   void dump() { Reader->dump(); }
 
 protected:
   friend class SampleCoverageTracker;
 
-  bool runOnFunction(Function &F, ModuleAnalysisManager *AM);
   unsigned getFunctionLoc(Function &F);
-  bool emitAnnotations(Function &F);
-  ErrorOr<uint64_t> getInstWeight(const Instruction &I);
-  ErrorOr<uint64_t> getProbeWeight(const Instruction &I);
+  virtual ErrorOr<uint64_t> getInstWeight(const Instruction &Inst);
+  ErrorOr<uint64_t> getInstWeightImpl(const Instruction &Inst);
   ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB);
-  const FunctionSamples *findCalleeFunctionSamples(const CallBase &I) const;
-  std::vector<const FunctionSamples *>
-  findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const;
-  mutable DenseMap<const DILocation *, const FunctionSamples *> DILocation2SampleMap;
-  const FunctionSamples *findFunctionSamples(const Instruction &I) const;
-  // Attempt to promote indirect call and also inline the promoted call
-  bool tryPromoteAndInlineCandidate(
-      Function &F, InlineCandidate &Candidate, uint64_t SumOrigin,
-      uint64_t &Sum, DenseSet<Instruction *> &PromotedInsns,
-      SmallVector<CallBase *, 8> *InlinedCallSites = nullptr);
-  bool inlineHotFunctions(Function &F,
-                          DenseSet<GlobalValue::GUID> &InlinedGUIDs);
-  InlineCost shouldInlineCandidate(InlineCandidate &Candidate);
-  bool getInlineCandidate(InlineCandidate *NewCandidate, CallBase *CB);
-  bool
-  tryInlineCandidate(InlineCandidate &Candidate,
-                     SmallVector<CallBase *, 8> *InlinedCallSites = nullptr);
-  bool
-  inlineHotFunctionsWithPriority(Function &F,
-                                 DenseSet<GlobalValue::GUID> &InlinedGUIDs);
-  // Inline cold/small functions in addition to hot ones
-  bool shouldInlineColdCallee(CallBase &CallInst);
-  void emitOptimizationRemarksForInlineCandidates(
-      const SmallVectorImpl<CallBase *> &Candidates, const Function &F,
-      bool Hot);
+  mutable DenseMap<const DILocation *, const FunctionSamples *>
+      DILocation2SampleMap;
+  virtual const FunctionSamples *
+  findFunctionSamples(const Instruction &I) const;
   void printEdgeWeight(raw_ostream &OS, Edge E);
   void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const;
   void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB);
@@ -458,10 +415,13 @@ protected:
   void propagateWeights(Function &F);
   uint64_t visitEdge(Edge E, unsigned *NumUnknownEdges, Edge *UnknownEdge);
   void buildEdges(Function &F);
-  std::vector<Function *> buildFunctionOrder(Module &M, CallGraph *CG);
   bool propagateThroughEdges(Function &F, bool UpdateBlockCount);
-  void computeDominanceAndLoopInfo(Function &F);
   void clearFunctionData();
+  void computeDominanceAndLoopInfo(Function &F);
+  bool
+  computeAndPropagateWeights(Function &F,
+                             const DenseSet<GlobalValue::GUID> &InlinedGUIDs);
+  void emitCoverageRemarks(Function &F);
 
   /// Map basic blocks to their computed weights.
   ///
@@ -489,21 +449,11 @@ protected:
   /// the same number of times.
   EquivalenceClassMap EquivalenceClass;
 
-  /// Map from function name to Function *. Used to find the function from
-  /// the function name. If the function name contains suffix, additional
-  /// entry is added to map from the stripped name to the function if there
-  /// is one-to-one mapping.
-  StringMap<Function *> SymbolMap;
-
   /// Dominance, post-dominance and loop information.
   std::unique_ptr<DominatorTree> DT;
   std::unique_ptr<PostDominatorTree> PDT;
   std::unique_ptr<LoopInfo> LI;
 
-  std::function<AssumptionCache &(Function &)> GetAC;
-  std::function<TargetTransformInfo &(Function &)> GetTTI;
-  std::function<const TargetLibraryInfo &(Function &)> GetTLI;
-
   /// Predecessors for each basic block in the CFG.
   BlockEdgeMap Predecessors;
 
@@ -515,15 +465,86 @@ protected:
   /// Profile reader object.
   std::unique_ptr<SampleProfileReader> Reader;
 
-  /// Profile tracker for different context.
-  std::unique_ptr<SampleContextTracker> ContextTracker;
-
   /// Samples collected for the body of this function.
   FunctionSamples *Samples = nullptr;
 
   /// Name of the profile file to load.
   std::string Filename;
 
+  /// Profile Summary Info computed from sample profile.
+  ProfileSummaryInfo *PSI = nullptr;
+
+  /// Optimization Remark Emitter used to emit diagnostic remarks.
+  OptimizationRemarkEmitter *ORE = nullptr;
+};
+
+/// Sample profile pass.
+///
+/// This pass reads profile data from the file specified by
+/// -sample-profile-file and annotates every affected function with the
+/// profile information found in that file.
+class SampleProfileLoader : public SampleProfileLoaderBaseImpl {
+public:
+  SampleProfileLoader(
+      StringRef Name, StringRef RemapName, ThinOrFullLTOPhase LTOPhase,
+      std::function<AssumptionCache &(Function &)> GetAssumptionCache,
+      std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo,
+      std::function<const TargetLibraryInfo &(Function &)> GetTLI)
+      : SampleProfileLoaderBaseImpl(std::string(Name)),
+        GetAC(std::move(GetAssumptionCache)),
+        GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)),
+        RemappingFilename(std::string(RemapName)), LTOPhase(LTOPhase) {}
+
+  bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr);
+  bool runOnModule(Module &M, ModuleAnalysisManager *AM,
+                   ProfileSummaryInfo *_PSI, CallGraph *CG);
+
+protected:
+  bool runOnFunction(Function &F, ModuleAnalysisManager *AM);
+  bool emitAnnotations(Function &F);
+  ErrorOr<uint64_t> getInstWeight(const Instruction &I) override;
+  ErrorOr<uint64_t> getProbeWeight(const Instruction &I);
+  const FunctionSamples *findCalleeFunctionSamples(const CallBase &I) const;
+  const FunctionSamples *
+  findFunctionSamples(const Instruction &I) const override;
+  std::vector<const FunctionSamples *>
+  findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const;
+  // Attempt to promote indirect call and also inline the promoted call
+  bool tryPromoteAndInlineCandidate(
+      Function &F, InlineCandidate &Candidate, uint64_t SumOrigin,
+      uint64_t &Sum, DenseSet<Instruction *> &PromotedInsns,
+      SmallVector<CallBase *, 8> *InlinedCallSites = nullptr);
+  bool inlineHotFunctions(Function &F,
+                          DenseSet<GlobalValue::GUID> &InlinedGUIDs);
+  InlineCost shouldInlineCandidate(InlineCandidate &Candidate);
+  bool getInlineCandidate(InlineCandidate *NewCandidate, CallBase *CB);
+  bool
+  tryInlineCandidate(InlineCandidate &Candidate,
+                     SmallVector<CallBase *, 8> *InlinedCallSites = nullptr);
+  bool
+  inlineHotFunctionsWithPriority(Function &F,
+                                 DenseSet<GlobalValue::GUID> &InlinedGUIDs);
+  // Inline cold/small functions in addition to hot ones
+  bool shouldInlineColdCallee(CallBase &CallInst);
+  void emitOptimizationRemarksForInlineCandidates(
+      const SmallVectorImpl<CallBase *> &Candidates, const Function &F,
+      bool Hot);
+  std::vector<Function *> buildFunctionOrder(Module &M, CallGraph *CG);
+  void generateMDProfMetadata(Function &F);
+
+  /// Map from function name to Function *. Used to find the function from
+  /// the function name. If the function name contains suffix, additional
+  /// entry is added to map from the stripped name to the function if there
+  /// is one-to-one mapping.
+  StringMap<Function *> SymbolMap;
+
+  std::function<AssumptionCache &(Function &)> GetAC;
+  std::function<TargetTransformInfo &(Function &)> GetTTI;
+  std::function<const TargetLibraryInfo &(Function &)> GetTLI;
+
+  /// Profile tracker for different context.
+  std::unique_ptr<SampleContextTracker> ContextTracker;
+
   /// Name of the profile remapping file to load.
   std::string RemappingFilename;
 
@@ -540,9 +561,6 @@ protected:
   /// we will mark GUIDs that needs to be annotated to the function.
   ThinOrFullLTOPhase LTOPhase;
 
-  /// Profile Summary Info computed from sample profile.
-  ProfileSummaryInfo *PSI = nullptr;
-
   /// Profle Symbol list tells whether a function name appears in the binary
   /// used to generate the current profile.
   std::unique_ptr<ProfileSymbolList> PSL;
@@ -553,9 +571,6 @@ protected:
   /// at runtime.
   uint64_t TotalCollectedSamples = 0;
 
-  /// Optimization Remark Emitter used to emit diagnostic remarks.
-  OptimizationRemarkEmitter *ORE = nullptr;
-
   // Information recorded when we declined to inline a call site
   // because we have determined it is too cold is accumulated for
   // each callee function. Initially this is just the entry count.
@@ -758,7 +773,7 @@ unsigned SampleCoverageTracker::computeCoverage(unsigned Used,
 }
 
 /// Clear all the per-function data used to load samples and propagate weights.
-void SampleProfileLoader::clearFunctionData() {
+void SampleProfileLoaderBaseImpl::clearFunctionData() {
   BlockWeights.clear();
   EdgeWeights.clear();
   VisitedBlocks.clear();
@@ -777,7 +792,7 @@ void SampleProfileLoader::clearFunctionData() {
 ///
 /// \param OS  Stream to emit the output to.
 /// \param E  Edge to print.
-void SampleProfileLoader::printEdgeWeight(raw_ostream &OS, Edge E) {
+void SampleProfileLoaderBaseImpl::printEdgeWeight(raw_ostream &OS, Edge E) {
   OS << "weight[" << E.first->getName() << "->" << E.second->getName()
      << "]: " << EdgeWeights[E] << "\n";
 }
@@ -786,8 +801,8 @@ void SampleProfileLoader::printEdgeWeight(raw_ostream &OS, Edge E) {
 ///
 /// \param OS  Stream to emit the output to.
 /// \param BB  Block to print.
-void SampleProfileLoader::printBlockEquivalence(raw_ostream &OS,
-                                                const BasicBlock *BB) {
+void SampleProfileLoaderBaseImpl::printBlockEquivalence(raw_ostream &OS,
+                                                        const BasicBlock *BB) {
   const BasicBlock *Equiv = EquivalenceClass[BB];
   OS << "equivalence[" << BB->getName()
      << "]: " << ((Equiv) ? EquivalenceClass[BB]->getName() : "NONE") << "\n";
@@ -797,8 +812,8 @@ void SampleProfileLoader::printBlockEquivalence(raw_ostream &OS,
 ///
 /// \param OS  Stream to emit the output to.
 /// \param BB  Block to print.
-void SampleProfileLoader::printBlockWeight(raw_ostream &OS,
-                                           const BasicBlock *BB) const {
+void SampleProfileLoaderBaseImpl::printBlockWeight(raw_ostream &OS,
+                                                   const BasicBlock *BB) const {
   const auto &I = BlockWeights.find(BB);
   uint64_t W = (I == BlockWeights.end() ? 0 : I->second);
   OS << "weight[" << BB->getName() << "]: " << W << "\n";
@@ -816,6 +831,11 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS,
 /// \param Inst Instruction to query.
 ///
 /// \returns the weight of \p Inst.
+ErrorOr<uint64_t>
+SampleProfileLoaderBaseImpl::getInstWeight(const Instruction &Inst) {
+  return getInstWeightImpl(Inst);
+}
+
 ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
   if (FunctionSamples::ProfileIsProbeBased)
     return getProbeWeight(Inst);
@@ -824,13 +844,9 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
   if (!DLoc)
     return std::error_code();
 
-  const FunctionSamples *FS = findFunctionSamples(Inst);
-  if (!FS)
-    return std::error_code();
-
   // Ignore all intrinsics, phinodes and branch instructions.
-  // Branch and phinodes instruction usually contains debug info from sources outside of
-  // the residing basic block, thus we ignore them during annotation.
+  // Branch and phinodes instruction usually contains debug info from sources
+  // outside of the residing basic block, thus we ignore them during annotation.
   if (isa<BranchInst>(Inst) || isa<IntrinsicInst>(Inst) || isa<PHINode>(Inst))
     return std::error_code();
 
@@ -843,6 +859,19 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
       if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB))
         return 0;
 
+  return getInstWeightImpl(Inst);
+}
+
+ErrorOr<uint64_t>
+SampleProfileLoaderBaseImpl::getInstWeightImpl(const Instruction &Inst) {
+  const FunctionSamples *FS = findFunctionSamples(Inst);
+  if (!FS)
+    return std::error_code();
+
+  const DebugLoc &DLoc = Inst.getDebugLoc();
+  if (!DLoc)
+    return std::error_code();
+
   const DILocation *DIL = DLoc;
   uint32_t LineOffset = FunctionSamples::getOffset(DIL);
   uint32_t Discriminator = DIL->getBaseDiscriminator();
@@ -926,7 +955,8 @@ ErrorOr<uint64_t> SampleProfileLoader::getProbeWeight(const Instruction &Inst) {
 /// \param BB The basic block to query.
 ///
 /// \returns the weight for \p BB.
-ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) {
+ErrorOr<uint64_t>
+SampleProfileLoaderBaseImpl::getBlockWeight(const BasicBlock *BB) {
   uint64_t Max = 0;
   bool HasWeight = false;
   for (auto &I : BB->getInstList()) {
@@ -945,7 +975,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) {
 /// the weights of every basic block in the CFG.
 ///
 /// \param F The function to query.
-bool SampleProfileLoader::computeBlockWeights(Function &F) {
+bool SampleProfileLoaderBaseImpl::computeBlockWeights(Function &F) {
   bool Changed = false;
   LLVM_DEBUG(dbgs() << "Block weights\n");
   for (const auto &BB : F) {
@@ -1064,6 +1094,19 @@ SampleProfileLoader::findIndirectCallFunctionSamples(
 /// \param Inst Instruction to query.
 ///
 /// \returns the FunctionSamples pointer to the inlined instance.
+const FunctionSamples *SampleProfileLoaderBaseImpl::findFunctionSamples(
+    const Instruction &Inst) const {
+  const DILocation *DIL = Inst.getDebugLoc();
+  if (!DIL)
+    return Samples;
+
+  auto it = DILocation2SampleMap.try_emplace(DIL, nullptr);
+  if (it.second) {
+    it.first->second = Samples->findFunctionSamples(DIL, Reader->getRemapper());
+  }
+  return it.first->second;
+}
+
 const FunctionSamples *
 SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const {
   if (FunctionSamples::ProfileIsProbeBased) {
@@ -1623,7 +1666,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority(
 ///                 with blocks from \p BB1's dominator tree, then
 ///                 this is the post-dominator tree, and vice versa.
 template <bool IsPostDom>
-void SampleProfileLoader::findEquivalencesFor(
+void SampleProfileLoaderBaseImpl::findEquivalencesFor(
     BasicBlock *BB1, ArrayRef<BasicBlock *> Descendants,
     DominatorTreeBase<BasicBlock, IsPostDom> *DomTree) {
   const BasicBlock *EC = EquivalenceClass[BB1];
@@ -1665,7 +1708,7 @@ void SampleProfileLoader::findEquivalencesFor(
 /// dominates B2, B2 post-dominates B1 and both are in the same loop.
 ///
 /// \param F The function to query.
-void SampleProfileLoader::findEquivalenceClasses(Function &F) {
+void SampleProfileLoaderBaseImpl::findEquivalenceClasses(Function &F) {
   SmallVector<BasicBlock *, 8> DominatedBBs;
   LLVM_DEBUG(dbgs() << "\nBlock equivalence classes\n");
   // Find equivalence sets based on dominance and post-dominance information.
@@ -1725,8 +1768,9 @@ void SampleProfileLoader::findEquivalenceClasses(Function &F) {
 /// \param UnknownEdge  Set if E has not been visited before.
 ///
 /// \returns E's weight, if known. Otherwise, return 0.
-uint64_t SampleProfileLoader::visitEdge(Edge E, unsigned *NumUnknownEdges,
-                                        Edge *UnknownEdge) {
+uint64_t SampleProfileLoaderBaseImpl::visitEdge(Edge E,
+                                                unsigned *NumUnknownEdges,
+                                                Edge *UnknownEdge) {
   if (!VisitedEdges.count(E)) {
     (*NumUnknownEdges)++;
     *UnknownEdge = E;
@@ -1749,8 +1793,8 @@ uint64_t SampleProfileLoader::visitEdge(Edge E, unsigned *NumUnknownEdges,
 ///                          has already been annotated.
 ///
 /// \returns  True if new weights were assigned to edges or blocks.
-bool SampleProfileLoader::propagateThroughEdges(Function &F,
-                                                bool UpdateBlockCount) {
+bool SampleProfileLoaderBaseImpl::propagateThroughEdges(Function &F,
+                                                        bool UpdateBlockCount) {
   bool Changed = false;
   LLVM_DEBUG(dbgs() << "\nPropagation through edges\n");
   for (const auto &BI : F) {
@@ -1898,7 +1942,7 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F,
 ///
 /// We are interested in unique edges. If a block B1 has multiple
 /// edges to another block B2, we only add a single B1->B2 edge.
-void SampleProfileLoader::buildEdges(Function &F) {
+void SampleProfileLoaderBaseImpl::buildEdges(Function &F) {
   for (auto &BI : F) {
     BasicBlock *B1 = &BI;
 
@@ -1947,7 +1991,7 @@ static SmallVector<InstrProfValueData, 2> GetSortedValueDataFromCallTargets(
 ///   known, the weight for that edge is set to the weight of the block
 ///   minus the weight of the other incoming edges to that block (if
 ///   known).
-void SampleProfileLoader::propagateWeights(Function &F) {
+void SampleProfileLoaderBaseImpl::propagateWeights(Function &F) {
   bool Changed = true;
   unsigned I = 0;
 
@@ -1992,7 +2036,116 @@ void SampleProfileLoader::propagateWeights(Function &F) {
   while (Changed && I++ < SampleProfileMaxPropagateIterations) {
     Changed = propagateThroughEdges(F, true);
   }
+}
 
+/// Generate branch weight metadata for all branches in \p F.
+///
+/// Branch weights are computed out of instruction samples using a
+/// propagation heuristic. Propagation proceeds in 3 phases:
+///
+/// 1- Assignment of block weights. All the basic blocks in the function
+///    are initial assigned the same weight as their most frequently
+///    executed instruction.
+///
+/// 2- Creation of equivalence classes. Since samples may be missing from
+///    blocks, we can fill in the gaps by setting the weights of all the
+///    blocks in the same equivalence class to the same weight. To compute
+///    the concept of equivalence, we use dominance and loop information.
+///    Two blocks B1 and B2 are in the same equivalence class if B1
+///    dominates B2, B2 post-dominates B1 and both are in the same loop.
+///
+/// 3- Propagation of block weights into edges. This uses a simple
+///    propagation heuristic. The following rules are applied to every
+///    block BB in the CFG:
+///
+///    - If BB has a single predecessor/successor, then the weight
+///      of that edge is the weight of the block.
+///
+///    - If all the edges are known except one, and the weight of the
+///      block is already known, the weight of the unknown edge will
+///      be the weight of the block minus the sum of all the known
+///      edges. If the sum of all the known edges is larger than BB's weight,
+///      we set the unknown edge weight to zero.
+///
+///    - If there is a self-referential edge, and the weight of the block is
+///      known, the weight for that edge is set to the weight of the block
+///      minus the weight of the other incoming edges to that block (if
+///      known).
+///
+/// Since this propagation is not guaranteed to finalize for every CFG, we
+/// only allow it to proceed for a limited number of iterations (controlled
+/// by -sample-profile-max-propagate-iterations).
+///
+/// FIXME: Try to replace this propagation heuristic with a scheme
+/// that is guaranteed to finalize. A work-list approach similar to
+/// the standard value propagation algorithm used by SSA-CCP might
+/// work here.
+///
+/// \param F The function to query.
+///
+/// \returns true if \p F was modified. Returns false, otherwise.
+bool SampleProfileLoaderBaseImpl::computeAndPropagateWeights(
+    Function &F, const DenseSet<GlobalValue::GUID> &InlinedGUIDs) {
+  bool Changed = (InlinedGUIDs.size() != 0);
+
+  // Compute basic block weights.
+  Changed |= computeBlockWeights(F);
+
+  if (Changed) {
+    // Add an entry count to the function using the samples gathered at the
+    // function entry.
+    // Sets the GUIDs that are inlined in the profiled binary. This is used
+    // for ThinLink to make correct liveness analysis, and also make the IR
+    // match the profiled binary before annotation.
+    F.setEntryCount(
+        ProfileCount(Samples->getHeadSamples() + 1, Function::PCT_Real),
+        &InlinedGUIDs);
+
+    // Compute dominance and loop info needed for propagation.
+    computeDominanceAndLoopInfo(F);
+
+    // Find equivalence classes.
+    findEquivalenceClasses(F);
+
+    // Propagate weights to all edges.
+    propagateWeights(F);
+  }
+
+  return Changed;
+}
+
+void SampleProfileLoaderBaseImpl::emitCoverageRemarks(Function &F) {
+  // If coverage checking was requested, compute it now.
+  if (SampleProfileRecordCoverage) {
+    unsigned Used = CoverageTracker.countUsedRecords(Samples, PSI);
+    unsigned Total = CoverageTracker.countBodyRecords(Samples, PSI);
+    unsigned Coverage = CoverageTracker.computeCoverage(Used, Total);
+    if (Coverage < SampleProfileRecordCoverage) {
+      F.getContext().diagnose(DiagnosticInfoSampleProfile(
+          F.getSubprogram()->getFilename(), getFunctionLoc(F),
+          Twine(Used) + " of " + Twine(Total) + " available profile records (" +
+              Twine(Coverage) + "%) were applied",
+          DS_Warning));
+    }
+  }
+
+  if (SampleProfileSampleCoverage) {
+    uint64_t Used = CoverageTracker.getTotalUsedSamples();
+    uint64_t Total = CoverageTracker.countBodySamples(Samples, PSI);
+    unsigned Coverage = CoverageTracker.computeCoverage(Used, Total);
+    if (Coverage < SampleProfileSampleCoverage) {
+      F.getContext().diagnose(DiagnosticInfoSampleProfile(
+          F.getSubprogram()->getFilename(), getFunctionLoc(F),
+          Twine(Used) + " of " + Twine(Total) + " available profile samples (" +
+              Twine(Coverage) + "%) were applied",
+          DS_Warning));
+    }
+  }
+}
+
+// Generate MD_prof metadata for every branch instruction using the
+// edge weights computed during propagation.
+void SampleProfileLoader::generateMDProfMetadata(Function &F) {
   // Generate MD_prof metadata for every branch instruction using the
   // edge weights computed during propagation.
   LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch weights\n");
@@ -2108,7 +2261,7 @@ void SampleProfileLoader::propagateWeights(Function &F) {
 ///
 /// \returns the line number where \p F is defined. If it returns 0,
 ///          it means that there is no debug information available for \p F.
-unsigned SampleProfileLoader::getFunctionLoc(Function &F) {
+unsigned SampleProfileLoaderBaseImpl::getFunctionLoc(Function &F) {
   if (DISubprogram *S = F.getSubprogram())
     return S->getLine();
 
@@ -2124,7 +2277,7 @@ unsigned SampleProfileLoader::getFunctionLoc(Function &F) {
   return 0;
 }
 
-void SampleProfileLoader::computeDominanceAndLoopInfo(Function &F) {
+void SampleProfileLoaderBaseImpl::computeDominanceAndLoopInfo(Function &F) {
   DT.reset(new DominatorTree);
   DT->recalculate(F);
 
@@ -2134,49 +2287,6 @@ void SampleProfileLoader::computeDominanceAndLoopInfo(Function &F) {
   LI->analyze(*DT);
 }
 
-/// Generate branch weight metadata for all branches in \p F.
-///
-/// Branch weights are computed out of instruction samples using a
-/// propagation heuristic. Propagation proceeds in 3 phases:
-///
-/// 1- Assignment of block weights. All the basic blocks in the function
-///    are initial assigned the same weight as their most frequently
-///    executed instruction.
-///
-/// 2- Creation of equivalence classes. Since samples may be missing from
-///    blocks, we can fill in the gaps by setting the weights of all the
-///    blocks in the same equivalence class to the same weight. To compute
-///    the concept of equivalence, we use dominance and loop information.
-///    Two blocks B1 and B2 are in the same equivalence class if B1
-///    dominates B2, B2 post-dominates B1 and both are in the same loop.
-///
-/// 3- Propagation of block weights into edges. This uses a simple
-///    propagation heuristic. The following rules are applied to every
-///    block BB in the CFG:
-///
-///    - If BB has a single predecessor/successor, then the weight
-///      of that edge is the weight of the block.
-///
-///    - If all the edges are known except one, and the weight of the
-///      block is already known, the weight of the unknown edge will
-///      be the weight of the block minus the sum of all the known
-///      edges. If the sum of all the known edges is larger than BB's weight,
-///      we set the unknown edge weight to zero.
-///
-///    - If there is a self-referential edge, and the weight of the block is
-///      known, the weight for that edge is set to the weight of the block
-///      minus the weight of the other incoming edges to that block (if
-///      known).
-///
-/// Since this propagation is not guaranteed to finalize for every CFG, we
-/// only allow it to proceed for a limited number of iterations (controlled
-/// by -sample-profile-max-propagate-iterations).
-///
-/// FIXME: Try to replace this propagation heuristic with a scheme
-/// that is guaranteed to finalize. A work-list approach similar to
-/// the standard value propagation algorithm used by SSA-CCP might
-/// work here.
-///
 /// Once all the branch weights are computed, we emit the MD_prof
 /// metadata on BB using the computed values for each of its branches.
 ///
@@ -2209,55 +2319,12 @@ bool SampleProfileLoader::emitAnnotations(Function &F) {
   else
     Changed |= inlineHotFunctions(F, InlinedGUIDs);
 
-  // Compute basic block weights.
-  Changed |= computeBlockWeights(F);
+  Changed |= computeAndPropagateWeights(F, InlinedGUIDs);
 
-  if (Changed) {
-    // Add an entry count to the function using the samples gathered at the
-    // function entry.
-    // Sets the GUIDs that are inlined in the profiled binary. This is used
-    // for ThinLink to make correct liveness analysis, and also make the IR
-    // match the profiled binary before annotation.
-    F.setEntryCount(
-        ProfileCount(Samples->getHeadSamples() + 1, Function::PCT_Real),
-        &InlinedGUIDs);
-
-    // Compute dominance and loop info needed for propagation.
-    computeDominanceAndLoopInfo(F);
-
-    // Find equivalence classes.
-    findEquivalenceClasses(F);
-
-    // Propagate weights to all edges.
-    propagateWeights(F);
-  }
+  if (Changed)
+    generateMDProfMetadata(F);
 
-  // If coverage checking was requested, compute it now.
-  if (SampleProfileRecordCoverage) {
-    unsigned Used = CoverageTracker.countUsedRecords(Samples, PSI);
-    unsigned Total = CoverageTracker.countBodyRecords(Samples, PSI);
-    unsigned Coverage = CoverageTracker.computeCoverage(Used, Total);
-    if (Coverage < SampleProfileRecordCoverage) {
-      F.getContext().diagnose(DiagnosticInfoSampleProfile(
-          F.getSubprogram()->getFilename(), getFunctionLoc(F),
-          Twine(Used) + " of " + Twine(Total) + " available profile records (" +
-              Twine(Coverage) + "%) were applied",
-          DS_Warning));
-    }
-  }
-
-  if (SampleProfileSampleCoverage) {
-    uint64_t Used = CoverageTracker.getTotalUsedSamples();
-    uint64_t Total = CoverageTracker.countBodySamples(Samples, PSI);
-    unsigned Coverage = CoverageTracker.computeCoverage(Used, Total);
-    if (Coverage < SampleProfileSampleCoverage) {
-      F.getContext().diagnose(DiagnosticInfoSampleProfile(
-          F.getSubprogram()->getFilename(), getFunctionLoc(F),
-          Twine(Used) + " of " + Twine(Total) + " available profile samples (" +
-              Twine(Coverage) + "%) were applied",
-          DS_Warning));
-    }
-  }
+  emitCoverageRemarks(F);
   return Changed;
 }