[BPI][NFC] Unify handling of normal and SCC based loops
authorEvgeniy Brevnov <ybrevnov@azul.com>
Wed, 29 Jul 2020 12:19:00 +0000 (19:19 +0700)
committerEvgeniy Brevnov <ybrevnov@azul.com>
Wed, 5 Aug 2020 04:19:24 +0000 (11:19 +0700)
This is one more NFC part extracted from D79485. Normal and SCC based loops have very different representation and have to be handled separatly each time we deal with loops. D79485 is going to introduce much more extensive use of loops what will be problematic with out this change.

Reviewed By: davidxl

Differential Revision: https://reviews.llvm.org/D84838

llvm/include/llvm/Analysis/BranchProbabilityInfo.h
llvm/lib/Analysis/BranchProbabilityInfo.cpp

index 7feb5b6..447f145 100644 (file)
@@ -32,6 +32,7 @@
 namespace llvm {
 
 class Function;
+class Loop;
 class LoopInfo;
 class raw_ostream;
 class PostDominatorTree;
@@ -230,6 +231,32 @@ private:
         : CallbackVH(const_cast<Value *>(V)), BPI(BPI) {}
   };
 
+  /// Pair of Loop and SCC ID number. Used to unify handling of normal and
+  /// SCC based loop representations.
+  using LoopData = std::pair<Loop *, int>;
+  /// Helper class to keep basic block along with its loop data information.
+  class LoopBlock {
+  public:
+    explicit LoopBlock(const BasicBlock *BB, const LoopInfo &LI,
+                       const SccInfo &SccI);
+
+    const BasicBlock *getBlock() const { return BB; }
+    Loop *getLoop() const { return LD.first; }
+    int getSccNum() const { return LD.second; }
+
+    bool belongsToLoop() const { return getLoop() || getSccNum() != -1; }
+    bool belongsToSameLoop(const LoopBlock &LB) const {
+      return (LB.getLoop() && getLoop() == LB.getLoop()) ||
+             (LB.getSccNum() != -1 && getSccNum() == LB.getSccNum());
+    }
+
+  private:
+    const BasicBlock *const BB = nullptr;
+    LoopData LD = {nullptr, -1};
+  };
+  // Pair of LoopBlocks representing an edge from first to second block.
+  using LoopEdge = std::pair<const LoopBlock &, const LoopBlock &>;
+
   DenseSet<BasicBlockCallbackVH, DenseMapInfo<Value*>> Handles;
 
   // Since we allow duplicate edges from one basic block to another, we use
@@ -258,6 +285,27 @@ private:
   /// Track the set of blocks that always lead to a cold call.
   SmallPtrSet<const BasicBlock *, 16> PostDominatedByColdCall;
 
+  /// Returns true if destination block belongs to some loop and source block is
+  /// either doesn't belong to any loop or belongs to a loop which is not inner
+  /// relative to the destination block.
+  bool isLoopEnteringEdge(const LoopEdge &Edge) const;
+  /// Returns true if source block belongs to some loop and destination block is
+  /// either doesn't belong to any loop or belongs to a loop which is not inner
+  /// relative to the source block.
+  bool isLoopExitingEdge(const LoopEdge &Edge) const;
+  /// Returns true if \p Edge is either enters to or exits from some loop, false
+  /// in all other cases.
+  bool isLoopEnteringExitingEdge(const LoopEdge &Edge) const;
+  /// Returns true if source and destination blocks belongs to the same loop and
+  /// destination block is loop header.
+  bool isLoopBackEdge(const LoopEdge &Edge) const;
+  // Fills in \p Enters vector with all "enter" blocks to a loop \LB belongs to.
+  void getLoopEnterBlocks(const LoopBlock &LB,
+                          SmallVectorImpl<BasicBlock *> &Enters) const;
+  // Fills in \p Exits vector with all "exit" blocks from a loop \LB belongs to.
+  void getLoopExitBlocks(const LoopBlock &LB,
+                         SmallVectorImpl<BasicBlock *> &Exits) const;
+
   void computePostDominatedByUnreachable(const Function &F,
                                          PostDominatorTree *PDT);
   void computePostDominatedByColdCall(const Function &F,
index 7b24fe9..0a14c8c 100644 (file)
@@ -247,6 +247,66 @@ void BranchProbabilityInfo::SccInfo::calculateSccBlockType(const BasicBlock *BB,
   }
 }
 
+BranchProbabilityInfo::LoopBlock::LoopBlock(const BasicBlock *BB,
+                                            const LoopInfo &LI,
+                                            const SccInfo &SccI)
+    : BB(BB) {
+  LD.first = LI.getLoopFor(BB);
+  if (!LD.first) {
+    LD.second = SccI.getSCCNum(BB);
+  }
+}
+
+bool BranchProbabilityInfo::isLoopEnteringEdge(const LoopEdge &Edge) const {
+  const auto &SrcBlock = Edge.first;
+  const auto &DstBlock = Edge.second;
+  return (DstBlock.getLoop() &&
+          !DstBlock.getLoop()->contains(SrcBlock.getLoop())) ||
+         // Assume that SCCs can't be nested.
+         (DstBlock.getSccNum() != -1 &&
+          SrcBlock.getSccNum() != DstBlock.getSccNum());
+}
+
+bool BranchProbabilityInfo::isLoopExitingEdge(const LoopEdge &Edge) const {
+  return isLoopEnteringEdge({Edge.second, Edge.first});
+}
+
+bool BranchProbabilityInfo::isLoopEnteringExitingEdge(
+    const LoopEdge &Edge) const {
+  return isLoopEnteringEdge(Edge) || isLoopExitingEdge(Edge);
+}
+
+bool BranchProbabilityInfo::isLoopBackEdge(const LoopEdge &Edge) const {
+  const auto &SrcBlock = Edge.first;
+  const auto &DstBlock = Edge.second;
+  return SrcBlock.belongsToSameLoop(DstBlock) &&
+         ((DstBlock.getLoop() &&
+           DstBlock.getLoop()->getHeader() == DstBlock.getBlock()) ||
+          (DstBlock.getSccNum() != -1 &&
+           SccI->isSCCHeader(DstBlock.getBlock(), DstBlock.getSccNum())));
+}
+
+void BranchProbabilityInfo::getLoopEnterBlocks(
+    const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Enters) const {
+  if (LB.getLoop()) {
+    auto *Header = LB.getLoop()->getHeader();
+    Enters.append(pred_begin(Header), pred_end(Header));
+  } else {
+    assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
+    SccI->getSccEnterBlocks(LB.getSccNum(), Enters);
+  }
+}
+
+void BranchProbabilityInfo::getLoopExitBlocks(
+    const LoopBlock &LB, SmallVectorImpl<BasicBlock *> &Exits) const {
+  if (LB.getLoop()) {
+    LB.getLoop()->getExitBlocks(Exits);
+  } else {
+    assert(LB.getSccNum() != -1 && "LB doesn't belong to any loop?");
+    SccI->getSccExitBlocks(LB.getSccNum(), Exits);
+  }
+}
+
 static void UpdatePDTWorklist(const BasicBlock *BB, PostDominatorTree *PDT,
                               SmallVectorImpl<const BasicBlock *> &WorkList,
                               SmallPtrSetImpl<const BasicBlock *> &TargetSet) {
@@ -720,17 +780,13 @@ computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
 // as taken, exiting edges as not-taken.
 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
                                                      const LoopInfo &LI) {
-  int SccNum;
-  Loop *L = LI.getLoopFor(BB);
-  if (!L) {
-    SccNum = SccI->getSCCNum(BB);
-    if (SccNum < 0)
-      return false;
-  }
+  LoopBlock LB(BB, LI, *SccI.get());
+  if (!LB.belongsToLoop())
+    return false;
 
   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
-  if (L)
-    computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
+  if (LB.getLoop())
+    computeUnlikelySuccessors(BB, LB.getLoop(), UnlikelyBlocks);
 
   SmallVector<unsigned, 8> BackEdges;
   SmallVector<unsigned, 8> ExitingEdges;
@@ -738,24 +794,19 @@ bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
   SmallVector<unsigned, 8> UnlikelyEdges;
 
   for (const_succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
-    // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
-    // irreducible loops.
-    if (L) {
-      if (UnlikelyBlocks.count(*I) != 0)
-        UnlikelyEdges.push_back(I.getSuccessorIndex());
-      else if (!L->contains(*I))
-        ExitingEdges.push_back(I.getSuccessorIndex());
-      else if (L->getHeader() == *I)
-        BackEdges.push_back(I.getSuccessorIndex());
-      else
-        InEdges.push_back(I.getSuccessorIndex());
-    } else {
-      if (SccI->getSCCNum(*I) != SccNum)
-        ExitingEdges.push_back(I.getSuccessorIndex());
-      else if (SccI->isSCCHeader(*I, SccNum))
-        BackEdges.push_back(I.getSuccessorIndex());
-      else
-        InEdges.push_back(I.getSuccessorIndex());
+    LoopBlock SuccLB(*I, LI, *SccI.get());
+    LoopEdge Edge(LB, SuccLB);
+    bool IsUnlikelyEdge =
+        LB.getLoop() && (UnlikelyBlocks.find(*I) != UnlikelyBlocks.end());
+
+    if (IsUnlikelyEdge)
+      UnlikelyEdges.push_back(I.getSuccessorIndex());
+    else if (isLoopExitingEdge(Edge))
+      ExitingEdges.push_back(I.getSuccessorIndex());
+    else if (isLoopBackEdge(Edge))
+      BackEdges.push_back(I.getSuccessorIndex());
+    else {
+      InEdges.push_back(I.getSuccessorIndex());
     }
   }