[LoopNest] Add new utilites
authorWhitney Tsang <whitneyt@ca.ibm.com>
Thu, 13 Jan 2022 22:12:25 +0000 (17:12 -0500)
committerWhitney Tsang <whitneyt@ca.ibm.com>
Thu, 13 Jan 2022 22:19:19 +0000 (17:19 -0500)
getLoopIndex() is added to get the loop index of a given loop.
getLoopsAtDepth() is added to get the loops in the nest at a given
depth.

Reviewed By: Meinersbur

Differential Revision: https://reviews.llvm.org/D115590

llvm/include/llvm/Analysis/LoopNestAnalysis.h
llvm/unittests/Analysis/LoopNestTest.cpp

index 3d4a064..852a6c4 100644 (file)
@@ -102,12 +102,35 @@ public:
     return Loops[Index];
   }
 
+  /// Get the loop index of the given loop \p L.
+  unsigned getLoopIndex(const Loop &L) const {
+    for (unsigned I = 0; I < getNumLoops(); ++I)
+      if (getLoop(I) == &L)
+        return I;
+    llvm_unreachable("Loop not in the loop nest");
+  }
+
   /// Return the number of loops in the nest.
   size_t getNumLoops() const { return Loops.size(); }
 
   /// Get the loops in the nest.
   ArrayRef<Loop *> getLoops() const { return Loops; }
 
+  /// Get the loops in the nest at the given \p Depth.
+  LoopVectorTy getLoopsAtDepth(unsigned Depth) const {
+    assert(Depth >= Loops.front()->getLoopDepth() &&
+           Depth <= Loops.back()->getLoopDepth() && "Invalid depth");
+    LoopVectorTy Result;
+    for (unsigned I = 0; I < getNumLoops(); ++I) {
+      Loop *L = getLoop(I);
+      if (L->getLoopDepth() == Depth)
+        Result.push_back(L);
+      else if (L->getLoopDepth() > Depth)
+        break;
+    }
+    return Result;
+  }
+
   /// Retrieve a vector of perfect loop nests contained in the current loop
   /// nest. For example, given the following  nest containing 4 loops, this
   /// member function would return {{L1,L2},{L3,L4}}.
index a279632..596dbcf 100644 (file)
@@ -106,6 +106,19 @@ TEST(LoopNestTest, PerfectLoopNest) {
     const ArrayRef<Loop*> Loops = LN.getLoops();
     EXPECT_EQ(Loops.size(), 2ull);
 
+    // Ensure that we can obtain loops by depth.
+    LoopVectorTy LoopsAtDepth1 = LN.getLoopsAtDepth(1);
+    EXPECT_EQ(LoopsAtDepth1.size(), 1u);
+    EXPECT_EQ(LoopsAtDepth1[0], &OL);
+    LoopVectorTy LoopsAtDepth2 = LN.getLoopsAtDepth(2);
+    EXPECT_EQ(LoopsAtDepth2.size(), 1u);
+    EXPECT_EQ(LoopsAtDepth2[0], IL);
+
+    // Ensure that we can obtain the loop index of a given loop, and get back
+    // the loop with that index.
+    EXPECT_EQ(LN.getLoop(LN.getLoopIndex(OL)), &OL);
+    EXPECT_EQ(LN.getLoop(LN.getLoopIndex(*IL)), IL);
+
     // Ensure the loop nest is recognized as perfect in its entirety.
     const SmallVector<LoopVectorTy, 4> &PLV = LN.getPerfectLoops(SE);
     EXPECT_EQ(PLV.size(), 1ull);