[LoopNest] Consider loop nest with inner loop guard using outer loop
authorWhitney Tsang <whitneyt@ca.ibm.com>
Fri, 7 May 2021 15:36:55 +0000 (15:36 +0000)
committerWhitney Tsang <whitneyt@ca.ibm.com>
Fri, 7 May 2021 16:04:18 +0000 (16:04 +0000)
induction variable to be perfect

This patch allow more conditional branches to be considered as loop
guard, and so more loop nests can be considered perfect.

Reviewed By: bmahjour, sidbav

Differential Revision: https://reviews.llvm.org/D94717

llvm/include/llvm/Analysis/LoopNestAnalysis.h
llvm/lib/Analysis/LoopInfo.cpp
llvm/lib/Analysis/LoopNestAnalysis.cpp
llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll
llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll
llvm/unittests/Analysis/LoopInfoTest.cpp

index ace1754..e045419 100644 (file)
@@ -61,10 +61,12 @@ public:
   static unsigned getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE);
 
   /// Recursivelly traverse all empty 'single successor' basic blocks of \p From
-  /// (if there are any). Return the last basic block found or \p End if it was
-  /// reached during the search.
+  /// (if there are any). When \p CheckUniquePred is set to true, check if
+  /// each of the empty single successors has a unique predecessor. Return
+  /// the last basic block found or \p End if it was reached during the search.
   static const BasicBlock &skipEmptyBlockUntil(const BasicBlock *From,
-                                               const BasicBlock *End);
+                                               const BasicBlock *End,
+                                               bool CheckUniquePred = false);
 
   /// Return the outermost loop in the loop nest.
   Loop &getOutermostLoop() const { return *Loops.front(); }
index adb2bdb..b2d7edb 100644 (file)
@@ -20,6 +20,7 @@
 #include "llvm/Analysis/IVDescriptors.h"
 #include "llvm/Analysis/LoopInfoImpl.h"
 #include "llvm/Analysis/LoopIterator.h"
+#include "llvm/Analysis/LoopNestAnalysis.h"
 #include "llvm/Analysis/MemorySSA.h"
 #include "llvm/Analysis/MemorySSAUpdater.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
@@ -380,10 +381,6 @@ BranchInst *Loop::getLoopGuardBranch() const {
   if (!ExitFromLatch)
     return nullptr;
 
-  BasicBlock *ExitFromLatchSucc = ExitFromLatch->getUniqueSuccessor();
-  if (!ExitFromLatchSucc)
-    return nullptr;
-
   BasicBlock *GuardBB = Preheader->getUniquePredecessor();
   if (!GuardBB)
     return nullptr;
@@ -397,7 +394,17 @@ BranchInst *Loop::getLoopGuardBranch() const {
   BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader)
                                    ? GuardBI->getSuccessor(1)
                                    : GuardBI->getSuccessor(0);
-  return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr;
+
+  // Check if ExitFromLatch (or any BasicBlock which is an empty unique
+  // successor of ExitFromLatch) is equal to GuardOtherSucc. If
+  // skipEmptyBlockUntil returns GuardOtherSucc, then the guard branch for the
+  // loop is GuardBI (return GuardBI), otherwise return nullptr.
+  if (&LoopNest::skipEmptyBlockUntil(ExitFromLatch, GuardOtherSucc,
+                                     /*CheckUniquePred=*/true) ==
+      GuardOtherSucc)
+    return GuardBI;
+  else
+    return nullptr;
 }
 
 bool Loop::isCanonical(ScalarEvolution &SE) const {
index ee74d4b..2649ed6 100644 (file)
@@ -206,7 +206,8 @@ unsigned LoopNest::getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE) {
 }
 
 const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From,
-                                                const BasicBlock *End) {
+                                                const BasicBlock *End,
+                                                bool CheckUniquePred) {
   assert(From && "Expecting valid From");
   assert(End && "Expecting valid End");
 
@@ -220,8 +221,9 @@ const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From,
   // Visited is used to avoid running into an infinite loop.
   SmallPtrSet<const BasicBlock *, 4> Visited;
   const BasicBlock *BB = From->getUniqueSuccessor();
-  const BasicBlock *PredBB = BB;
-  while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB)) {
+  const BasicBlock *PredBB = From;
+  while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB) &&
+         (!CheckUniquePred || BB->getUniquePredecessor())) {
     Visited.insert(BB);
     PredBB = BB;
     BB = BB->getUniqueSuccessor();
@@ -335,9 +337,11 @@ static bool checkLoopsStructure(const Loop &OuterLoop, const Loop &InnerLoop,
 
   // Ensure the inner loop exit block lead to the outer loop latch possibly
   // through empty blocks.
-  const BasicBlock &SuccInner =
-      LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(), OuterLoopLatch);
-  if (&SuccInner != OuterLoopLatch && &SuccInner != ExtraPhiBlock) {
+  if ((!ExtraPhiBlock ||
+       &LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(),
+                                      ExtraPhiBlock) != ExtraPhiBlock) &&
+      (&LoopNest::skipEmptyBlockUntil(InnerLoop.getExitBlock(),
+                                      OuterLoopLatch) != OuterLoopLatch)) {
     DEBUG_WITH_TYPE(
         VerboseDebug,
         dbgs() << "Inner loop exit block " << *InnerLoopExit
index 4c8066e..77b361b 100644 (file)
@@ -424,70 +424,3 @@ for.cond.for.end13_crit_edge:
 for.end13:                   
   ret void
 }
-
-; Test an imperfect loop nest of the form:
-;   for (int i = 0; i < nx; ++i)
-;     if (i > 5) { // user branch
-;       for (int j = 1; j <= 5; j+=2)
-;         y[j][i] = x[i][j] + j;
-;     }
-
-define void @imperf_nest_6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) {
-;    CHECK-LABEL: IsPerfect=false, Depth=2, OutermostLoop: imperf_nest_6_loop_i, Loops: ( imperf_nest_6_loop_i imperf_nest_6_loop_j )
-entry:
-  %cmp2 = icmp slt i32 0, %nx
-  br i1 %cmp2, label %imperf_nest_6_loop_i.lr.ph, label %for.end13
-
-imperf_nest_6_loop_i.lr.ph:
-  br label %imperf_nest_6_loop_i
-
-imperf_nest_6_loop_i:      
-  %i.0 = phi i32 [ 0, %imperf_nest_6_loop_i.lr.ph ], [ %inc12, %for.inc11 ]
-  %cmp1 = icmp sgt i32 %i.0, 5
-  br i1 %cmp1, label %imperf_nest_6_loop_j.lr.ph, label %if.end
-
-imperf_nest_6_loop_j.lr.ph:
-  br label %imperf_nest_6_loop_j
-
-imperf_nest_6_loop_j:      
-  %j.0 = phi i32 [ 1, %imperf_nest_6_loop_j.lr.ph ], [ %inc, %for.inc ]
-  %idxprom = sext i32 %i.0 to i64
-  %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom
-  %0 = load i32*, i32** %arrayidx, align 8
-  %idxprom5 = sext i32 %j.0 to i64
-  %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5
-  %1 = load i32, i32* %arrayidx6, align 4
-  %add = add nsw i32 %1, %j.0
-  %idxprom7 = sext i32 %j.0 to i64
-  %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7
-  %2 = load i32*, i32** %arrayidx8, align 8
-  %idxprom9 = sext i32 %i.0 to i64
-  %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9
-  store i32 %add, i32* %arrayidx10, align 4
-  br label %for.inc
-
-for.inc:
-  %inc = add nsw i32 %j.0, 2
-  %cmp3 = icmp sle i32 %inc, 5
-  br i1 %cmp3, label %imperf_nest_6_loop_j, label %for.cond2.for.end_crit_edge
-
-for.cond2.for.end_crit_edge:
-  br label %for.end
-
-for.end:                    
-  br label %if.end
-
-if.end:                     
-  br label %for.inc11
-
-for.inc11:                  
-  %inc12 = add nsw i32 %i.0, 1
-  %cmp = icmp slt i32 %inc12, %nx
-  br i1 %cmp, label %imperf_nest_6_loop_i, label %for.cond.for.end13_crit_edge
-
-for.cond.for.end13_crit_edge:
-  br label %for.end13
-
-for.end13:                   
-  ret void
-}
index 7593d6f..f8b0e6a 100644 (file)
@@ -322,3 +322,148 @@ for.end7:
   %x.addr.0.lcssa = phi i32 [ %split7, %for.cond.for.end7_crit_edge ], [ %x, %entry ]
   ret i32 %x.addr.0.lcssa
 }
+
+; Test a perfect loop nest of the form:
+;   for (int i = 0; i < nx; ++i)
+;     if (i < ny) { // guard branch for the j-loop
+;       for (int j=i; j < ny; j+=1)
+;         y[j][i] = x[i][j] + j;
+;     }
+define double @perf_nest_guard_branch(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) {
+; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 )
+; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 )
+entry:
+  %cmp2 = icmp slt i32 0, %nx
+  br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13
+
+test6Loop1.lr.ph:                                   ; preds = %entry
+  br label %test6Loop1
+
+test6Loop1:                                         ; preds = %test6Loop1.lr.ph, %for.inc11
+  %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ]
+  %cmp1 = icmp slt i32 %i.0, %ny
+  br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end
+
+test6Loop2.lr.ph:                                  ; preds = %if.then
+  br label %test6Loop2
+
+test6Loop2:                                        ; preds = %test6Loop2.lr.ph, %for.inc
+  %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ]
+  %idxprom = sext i32 %i.0 to i64
+  %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom
+  %0 = load i32*, i32** %arrayidx, align 8
+  %idxprom5 = sext i32 %j.0 to i64
+  %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5
+  %1 = load i32, i32* %arrayidx6, align 4
+  %add = add nsw i32 %1, %j.0
+  %idxprom7 = sext i32 %j.0 to i64
+  %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7
+  %2 = load i32*, i32** %arrayidx8, align 8
+  %idxprom9 = sext i32 %i.0 to i64
+  %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9
+  store i32 %add, i32* %arrayidx10, align 4
+  br label %for.inc
+
+for.inc:                                          ; preds = %test6Loop2
+  %inc = add nsw i32 %j.0, 1
+  %cmp3 = icmp slt i32 %inc, %ny
+  br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge
+
+for.cond2.for.end_crit_edge:                      ; preds = %for.inc
+  br label %for.end
+
+for.end:                                          ; preds = %for.cond2.for.end_crit_edge, %if.then
+  br label %if.end
+
+if.end:                                           ; preds = %for.end, %test6Loop1
+  br label %for.inc11
+
+for.inc11:                                        ; preds = %if.end
+  %inc12 = add nsw i32 %i.0, 1
+  %cmp = icmp slt i32 %inc12, %nx
+  br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge
+
+for.cond.for.end13_crit_edge:                     ; preds = %for.inc11
+  br label %for.end13
+
+for.end13:                                        ; preds = %for.cond.for.end13_crit_edge, %entry
+  %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0
+  %3 = load i32*, i32** %arrayidx14, align 8
+  %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0
+  %4 = load i32, i32* %arrayidx15, align 4
+  %conv = sitofp i32 %4 to double
+  ret double %conv
+}
+
+; Test a perfect loop nest of the form:
+;   for (int i = 0; i < nx; ++i)
+;     if (i < ny) { // guard branch for the j-loop
+;       for (int j=i; j < ny; j+=1)
+;         y[j][i] = x[i][j] + j;
+;     }
+
+define double @test6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) {
+; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 )
+; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 )
+entry:
+  %cmp2 = icmp slt i32 0, %nx
+  br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13
+
+test6Loop1.lr.ph:                                   ; preds = %entry
+  br label %test6Loop1
+
+test6Loop1:                                         ; preds = %test6Loop1.lr.ph, %for.inc11
+  %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ]
+  %cmp1 = icmp slt i32 %i.0, %ny
+  br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end
+
+test6Loop2.lr.ph:                                  ; preds = %if.then
+  br label %test6Loop2
+
+test6Loop2:                                        ; preds = %test6Loop2.lr.ph, %for.inc
+  %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ]
+  %idxprom = sext i32 %i.0 to i64
+  %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom
+  %0 = load i32*, i32** %arrayidx, align 8
+  %idxprom5 = sext i32 %j.0 to i64
+  %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5
+  %1 = load i32, i32* %arrayidx6, align 4
+  %add = add nsw i32 %1, %j.0
+  %idxprom7 = sext i32 %j.0 to i64
+  %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7
+  %2 = load i32*, i32** %arrayidx8, align 8
+  %idxprom9 = sext i32 %i.0 to i64
+  %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9
+  store i32 %add, i32* %arrayidx10, align 4
+  br label %for.inc
+
+for.inc:                                          ; preds = %test6Loop2
+  %inc = add nsw i32 %j.0, 1
+  %cmp3 = icmp slt i32 %inc, %ny
+  br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge
+
+for.cond2.for.end_crit_edge:                      ; preds = %for.inc
+  br label %for.end
+
+for.end:                                          ; preds = %for.cond2.for.end_crit_edge, %if.then
+  br label %if.end
+
+if.end:                                           ; preds = %for.end, %test6Loop1
+  br label %for.inc11
+
+for.inc11:                                        ; preds = %if.end
+  %inc12 = add nsw i32 %i.0, 1
+  %cmp = icmp slt i32 %inc12, %nx
+  br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge
+
+for.cond.for.end13_crit_edge:                     ; preds = %for.inc11
+  br label %for.end13
+
+for.end13:                                        ; preds = %for.cond.for.end13_crit_edge, %entry
+  %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0
+  %3 = load i32*, i32** %arrayidx14, align 8
+  %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0
+  %4 = load i32, i32* %arrayidx15, align 4
+  %conv = sitofp i32 %4 to double
+  ret double %conv
+}
index bb51890..db6484f 100644 (file)
@@ -1500,3 +1500,51 @@ TEST(LoopInfoTest, LoopNotRotated) {
     EXPECT_FALSE(L->isRotatedForm());
   });
 }
+
+TEST(LoopInfoTest, LoopUserBranch) {
+  const char *ModuleStr =
+      "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
+      "define void @foo(i32* %B, i64 signext %nx, i1 %cond) {\n"
+      "entry:\n"
+      "  br i1 %cond, label %bb, label %guard\n"
+      "guard:\n"
+      "  %cmp.guard = icmp slt i64 0, %nx\n"
+      "  br i1 %cmp.guard, label %for.i.preheader, label %for.end\n"
+      "for.i.preheader:\n"
+      "  br label %for.i\n"
+      "for.i:\n"
+      "  %i = phi i64 [ 0, %for.i.preheader ], [ %inc13, %for.i ]\n"
+      "  %Bi = getelementptr inbounds i32, i32* %B, i64 %i\n"
+      "  store i32 0, i32* %Bi, align 4\n"
+      "  %inc13 = add nsw i64 %i, 1\n"
+      "  %cmp = icmp slt i64 %inc13, %nx\n"
+      "  br i1 %cmp, label %for.i, label %for.i.exit\n"
+      "for.i.exit:\n"
+      "  br label %bb\n"
+      "bb:\n"
+      "  br label %for.end\n"
+      "for.end:\n"
+      "  ret void\n"
+      "}\n";
+
+  // Parse the module.
+  LLVMContext Context;
+  std::unique_ptr<Module> M = makeLLVMModule(Context, ModuleStr);
+
+  runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) {
+    Function::iterator FI = F.begin();
+    FI = ++FI;
+    BasicBlock *Guard = &*FI;
+    assert(Guard->getName() == "guard");
+
+    FI = ++FI;
+    BasicBlock *Header = &*(++FI);
+    assert(Header->getName() == "for.i");
+
+    Loop *L = LI.getLoopFor(Header);
+    EXPECT_NE(L, nullptr);
+
+    // L should not have a guard branch
+    EXPECT_EQ(L->getLoopGuardBranch(), nullptr);
+  });
+}