[llvm][NFC] Refactor code to use ProfDataUtils
authorPaul Kirth <paulkirth@google.com>
Wed, 27 Jul 2022 21:44:24 +0000 (21:44 +0000)
committerPaul Kirth <paulkirth@google.com>
Wed, 3 Aug 2022 00:09:45 +0000 (00:09 +0000)
In this patch we replace common code patterns with the use of utility
functions for dealing with profiling metadata. There should be no change
in functionality, as the existing checks should be preserved in all
cases.

Reviewed By: bogner, davidxl

Differential Revision: https://reviews.llvm.org/D128860

15 files changed:
llvm/include/llvm/IR/Instruction.h
llvm/include/llvm/IR/ProfDataUtils.h
llvm/lib/Analysis/BranchProbabilityInfo.cpp
llvm/lib/CodeGen/CodeGenPrepare.cpp
llvm/lib/CodeGen/SelectOptimize.cpp
llvm/lib/IR/Metadata.cpp
llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp
llvm/lib/Transforms/IPO/PartialInlining.cpp
llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
llvm/lib/Transforms/Scalar/JumpThreading.cpp
llvm/lib/Transforms/Utils/LoopPeel.cpp
llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
llvm/lib/Transforms/Utils/LoopUtils.cpp
llvm/lib/Transforms/Utils/MisExpect.cpp
llvm/lib/Transforms/Utils/SimplifyCFG.cpp

index 15b0bdf..1044a63 100644 (file)
@@ -335,11 +335,6 @@ public:
   /// Sets the AA metadata on this instruction from the AAMDNodes structure.
   void setAAMetadata(const AAMDNodes &N);
 
-  /// Retrieve the raw weight values of a conditional branch or select.
-  /// Returns true on success with profile weights filled in.
-  /// Returns false if no metadata or invalid metadata was found.
-  bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const;
-
   /// Retrieve total raw weight values of a branch.
   /// Returns true on success with profile total weights filled in.
   /// Returns false if no metadata was found.
index b6c53a3..0051c41 100644 (file)
@@ -16,15 +16,15 @@ bool isBranchWeightMD(const MDNode *ProfileData);
 /// Checks if an instructions has Branch Weight Metadata
 ///
 /// \param I The instruction to check
-/// \return True if I has an MD_prof node containing Branch Weights. False
+/// \returns True if I has an MD_prof node containing Branch Weights. False
 /// otherwise.
 bool hasBranchWeightMD(const Instruction &I);
 
 /// Extract branch weights from MD_prof metadata
 ///
 /// \param ProfileData A pointer to an MDNode.
-/// \param Weights An output vector to fill with branch weights
-/// \return True if weights were extracted, False otherwise. When false Weights
+/// \param [out] Weights An output vector to fill with branch weights
+/// \returns True if weights were extracted, False otherwise. When false Weights
 /// will be cleared.
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights);
@@ -32,24 +32,28 @@ bool extractBranchWeights(const MDNode *ProfileData,
 /// Extract branch weights attatched to an Instruction
 ///
 /// \param I The Instruction to extract weights from.
-/// \param Weights An output vector to fill with branch weights
-/// \return True if weights were extracted, False otherwise. When false Weights
+/// \param [out] Weights An output vector to fill with branch weights
+/// \returns True if weights were extracted, False otherwise. When false Weights
 /// will be cleared.
 bool extractBranchWeights(const Instruction &I,
                           SmallVectorImpl<uint32_t> &Weights);
 
-/// Retrieve the raw weight values of a conditional branch or select.
-/// Returns true on success with profile weights filled in.
-/// Returns false if no metadata or invalid metadata was found.
+/// Extract branch weights from a conditional branch or select Instruction.
+///
+/// \param I The instruction to extract branch weights from.
+/// \param [out] TrueVal will contain the branch weight for the True branch
+/// \param [out] FalseVal will contain the branch weight for the False branch
+/// \returns True on success with profile weights filled in. False if no
+/// metadata or invalid metadata was found.
 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
                           uint64_t &FalseVal);
 
 /// Retrieve the total of all weights from MD_prof data.
 ///
 /// \param ProfileData The profile data to extract the total weight from
-/// \param TotalWeights input variable to fill with total weights
-/// \return true on success with profile total weights filled in.
-/// \return false if no metadata was found.
+/// \param [out] TotalWeights input variable to fill with total weights
+/// \returns True on success with profile total weights filled in. False if no
+/// metadata was found.
 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 
 } // namespace llvm
index f457287..8918fb9 100644 (file)
@@ -31,6 +31,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/InitializePasses.h"
@@ -401,24 +402,18 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
   SmallVector<uint32_t, 2> Weights;
   SmallVector<unsigned, 2> UnreachableIdxs;
   SmallVector<unsigned, 2> ReachableIdxs;
-  Weights.reserve(TI->getNumSuccessors());
-  for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(I));
-    if (!Weight)
-      return false;
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights.push_back(Weight->getZExtValue());
-    WeightSum += Weights.back();
+
+  extractBranchWeights(*TI, Weights);
+  for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+    WeightSum += Weights[I];
     const LoopBlock SrcLoopBB = getLoopBlock(BB);
-    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1));
+    const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I));
     auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB});
     if (EstimatedWeight &&
         *EstimatedWeight <= static_cast<uint32_t>(BlockExecWeight::UNREACHABLE))
-      UnreachableIdxs.push_back(I - 1);
+      UnreachableIdxs.push_back(I);
     else
-      ReachableIdxs.push_back(I - 1);
+      ReachableIdxs.push_back(I);
   }
   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
 
index d5d08f0..b100fbe 100644 (file)
@@ -65,6 +65,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Statepoint.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
@@ -6620,7 +6621,7 @@ static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI,
   // If metadata tells us that the select condition is obviously predictable,
   // then we want to replace the select with a branch.
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -8366,7 +8367,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // Another choice is to assume TrueProb for BB1 equals to TrueProb for
       // TmpBB, but the math is more complicated.
       uint64_t TrueWeight, FalseWeight;
-      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
+      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = TrueWeight;
         uint64_t NewFalseWeight = TrueWeight + 2 * FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);
@@ -8399,7 +8400,7 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, bool &ModifiedDT) {
       // assumes that
       //   FalseProb for BB1 == TrueProb for BB1 * FalseProb for TmpBB.
       uint64_t TrueWeight, FalseWeight;
-      if (Br1->extractProfMetadata(TrueWeight, FalseWeight)) {
+      if (extractBranchWeights(*Br1, TrueWeight, FalseWeight)) {
         uint64_t NewTrueWeight = 2 * TrueWeight + FalseWeight;
         uint64_t NewFalseWeight = FalseWeight;
         scaleWeights(NewTrueWeight, NewFalseWeight);
index 011f55e..a7a6987 100644 (file)
@@ -29,6 +29,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ScaledNumber.h"
@@ -655,7 +656,7 @@ bool SelectOptimize::hasExpensiveColdOperand(
     const SmallVector<SelectInst *, 2> &ASI) {
   bool ColdOperand = false;
   uint64_t TrueWeight, FalseWeight, TotalWeight;
-  if (ASI.front()->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) {
     uint64_t MinWeight = std::min(TrueWeight, FalseWeight);
     TotalWeight = TrueWeight + FalseWeight;
     // Is there a path with frequency <ColdOperandThreshold% (default:20%) ?
@@ -750,7 +751,7 @@ void SelectOptimize::getExclBackwardsSlice(Instruction *I,
 
 bool SelectOptimize::isSelectHighlyPredictable(const SelectInst *SI) {
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -959,7 +960,7 @@ SelectOptimize::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
                                      const SelectInst *SI) {
   Scaled64 PredPathCost;
   uint64_t TrueWeight, FalseWeight;
-  if (SI->extractProfMetadata(TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
     uint64_t SumWeight = TrueWeight + FalseWeight;
     if (SumWeight != 0) {
       PredPathCost = TrueCost * Scaled64::get(TrueWeight) +
index 2a1a514..a202743 100644 (file)
@@ -40,6 +40,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/TrackingMDRef.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -1493,31 +1494,6 @@ void Instruction::getAllMetadataImpl(
   Value::getAllMetadata(Result);
 }
 
-bool Instruction::extractProfMetadata(uint64_t &TrueVal,
-                                      uint64_t &FalseVal) const {
-  assert(
-      (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) &&
-      "Looking for branch weights on something besides branch or select");
-
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData || ProfileData->getNumOperands() != 3)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return false;
-
-  auto *CITrue = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1));
-  auto *CIFalse = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2));
-  if (!CITrue || !CIFalse)
-    return false;
-
-  TrueVal = CITrue->getValue().getZExtValue();
-  FalseVal = CIFalse->getValue().getZExtValue();
-
-  return true;
-}
-
 bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
   assert(
       (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select ||
@@ -1526,32 +1502,7 @@ bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const {
        getOpcode() == Instruction::Switch) &&
       "Looking for branch weights on something besides branch");
 
-  TotalVal = 0;
-  auto *ProfileData = getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return false;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName)
-    return false;
-
-  if (ProfDataName->getString().equals("branch_weights")) {
-    TotalVal = 0;
-    for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) {
-      auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i));
-      if (!V)
-        return false;
-      TotalVal += V->getValue().getZExtValue();
-    }
-    return true;
-  } else if (ProfDataName->getString().equals("VP") &&
-             ProfileData->getNumOperands() > 3) {
-    TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
-                   ->getValue()
-                   .getZExtValue();
-    return true;
-  }
-  return false;
+  return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
 void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) {
index cf72893..c894587 100644 (file)
@@ -15,6 +15,7 @@
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/TargetSchedule.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/KnownBits.h"
@@ -757,7 +758,7 @@ bool PPCTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
     if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
       uint64_t TrueWeight = 0, FalseWeight = 0;
       if (!BI->isConditional() ||
-          !BI->extractProfMetadata(TrueWeight, FalseWeight))
+          !extractBranchWeights(*BI, TrueWeight, FalseWeight))
         continue;
 
       // If the exit path is more frequent than the loop path,
index 54c72bd..ec2e7fb 100644 (file)
@@ -40,6 +40,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/User.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -717,7 +718,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) {
     if (!BR || BR->isUnconditional())
       continue;
     uint64_t T, F;
-    if (BR->extractProfMetadata(T, F))
+    if (extractBranchWeights(*BR, T, F))
       return true;
   }
   return false;
index c4512d0..90cc61c 100644 (file)
@@ -91,6 +91,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
@@ -2067,7 +2068,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
       // Display scaled counts for SELECT instruction:
       OS << "SELECT : { T = ";
       uint64_t TC, FC;
-      bool HasProf = I.extractProfMetadata(TC, FC);
+      bool HasProf = extractBranchWeights(I, TC, FC);
       if (!HasProf)
         OS << "Unknown, F = Unknown }\\l";
       else
index b31eab5..0113e7b 100644 (file)
@@ -54,6 +54,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/Value.h"
@@ -216,7 +217,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     return;
 
   uint64_t TrueWeight, FalseWeight;
-  if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*CondBr, TrueWeight, FalseWeight))
     return;
 
   if (TrueWeight + FalseWeight == 0)
@@ -279,7 +280,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     // With PGO, this can be used to refine even existing profile data with
     // context information. This needs to be done after more performance
     // testing.
-    if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight))
+    if (extractBranchWeights(*PredBr, PredTrueWeight, PredFalseWeight))
       continue;
 
     // We can not infer anything useful when BP >= 50%, because BP is the
index f093fea..9a7f9df 100644 (file)
@@ -29,6 +29,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -532,7 +533,7 @@ static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
                               uint64_t &ExitWeight,
                               uint64_t &FallThroughWeight) {
   uint64_t TrueWeight, FalseWeight;
-  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
     return;
   unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
   ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;
index 023a0af..1c44ccb 100644 (file)
@@ -30,6 +30,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -471,7 +472,7 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop,
   uint64_t TrueWeight, FalseWeight;
   BranchInst *LatchBR =
       cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator());
-  if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
+  if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight))
     return;
   uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader()
                             ? FalseWeight
index 349063d..03272e5 100644 (file)
@@ -38,6 +38,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
@@ -789,7 +790,7 @@ getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L,
   // know the number of times the backedge was taken, vs. the number of times
   // we exited the loop.
   uint64_t LoopWeight, ExitWeight;
-  if (!ExitingBranch->extractProfMetadata(LoopWeight, ExitWeight))
+  if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
     return None;
 
   if (L->contains(ExitingBranch->getSuccessor(1)))
index d85e9dd..8f94902 100644 (file)
@@ -35,6 +35,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -118,34 +119,6 @@ void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx,
 namespace llvm {
 namespace misexpect {
 
-// Helper function to extract branch weights into a vector
-Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I,
-                                                  LLVMContext &Ctx) {
-  assert(I && "MisExpect::extractWeights given invalid pointer");
-
-  auto *ProfileData = I->getMetadata(LLVMContext::MD_prof);
-  if (!ProfileData)
-    return None;
-
-  unsigned NOps = ProfileData->getNumOperands();
-  if (NOps < 3)
-    return None;
-
-  auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
-  if (!ProfDataName || !ProfDataName->getString().equals("branch_weights"))
-    return None;
-
-  SmallVector<uint32_t, 4> Weights(NOps - 1);
-  for (unsigned Idx = 1; Idx < NOps; Idx++) {
-    ConstantInt *Value =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    uint32_t V = Value->getZExtValue();
-    Weights[Idx - 1] = V;
-  }
-
-  return Weights;
-}
-
 // TODO: when clang allows c++17, use std::clamp instead
 uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) {
   if (value > hi)
@@ -218,19 +191,17 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights,
 
 void checkBackendInstrumentation(Instruction &I,
                                  const ArrayRef<uint32_t> RealWeights) {
-  auto ExpectedWeightsOpt = extractWeights(&I, I.getContext());
-  if (!ExpectedWeightsOpt)
+  SmallVector<uint32_t> ExpectedWeights;
+  if (!extractBranchWeights(I, ExpectedWeights))
     return;
-  auto ExpectedWeights = ExpectedWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
 void checkFrontendInstrumentation(Instruction &I,
                                   const ArrayRef<uint32_t> ExpectedWeights) {
-  auto RealWeightsOpt = extractWeights(&I, I.getContext());
-  if (!RealWeightsOpt)
+  SmallVector<uint32_t> RealWeights;
+  if (!extractBranchWeights(I, RealWeights))
     return;
-  auto RealWeights = RealWeightsOpt.value();
   verifyMisExpect(I, RealWeights, ExpectedWeights);
 }
 
index 1806081..bba8315 100644 (file)
@@ -57,6 +57,7 @@
 #include "llvm/IR/NoFolder.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -1050,15 +1051,6 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
   return LHS->getValue().ult(RHS->getValue()) ? 1 : -1;
 }
 
-static inline bool HasBranchWeights(const Instruction *I) {
-  MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof);
-  if (ProfMD && ProfMD->getOperand(0))
-    if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0)))
-      return MDS->getString().equals("branch_weights");
-
-  return false;
-}
-
 /// Get Weights of a given terminator, the default weight is at the front
 /// of the vector. If TI is a conditional eq, we need to swap the branch-weight
 /// metadata.
@@ -1177,8 +1169,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding(
 
   // Update the branch weight metadata along the way
   SmallVector<uint64_t, 8> Weights;
-  bool PredHasWeights = HasBranchWeights(PTI);
-  bool SuccHasWeights = HasBranchWeights(TI);
+  bool PredHasWeights = hasBranchWeightMD(*PTI);
+  bool SuccHasWeights = hasBranchWeightMD(*TI);
 
   if (PredHasWeights) {
     GetBranchWeights(PTI, Weights);
@@ -2752,7 +2744,8 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB,
   // the `then` block, then avoid speculating it.
   if (!BI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) {
+    if (extractBranchWeights(*BI, TWeight, FWeight) &&
+        (TWeight + FWeight) != 0) {
       uint64_t EndWeight = Invert ? TWeight : FWeight;
       BranchProbability BIEndProb =
           BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight);
@@ -3174,7 +3167,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI,
   // from the block that we know is predictably not entered.
   if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) {
     uint64_t TWeight, FWeight;
-    if (DomBI->extractProfMetadata(TWeight, FWeight) &&
+    if (extractBranchWeights(*DomBI, TWeight, FWeight) &&
         (TWeight + FWeight) != 0) {
       BranchProbability BITrueProb =
           BranchProbability::getBranchProbability(TWeight, TWeight + FWeight);
@@ -3354,9 +3347,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI,
                                    uint64_t &SuccTrueWeight,
                                    uint64_t &SuccFalseWeight) {
   bool PredHasWeights =
-      PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight);
+      extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight);
   bool SuccHasWeights =
-      BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight);
+      extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight);
   if (PredHasWeights || SuccHasWeights) {
     if (!PredHasWeights)
       PredTrueWeight = PredFalseWeight = 1;
@@ -3384,7 +3377,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI,
   uint64_t PTWeight, PFWeight;
   BranchProbability PBITrueProb, Likely;
   if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) &&
-      PBI->extractProfMetadata(PTWeight, PFWeight) &&
+      extractBranchWeights(*PBI, PTWeight, PFWeight) &&
       (PTWeight + PFWeight) != 0) {
     PBITrueProb =
         BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight);
@@ -4349,7 +4342,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI,
   // Get weight for TrueBB and FalseBB.
   uint32_t TrueWeight = 0, FalseWeight = 0;
   SmallVector<uint64_t, 8> Weights;
-  bool HasWeights = HasBranchWeights(SI);
+  bool HasWeights = hasBranchWeightMD(*SI);
   if (HasWeights) {
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {
@@ -5209,7 +5202,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI,
   BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
 
   // Update weight for the newly-created conditional branch.
-  if (HasBranchWeights(SI)) {
+  if (hasBranchWeightMD(*SI)) {
     SmallVector<uint64_t, 8> Weights;
     GetBranchWeights(SI, Weights);
     if (Weights.size() == 1 + SI->getNumCases()) {