[SCEV] Make exact taken count calculation more optimistic
authorMax Kazantsev <max.kazantsev@azul.com>
Tue, 27 Mar 2018 07:30:38 +0000 (07:30 +0000)
committerMax Kazantsev <max.kazantsev@azul.com>
Tue, 27 Mar 2018 07:30:38 +0000 (07:30 +0000)
Currently, `getExact` fails if it sees two exit counts in different blocks. There is
no solid reason to do so, given that we only calculate exact non-taken count
for exiting blocks that dominate latch. Using this fact, we can simply take min
out of all exits of all blocks to get the exact taken count.

This patch makes the calculation more optimistic with enforcing our assumption
with asserts. It allows us to calculate exact backedge taken count in trivial loops
like

  for (int i = 0; i < 100; i++) {
    if (i > 50) break;
    . . .
  }

Differential Revision: https://reviews.llvm.org/D44676
Reviewed By: fhahn

llvm-svn: 328611

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/exact_iter_count.ll [new file with mode: 0644]
llvm/test/Analysis/ScalarEvolution/max-trip-count.ll
llvm/test/Analysis/ScalarEvolution/trip-count14.ll
llvm/test/Transforms/IndVarSimplify/loop_evaluate10.ll
llvm/test/Transforms/LoopSimplify/preserve-scev.ll

index 82dac93..7a43b81 100644 (file)
@@ -1288,7 +1288,7 @@ private:
     /// If we allowed SCEV predicates to be generated when populating this
     /// vector, this information can contain them and therefore a
     /// SCEVPredicate argument should be added to getExact.
-    const SCEV *getExact(ScalarEvolution *SE,
+    const SCEV *getExact(const Loop *L, ScalarEvolution *SE,
                          SCEVUnionPredicate *Predicates = nullptr) const;
 
     /// Return the number of times this loop exit may fall through to the back
index b44107d..3a1c6fb 100644 (file)
@@ -6413,11 +6413,11 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
 const SCEV *
 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
                                                  SCEVUnionPredicate &Preds) {
-  return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds);
+  return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
 }
 
 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
-  return getBackedgeTakenInfo(L).getExact(this);
+  return getBackedgeTakenInfo(L).getExact(L, this);
 }
 
 /// Similar to getBackedgeTakenCount, except return the least SCEV value that is
@@ -6474,8 +6474,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
   // must be cleared in this scope.
   BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
 
-  if (Result.getExact(this) != getCouldNotCompute()) {
-    assert(isLoopInvariant(Result.getExact(this), L) &&
+  if (Result.getExact(L, this) != getCouldNotCompute()) {
+    assert(isLoopInvariant(Result.getExact(L, this), L) &&
            isLoopInvariant(Result.getMax(this), L) &&
            "Computed backedge-taken count isn't loop invariant for loop!");
     ++NumTripCountsComputed;
@@ -6656,20 +6656,30 @@ void ScalarEvolution::forgetValue(Value *V) {
 /// caller's responsibility to specify the relevant loop exit using
 /// getExact(ExitingBlock, SE).
 const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE,
+ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
                                              SCEVUnionPredicate *Preds) const {
   // If any exits were not computable, the loop is not computable.
   if (!isComplete() || ExitNotTaken.empty())
     return SE->getCouldNotCompute();
 
   const SCEV *BECount = nullptr;
+  const BasicBlock *Latch = L->getLoopLatch();
+  // All exits we have collected must dominate the only latch.
+  if (!Latch)
+    return SE->getCouldNotCompute();
+
+  // All exits we have gathered dominate loop's latch, so exact trip count is
+  // simply a minimum out of all these calculated exit counts.
   for (auto &ENT : ExitNotTaken) {
     assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
+    assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
+           "We should only have known counts for exits that dominate latch!");
 
     if (!BECount)
       BECount = ENT.ExactNotTaken;
     else if (BECount != ENT.ExactNotTaken)
-      return SE->getCouldNotCompute();
+      BECount = SE->getUMinFromMismatchedTypes(BECount, ENT.ExactNotTaken);
+
     if (Preds && !ENT.hasAlwaysTruePredicate())
       Preds->add(ENT.Predicate.get());
 
diff --git a/llvm/test/Analysis/ScalarEvolution/exact_iter_count.ll b/llvm/test/Analysis/ScalarEvolution/exact_iter_count.ll
new file mode 100644 (file)
index 0000000..ba0bc1f
--- /dev/null
@@ -0,0 +1,27 @@
+; RUN: opt < %s -scalar-evolution -analyze | FileCheck %s
+
+; One side exit dominating the latch, exact backedge taken count is known.
+define void @test_01() {
+
+; CHECK-LABEL: Determining loop execution counts for: @test_01
+; CHECK-NEXT:  Loop %loop: <multiple exits> backedge-taken count is 50
+
+entry:
+  br label %loop
+
+loop:
+  %iv = phi i32 [ 0, %entry ], [ %iv.next, %backedge ]
+  %side.cond = icmp slt i32 %iv, 50
+  br i1 %side.cond, label %backedge, label %side.exit
+
+backedge:
+  %iv.next = add i32 %iv, 1
+  %loop.cond = icmp slt i32 %iv, 100
+  br i1 %loop.cond, label %loop, label %exit
+
+exit:
+  ret void
+
+side.exit:
+  ret void
+}
index 240ff8d..53b882b 100644 (file)
@@ -186,7 +186,7 @@ bar.exit:                                         ; preds = %for.cond.i, %for.bo
 ; MaxBECount should be the minimum of them.
 ;
 ; CHECK-LABEL: @two_mustexit
-; CHECK: Loop %for.body.i: <multiple exits> Unpredictable backedge-taken count.
+; CHECK: Loop %for.body.i: <multiple exits> backedge-taken count is 1
 ; CHECK: Loop %for.body.i: max backedge-taken count is 1
 define i32 @two_mustexit() {
 entry:
index 0f935d7..5e6cfe8 100644 (file)
@@ -81,7 +81,7 @@ if.end:
   br i1 %cmp1, label %do.body, label %do.end ; taken either 0 or 2 times
 
 ; CHECK-LABEL: Determining loop execution counts for: @s32_max2_unpredictable_exit
-; CHECK-NEXT: Loop %do.body: <multiple exits> Unpredictable backedge-taken count.
+; CHECK-NEXT: Loop %do.body: <multiple exits> backedge-taken count is (-1 + (-1 * ((-1 + (-1 * ((2 + %n) smax %n)) + %n) umax (-1 + (-1 * %x) + %n))))
 ; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2{{$}}
 
 do.end:
@@ -169,7 +169,7 @@ if.end:
   br i1 %cmp1, label %do.body, label %do.end ; taken either 0 or 2 times
 
 ; CHECK-LABEL: Determining loop execution counts for: @u32_max2_unpredictable_exit
-; CHECK-NEXT: Loop %do.body: <multiple exits> Unpredictable backedge-taken count.
+; CHECK-NEXT: Loop %do.body: <multiple exits> backedge-taken count is (-1 + (-1 * ((-1 + (-1 * ((2 + %n) umax %n)) + %n) umax (-1 + (-1 * %x) + %n))))
 ; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2{{$}}
 
 do.end:
index 0739219..3ac9106 100644 (file)
@@ -3,11 +3,6 @@
 ; This loop has multiple exits, and the value of %b1 depends on which
 ; exit is taken. Indvars should correctly compute the exit values.
 ;
-; XFAIL: *
-; Indvars does not currently replace loop invariant values unless all
-; loop exits have the same exit value. We could handle some cases,
-; such as this, by making getSCEVAtScope() sensitive to a particular
-; loop exit.  See PR11388.
 
 target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128"
 target triple = "x86_64-pc-linux-gnu"
index fb15d84..2def885 100644 (file)
@@ -91,9 +91,9 @@ declare void @foo() nounwind
 ; After simplifying, the max backedge count is refined.
 ; Second SCEV print:
 ; CHECK-LABEL: Determining loop execution counts for: @mergeExit
-; CHECK: Loop %while.cond191: <multiple exits> Unpredictable backedge-taken count.
+; CHECK: Loop %while.cond191: <multiple exits> backedge-taken count is 0
 ; CHECK: Loop %while.cond191: max backedge-taken count is 0
-; CHECK: Loop %while.cond191: Unpredictable predicated backedge-taken count.
+; CHECK: Loop %while.cond191: Predicated backedge-taken count is 0
 ; CHECK: Loop %while.cond191.outer: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK: Loop %while.cond191.outer: Unpredictable max backedge-taken count.
 ; CHECK: Loop %while.cond191.outer: Unpredictable predicated backedge-taken count.