[SCEV] Fix unsound reasoning in howManyLessThans
authorPhilip Reames <listmail@philipreames.com>
Thu, 15 Jul 2021 17:25:06 +0000 (10:25 -0700)
committerPhilip Reames <listmail@philipreames.com>
Thu, 15 Jul 2021 17:32:47 +0000 (10:32 -0700)
This is split from D105216, it handles only a subset of the cases in that patch.

Specifically, the issue being fixed is that the code incorrectly assumed that (Start-Stide) < End implied that the backedge was taken at least once. This is not true when e.g. Start = 4, Stride = 2, and End = 3. Note that we often do produce the right backedge taken count despite the flawed reasoning.

The fix chosen here is to use an alternate form of uceil (ceiling of unsigned divide) lowering which is safe when max(RHS,Start) > Start - Stride.  (Note that signedness of both max expression and comparison depend on the signedness of the comparison being analyzed, and that overflow in the Start - Stride expression is allowed.)  Note that this is weaker than proving the backedge is taken because it allows start - stride < end < start.  Some cases which can't be proven safe are sent down the generic path, and we do end up generating less optimal expressions in a few cases.

Credit for coming up with the approach goes entirely to Eli.  I just split it off, tweaked the comments a bit, and did some additional testing.

Differential Revision: https://reviews.llvm.org/D105942

llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
llvm/test/Transforms/LoopReroll/nonconst_lb.ll

index b486ec4..88858ba 100644 (file)
@@ -11761,17 +11761,42 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   // and End is the RHS.
   const SCEV *BECountIfBackedgeTaken =
     computeBECount(getMinusSCEV(End, Start), Stride);
-  // If the loop entry is guarded by the result of the backedge test of the
-  // first loop iteration, then we know the backedge will be taken at least
-  // once and so the backedge taken count is as above. If not then we use the
-  // expression (max(End,Start)-Start)/Stride to describe the backedge count,
-  // as if the backedge is taken at least once max(End,Start) is End and so the
-  // result is as above, and if not max(End,Start) is Start so we get a backedge
-  // count of zero.
-  const SCEV *BECount;
-  if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(OrigStart, Stride), OrigRHS))
-    BECount = BECountIfBackedgeTaken;
-  else {
+  
+  // We use the expression (max(End,Start)-Start)/Stride to describe the
+  // backedge count, as if the backedge is taken at least once max(End,Start)
+  // is End and so the result is as above, and if not max(End,Start) is Start
+  // so we get a backedge count of zero.
+  const SCEV *BECount = nullptr;
+  auto *StartMinusStride = getMinusSCEV(OrigStart, Stride);
+  // Can we prove (max(RHS,Start) > Start - Stride?
+  if (isLoopEntryGuardedByCond(L, Cond, StartMinusStride, Start) &&
+      isLoopEntryGuardedByCond(L, Cond, StartMinusStride, RHS)) {
+    // In this case, we can use a refined formula for computing backedge taken
+    // count.  The general formula remains:
+    //   "End-Start /uceiling Stride" where "End = max(RHS,Start)"
+    // We want to use the alternate formula:
+    //   "((End - 1) - (Start - Stride)) /u Stride"
+    // Let's do a quick case analysis to show these are equivalent under
+    // our precondition that max(RHS,Start) > Start - Stride.
+    // * For RHS <= Start, the backedge-taken count must be zero.
+    //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
+    //   "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
+    //   "Stride - 1 /u Stride" which is indeed zero for all non-zero values
+    //     of Stride.  For 0 stride, we've use umin(1,Stride) above, reducing
+    //     this to the stride of 1 case.
+    // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
+    //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
+    //   "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
+    //   "((RHS - (Start - Stride) - 1) /u Stride".
+    //   Our preconditions trivially imply no overflow in that form.
+    const SCEV *MinusOne = getMinusOne(Stride->getType());
+    const SCEV *Numerator =
+        getMinusSCEV(getAddExpr(RHS, MinusOne), StartMinusStride);
+    if (!isa<SCEVCouldNotCompute>(Numerator)) {
+      BECount = getUDivExpr(Numerator, Stride);
+    }
+  }
+  if (!BECount) {
     auto canProveRHSGreaterThanEqualStart = [&]() {
       auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
       if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
index cece093..390f974 100644 (file)
@@ -1,7 +1,7 @@
 ; RUN: opt < %s -analyze -enable-new-pm=0 -scalar-evolution 2>&1 | FileCheck %s
 ; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 2>&1 | FileCheck %s
 
-; CHECK: Loop %bb: backedge-taken count is ((999 + (-1 * %x)) /u 3)
+; CHECK: Loop %bb: backedge-taken count is ((-1 + (-1 * %x) + (1000 umax (3 + %x))) /u 3)
 ; CHECK: Loop %bb: max backedge-taken count is 334
 
 
index 1af81a3..a085ee9 100644 (file)
@@ -5,7 +5,7 @@
 ; that this is not an infinite loop with side effects.
 
 ; CHECK-LABEL: Determining loop execution counts for: @foo1
-; CHECK: backedge-taken count is ((-1 + %n) /u %s)
+; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s)
 
 ; We should have a conservative estimate for the max backedge taken count for
 ; loops with unknown stride.
index 200a37b..31cc10e 100644 (file)
@@ -17,22 +17,24 @@ define void @foo(i32* nocapture %A, i32* nocapture readonly %B, i32 %m, i32 %n)
 ; CHECK-NEXT:    [[CMP34:%.*]] = icmp slt i32 [[M:%.*]], [[N:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP34]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_END:%.*]]
 ; CHECK:       for.body.preheader:
-; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[N]], -1
-; CHECK-NEXT:    [[TMP1:%.*]] = sub i32 [[TMP0]], [[M]]
-; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 2
-; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw i32 [[TMP2]], 2
-; CHECK-NEXT:    [[TMP4:%.*]] = add nuw nsw i32 [[TMP3]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[M]], 4
+; CHECK-NEXT:    [[SMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[N]], i32 [[TMP0]])
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[SMAX]], -1
+; CHECK-NEXT:    [[TMP2:%.*]] = sub i32 [[TMP1]], [[M]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i32 [[TMP2]], 2
+; CHECK-NEXT:    [[TMP4:%.*]] = shl nuw i32 [[TMP3]], 2
+; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i32 [[TMP4]], 3
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[INDVAR:%.*]] = phi i32 [ 0, [[FOR_BODY_PREHEADER]] ], [ [[INDVAR_NEXT:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP5:%.*]] = add i32 [[M]], [[INDVAR]]
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, i32* [[B:%.*]], i32 [[TMP5]]
-; CHECK-NEXT:    [[TMP6:%.*]] = load i32, i32* [[ARRAYIDX]], align 4
-; CHECK-NEXT:    [[MUL:%.*]] = shl nsw i32 [[TMP6]], 2
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i32 [[TMP5]]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i32 [[M]], [[INDVAR]]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, i32* [[B:%.*]], i32 [[TMP6]]
+; CHECK-NEXT:    [[TMP7:%.*]] = load i32, i32* [[ARRAYIDX]], align 4
+; CHECK-NEXT:    [[MUL:%.*]] = shl nsw i32 [[TMP7]], 2
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr inbounds i32, i32* [[A:%.*]], i32 [[TMP6]]
 ; CHECK-NEXT:    store i32 [[MUL]], i32* [[ARRAYIDX2]], align 4
 ; CHECK-NEXT:    [[INDVAR_NEXT]] = add i32 [[INDVAR]], 1
-; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp eq i32 [[INDVAR]], [[TMP4]]
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp eq i32 [[INDVAR]], [[TMP5]]
 ; CHECK-NEXT:    br i1 [[EXITCOND]], label [[FOR_END_LOOPEXIT:%.*]], label [[FOR_BODY]]
 ; CHECK:       for.end.loopexit:
 ; CHECK-NEXT:    br label [[FOR_END]]