[LoopUnroll] Keep the loop test only on the first iteration of max-or-zero loops
authorJohn Brawn <john.brawn@arm.com>
Fri, 21 Oct 2016 11:08:48 +0000 (11:08 +0000)
committerJohn Brawn <john.brawn@arm.com>
Fri, 21 Oct 2016 11:08:48 +0000 (11:08 +0000)
When we have a loop with a known upper bound on the number of iterations, and
furthermore know that either the number of iterations will be either exactly
that upper bound or zero, then we can fully unroll up to that upper bound
keeping only the first loop test to check for the zero iteration case.

Most of the work here is in plumbing this 'max-or-zero' information from the
part of scalar evolution where it's detected through to loop unrolling. I've
also gone for the safe default of 'false' everywhere but howManyLessThans which
could probably be improved.

Differential Revision: https://reviews.llvm.org/D25682

llvm-svn: 284818

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/include/llvm/Transforms/Utils/UnrollLoop.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
llvm/lib/Transforms/Utils/LoopUnroll.cpp
llvm/test/Analysis/ScalarEvolution/trip-count13.ll
llvm/test/Analysis/ScalarEvolution/trip-count14.ll
llvm/test/Transforms/LoopUnroll/full-unroll-keep-first-exit.ll [new file with mode: 0644]

index 7a66d95..7284f3e 100644 (file)
@@ -548,8 +548,9 @@ private:
   /// pair of exact and max expressions that are eventually summarized in
   /// ExitNotTakenInfo and BackedgeTakenInfo.
   struct ExitLimit {
-    const SCEV *ExactNotTaken;
-    const SCEV *MaxNotTaken;
+    const SCEV *ExactNotTaken; //< The exit is not taken exactly this many times
+    const SCEV *MaxNotTaken; //< The exit is not taken at most this many times
+    bool MaxOrZero; //< Not taken either exactly MaxNotTaken or zero times
 
     /// A set of predicate guards for this ExitLimit. The result is only valid
     /// if all of the predicates in \c Predicates evaluate to 'true' at
@@ -561,12 +562,13 @@ private:
       Predicates.insert(P);
     }
 
-    /*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {}
+    /*implicit*/ ExitLimit(const SCEV *E)
+        : ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) {}
 
     ExitLimit(
-        const SCEV *E, const SCEV *M,
+        const SCEV *E, const SCEV *M, bool MaxOrZero,
         ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
-        : ExactNotTaken(E), MaxNotTaken(M) {
+        : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
       assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
               !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
              "Exact is not allowed to be less precise than Max");
@@ -575,11 +577,12 @@ private:
           addPredicate(P);
     }
 
-    ExitLimit(const SCEV *E, const SCEV *M,
+    ExitLimit(const SCEV *E, const SCEV *M, bool MaxOrZero,
               const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
-        : ExitLimit(E, M, {&PredSet}) {}
+        : ExitLimit(E, M, MaxOrZero, {&PredSet}) {}
 
-    ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {}
+    ExitLimit(const SCEV *E, const SCEV *M, bool MaxOrZero)
+        : ExitLimit(E, M, MaxOrZero, None) {}
 
     /// Test whether this ExitLimit contains any computed information, or
     /// whether it's all SCEVCouldNotCompute values.
@@ -628,6 +631,9 @@ private:
     /// ExitNotTaken has an element for every exiting block in the loop.
     PointerIntPair<const SCEV *, 1> MaxAndComplete;
 
+    /// True iff the backedge is taken either exactly Max or zero times.
+    bool MaxOrZero;
+
     /// \name Helper projection functions on \c MaxAndComplete.
     /// @{
     bool isComplete() const { return MaxAndComplete.getInt(); }
@@ -644,7 +650,7 @@ private:
 
     /// Initialize BackedgeTakenInfo from a list of exact exit counts.
     BackedgeTakenInfo(SmallVectorImpl<EdgeExitInfo> &&ExitCounts, bool Complete,
-                      const SCEV *MaxCount);
+                      const SCEV *MaxCount, bool MaxOrZero);
 
     /// Test whether this BackedgeTakenInfo contains any computed information,
     /// or whether it's all SCEVCouldNotCompute values.
@@ -683,6 +689,10 @@ private:
     /// Get the max backedge taken count for the loop.
     const SCEV *getMax(ScalarEvolution *SE) const;
 
+    /// Return true if the number of times this backedge is taken is either the
+    /// value returned by getMax or zero.
+    bool isMaxOrZero(ScalarEvolution *SE) const;
+
     /// Return true if any backedge taken count expressions refer to the given
     /// subexpression.
     bool hasOperand(const SCEV *S, ScalarEvolution *SE) const;
@@ -1354,6 +1364,10 @@ public:
   /// that is known never to be less than the actual backedge taken count.
   const SCEV *getMaxBackedgeTakenCount(const Loop *L);
 
+  /// Return true if the backedge taken count is either the value returned by
+  /// getMaxBackedgeTakenCount or zero.
+  bool isBackedgeTakenCountMaxOrZero(const Loop *L);
+
   /// Return true if the specified loop has an analyzable loop-invariant
   /// backedge-taken count.
   bool hasLoopInvariantBackedgeTakenCount(const Loop *L);
index 74dc2b7..6347908 100644 (file)
@@ -32,8 +32,9 @@ class ScalarEvolution;
 
 bool UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
                 bool AllowRuntime, bool AllowExpensiveTripCount,
-                bool UseUpperBound, unsigned TripMultiple, LoopInfo *LI,
-                ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
+                bool PreserveCondBr, bool PreserveOnlyFirst,
+                unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE,
+                DominatorTree *DT, AssumptionCache *AC,
                 OptimizationRemarkEmitter *ORE, bool PreserveLCSSA);
 
 bool UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,
index 8c6ddff..f03051d 100644 (file)
@@ -5424,6 +5424,10 @@ const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
   return getBackedgeTakenInfo(L).getMax(this);
 }
 
+bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
+  return getBackedgeTakenInfo(L).isMaxOrZero(this);
+}
+
 /// Push PHI nodes in the header of the given loop onto the given Worklist.
 static void
 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
@@ -5656,6 +5660,13 @@ ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
   return getMax();
 }
 
+bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
+  auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
+    return !ENT.hasAlwaysTruePredicate();
+  };
+  return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
+}
+
 bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
                                                     ScalarEvolution *SE) const {
   if (getMax() && getMax() != SE->getCouldNotCompute() &&
@@ -5675,8 +5686,8 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
     SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
         &&ExitCounts,
-    bool Complete, const SCEV *MaxCount)
-    : MaxAndComplete(MaxCount, Complete) {
+    bool Complete, const SCEV *MaxCount, bool MaxOrZero)
+    : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
   typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
   ExitNotTaken.reserve(ExitCounts.size());
   std::transform(
@@ -5714,6 +5725,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
   const SCEV *MustExitMaxBECount = nullptr;
   const SCEV *MayExitMaxBECount = nullptr;
+  bool MustExitMaxOrZero = false;
 
   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
   // and compute maxBECount.
@@ -5746,9 +5758,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
     // computable EL.MaxNotTaken.
     if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
         DT.dominates(ExitBB, Latch)) {
-      if (!MustExitMaxBECount)
+      if (!MustExitMaxBECount) {
         MustExitMaxBECount = EL.MaxNotTaken;
-      else {
+        MustExitMaxOrZero = EL.MaxOrZero;
+      } else {
         MustExitMaxBECount =
             getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
       }
@@ -5763,8 +5776,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   }
   const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
     (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
+  // The loop backedge will be taken the maximum or zero times if there's
+  // a single exit that must be taken the maximum or zero times.
+  bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
   return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
-                           MaxBECount);
+                           MaxBECount, MaxOrZero);
 }
 
 ScalarEvolution::ExitLimit
@@ -5901,7 +5917,8 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
           !isa<SCEVCouldNotCompute>(BECount))
         MaxBECount = BECount;
 
-      return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
+      return ExitLimit(BECount, MaxBECount, false,
+                       {&EL0.Predicates, &EL1.Predicates});
     }
     if (BO->getOpcode() == Instruction::Or) {
       // Recurse on the operands of the or.
@@ -5940,7 +5957,8 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,
           BECount = EL0.ExactNotTaken;
       }
 
-      return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
+      return ExitLimit(BECount, MaxBECount, false,
+                       {&EL0.Predicates, &EL1.Predicates});
     }
   }
 
@@ -6325,7 +6343,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
     unsigned BitWidth = getTypeSizeInBits(RHS->getType());
     const SCEV *UpperBound =
         getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
-    return ExitLimit(getCouldNotCompute(), UpperBound);
+    return ExitLimit(getCouldNotCompute(), UpperBound, false);
   }
 
   return getCouldNotCompute();
@@ -7121,7 +7139,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
         // should not accept a root of 2.
         const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
         if (Val->isZero())
-          return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
+          // We found a quadratic root!
+          return ExitLimit(R1, R1, false, Predicates);
       }
     }
     return getCouldNotCompute();
@@ -7178,7 +7197,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
     else
       MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
                                          : -CR.getUnsignedMin());
-    return ExitLimit(Distance, MaxBECount, Predicates);
+    return ExitLimit(Distance, MaxBECount, false, Predicates);
   }
 
   // As a special case, handle the instance where Step is a positive power of
@@ -7233,7 +7252,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
 
       const SCEV *Limit =
           getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
-      return ExitLimit(Limit, Limit, Predicates);
+      return ExitLimit(Limit, Limit, false, Predicates);
     }
   }
 
@@ -7246,14 +7265,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
       loopHasNoAbnormalExits(AddRec->getLoop())) {
     const SCEV *Exact =
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
-    return ExitLimit(Exact, Exact, Predicates);
+    return ExitLimit(Exact, Exact, false, Predicates);
   }
 
   // Then, try to solve the above equation provided that Start is constant.
   if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
     const SCEV *E = SolveLinEquationWithOverflow(
         StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
-    return ExitLimit(E, E, Predicates);
+    return ExitLimit(E, E, false, Predicates);
   }
   return getCouldNotCompute();
 }
@@ -8695,14 +8714,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   }
 
   const SCEV *MaxBECount;
+  bool MaxOrZero = false;
   if (isa<SCEVConstant>(BECount))
     MaxBECount = BECount;
-  else if (isa<SCEVConstant>(BECountIfBackedgeTaken))
+  else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
     // If we know exactly how many times the backedge will be taken if it's
     // taken at least once, then the backedge count will either be that or
     // zero.
     MaxBECount = BECountIfBackedgeTaken;
-  else {
+    MaxOrZero = true;
+  } else {
     // Calculate the maximum backedge count based on the range of values
     // permitted by Start, End, and Stride.
     APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin()
@@ -8739,7 +8760,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     MaxBECount = BECount;
 
-  return ExitLimit(BECount, MaxBECount, Predicates);
+  return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
 }
 
 ScalarEvolution::ExitLimit
@@ -8816,7 +8837,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     MaxBECount = BECount;
 
-  return ExitLimit(BECount, MaxBECount, Predicates);
+  return ExitLimit(BECount, MaxBECount, false, Predicates);
 }
 
 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@@ -9598,6 +9619,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
 
   if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
     OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
+    if (SE->isBackedgeTakenCountMaxOrZero(L))
+      OS << ", actual taken count either this or zero.";
   } else {
     OS << "Unpredictable max backedge-taken count. ";
   }
index e4b181b..b81cf84 100644 (file)
@@ -1000,14 +1000,22 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
   if (Convergent)
     UP.AllowRemainder = false;
 
-  // Try to find the trip count upper bound if it is allowed and we cannot find
-  // exact trip count.
-  if (UP.UpperBound) {
-    if (!TripCount) {
-      MaxTripCount = SE->getSmallConstantMaxTripCount(L);
-      // Only unroll with small upper bound.
-      if (MaxTripCount > UnrollMaxUpperBound)
-        MaxTripCount = 0;
+  // Try to find the trip count upper bound if we cannot find the exact trip
+  // count.
+  bool MaxOrZero = false;
+  if (!TripCount) {
+    MaxTripCount = SE->getSmallConstantMaxTripCount(L);
+    MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
+    // We can unroll by the upper bound amount if it's generally allowed or if
+    // we know that the loop is executed either the upper bound or zero times.
+    // (MaxOrZero unrolling keeps only the first loop test, so the number of
+    // loop tests remains the same compared to the non-unrolled version, whereas
+    // the generic upper bound unrolling keeps all but the last loop test so the
+    // number of loop tests goes up which may end up being worse on targets with
+    // constriained branch predictor resources so is controlled by an option.)
+    // In addition we only unroll small upper bounds.
+    if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) {
+      MaxTripCount = 0;
     }
   }
 
@@ -1025,8 +1033,8 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
 
   // Unroll the loop.
   if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime,
-                  UP.AllowExpensiveTripCount, UseUpperBound, TripMultiple, LI,
-                  SE, &DT, &AC, &ORE, PreserveLCSSA))
+                  UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero,
+                  TripMultiple, LI, SE, &DT, &AC, &ORE, PreserveLCSSA))
     return false;
 
   // If loop has an unroll count pragma or unrolled by explicitly set count
index 8472247..912f9e0 100644 (file)
@@ -189,7 +189,8 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks,
 ///
 /// PreserveCondBr indicates whether the conditional branch of the LatchBlock
 /// needs to be preserved.  It is needed when we use trip count upper bound to
-/// fully unroll the loop.
+/// fully unroll the loop. If PreserveOnlyFirst is also set then only the first
+/// conditional branch needs to be preserved.
 ///
 /// Similarly, TripMultiple divides the number of times that the LatchBlock may
 /// execute without exiting the loop.
@@ -207,10 +208,10 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks,
 /// DominatorTree if they are non-null.
 bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
                       bool AllowRuntime, bool AllowExpensiveTripCount,
-                      bool PreserveCondBr, unsigned TripMultiple, LoopInfo *LI,
-                      ScalarEvolution *SE, DominatorTree *DT,
-                      AssumptionCache *AC, OptimizationRemarkEmitter *ORE,
-                      bool PreserveLCSSA) {
+                      bool PreserveCondBr, bool PreserveOnlyFirst,
+                      unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE,
+                      DominatorTree *DT, AssumptionCache *AC,
+                      OptimizationRemarkEmitter *ORE, bool PreserveLCSSA) {
   BasicBlock *Preheader = L->getLoopPreheader();
   if (!Preheader) {
     DEBUG(dbgs() << "  Can't unroll; loop preheader-insertion failed.\n");
@@ -550,7 +551,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force,
       assert(NeedConditional &&
              "NeedCondition cannot be modified by both complete "
              "unrolling and runtime unrolling");
-      NeedConditional = (PreserveCondBr && j);
+      NeedConditional = (PreserveCondBr && j && !(PreserveOnlyFirst && i != 0));
     } else if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
       // If we know the trip count or a multiple of it, we can safely use an
       // unconditional branch for some iterations.
index dfad4f5..3e10097 100644 (file)
@@ -14,7 +14,7 @@ loop:
 
 ; CHECK-LABEL: Determining loop execution counts for: @u_0
 ; CHECK-NEXT: Loop %loop: backedge-taken count is (-100 + (-1 * %rhs) + ((100 + %rhs) umax %rhs))
-; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
+; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.
 
 leave:
   ret void
@@ -34,7 +34,7 @@ loop:
 
 ; CHECK-LABEL: Determining loop execution counts for: @u_1
 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-1 * %start) + ((-100 + %start) umax %start))
-; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
+; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.
 
 leave:
   ret void
@@ -54,7 +54,7 @@ loop:
 
 ; CHECK-LABEL: Determining loop execution counts for: @s_0
 ; CHECK-NEXT: Loop %loop: backedge-taken count is (-100 + (-1 * %rhs) + ((100 + %rhs) smax %rhs))
-; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
+; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.
 
 leave:
   ret void
@@ -74,7 +74,7 @@ loop:
 
 ; CHECK-LABEL: Determining loop execution counts for: @s_1
 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-1 * %start) + ((-100 + %start) smax %start))
-; CHECK-NEXT: Loop %loop: max backedge-taken count is -100
+; CHECK-NEXT: Loop %loop: max backedge-taken count is -100, actual taken count either this or zero.
 
 leave:
   ret void
index b38c544..0f935d7 100644 (file)
@@ -15,7 +15,7 @@ do.body:
 
 ; CHECK-LABEL: Determining loop execution counts for: @s32_max1
 ; CHECK-NEXT: Loop %do.body: backedge-taken count is ((-1 * %n) + ((1 + %n) smax %n))
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is 1
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is 1, actual taken count either this or zero.
 
 do.end:
   ret void
@@ -36,7 +36,7 @@ do.body:
 
 ; CHECK-LABEL: Determining loop execution counts for: @s32_max2
 ; CHECK-NEXT: Loop %do.body: backedge-taken count is ((-1 * %n) + ((2 + %n) smax %n))
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2, actual taken count either this or zero.
 
 do.end:
   ret void
@@ -57,7 +57,7 @@ do.body:
 
 ; CHECK-LABEL: Determining loop execution counts for: @s32_maxx
 ; CHECK-NEXT: Loop %do.body: backedge-taken count is ((-1 * %n) + ((%n + %x) smax %n))
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is -1
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is -1{{$}}
 
 do.end:
   ret void
@@ -82,7 +82,7 @@ if.end:
 
 ; CHECK-LABEL: Determining loop execution counts for: @s32_max2_unpredictable_exit
 ; CHECK-NEXT: Loop %do.body: <multiple exits> Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2{{$}}
 
 do.end:
   ret void
@@ -103,7 +103,7 @@ do.body:
 
 ; CHECK-LABEL: Determining loop execution counts for: @u32_max1
 ; CHECK-NEXT: Loop %do.body: backedge-taken count is ((-1 * %n) + ((1 + %n) umax %n))
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is 1
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is 1, actual taken count either this or zero.
 
 do.end:
   ret void
@@ -124,7 +124,7 @@ do.body:
 
 ; CHECK-LABEL: Determining loop execution counts for: @u32_max2
 ; CHECK-NEXT: Loop %do.body: backedge-taken count is ((-1 * %n) + ((2 + %n) umax %n))
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2, actual taken count either this or zero.
 
 do.end:
   ret void
@@ -145,7 +145,7 @@ do.body:
 
 ; CHECK-LABEL: Determining loop execution counts for: @u32_maxx
 ; CHECK-NEXT: Loop %do.body: backedge-taken count is ((-1 * %n) + ((%n + %x) umax %n))
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is -1
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is -1{{$}}
 
 do.end:
   ret void
@@ -170,7 +170,7 @@ if.end:
 
 ; CHECK-LABEL: Determining loop execution counts for: @u32_max2_unpredictable_exit
 ; CHECK-NEXT: Loop %do.body: <multiple exits> Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2
+; CHECK-NEXT: Loop %do.body: max backedge-taken count is 2{{$}}
 
 do.end:
   ret void
diff --git a/llvm/test/Transforms/LoopUnroll/full-unroll-keep-first-exit.ll b/llvm/test/Transforms/LoopUnroll/full-unroll-keep-first-exit.ll
new file mode 100644 (file)
index 0000000..e70ff41
--- /dev/null
@@ -0,0 +1,207 @@
+; RUN: opt -S -loop-unroll < %s | FileCheck %s
+
+; Unroll twice, with first loop exit kept
+; CHECK-LABEL: @s32_max1
+; CHECK: do.body:
+; CHECK:  store
+; CHECK:  br i1 %cmp, label %do.body.1, label %do.end
+; CHECK: do.end:
+; CHECK:  ret void
+; CHECK: do.body.1:
+; CHECK:  store
+; CHECK:  br label %do.end
+define void @s32_max1(i32 %n, i32* %p) {
+entry:
+  %add = add i32 %n, 1
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %do.body ]
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp = icmp slt i32 %i.0, %add
+  br i1 %cmp, label %do.body, label %do.end ; taken either 0 or 1 times
+
+do.end:
+  ret void
+}
+
+; Unroll thrice, with first loop exit kept
+; CHECK-LABEL: @s32_max2
+; CHECK: do.body:
+; CHECK:  store
+; CHECK:  br i1 %cmp, label %do.body.1, label %do.end
+; CHECK: do.end:
+; CHECK:  ret void
+; CHECK: do.body.1:
+; CHECK:  store
+; CHECK:  store
+; CHECK:  br label %do.end
+define void @s32_max2(i32 %n, i32* %p) {
+entry:
+  %add = add i32 %n, 2
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %do.body ]
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp = icmp slt i32 %i.0, %add
+  br i1 %cmp, label %do.body, label %do.end ; taken either 0 or 2 times
+
+do.end:
+  ret void
+}
+
+; Should not be unrolled
+; CHECK-LABEL: @s32_maxx
+; CHECK: do.body:
+; CHECK: do.end:
+; CHECK-NOT: do.body.1:
+define void @s32_maxx(i32 %n, i32 %x, i32* %p) {
+entry:
+  %add = add i32 %x, %n
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %do.body ]
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp = icmp slt i32 %i.0, %add
+  br i1 %cmp, label %do.body, label %do.end ; taken either 0 or x times
+
+do.end:
+  ret void
+}
+
+; Should not be unrolled
+; CHECK-LABEL: @s32_max2_unpredictable_exit
+; CHECK: do.body:
+; CHECK: do.end:
+; CHECK-NOT: do.body.1:
+define void @s32_max2_unpredictable_exit(i32 %n, i32 %x, i32* %p) {
+entry:
+  %add = add i32 %n, 2
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %if.end ]
+  %cmp = icmp eq i32 %i.0, %x
+  br i1 %cmp, label %do.end, label %if.end ; unpredictable
+
+if.end:
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp1 = icmp slt i32 %i.0, %add
+  br i1 %cmp1, label %do.body, label %do.end ; taken either 0 or 2 times
+
+do.end:
+  ret void
+}
+
+; Unroll twice, with first loop exit kept
+; CHECK-LABEL: @u32_max1
+; CHECK: do.body:
+; CHECK:  store
+; CHECK:  br i1 %cmp, label %do.body.1, label %do.end
+; CHECK: do.end:
+; CHECK:  ret void
+; CHECK: do.body.1:
+; CHECK:  store
+; CHECK:  br label %do.end
+define void @u32_max1(i32 %n, i32* %p) {
+entry:
+  %add = add i32 %n, 1
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %do.body ]
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp = icmp ult i32 %i.0, %add
+  br i1 %cmp, label %do.body, label %do.end ; taken either 0 or 1 times
+
+do.end:
+  ret void
+}
+
+; Unroll thrice, with first loop exit kept
+; CHECK-LABEL: @u32_max2
+; CHECK: do.body:
+; CHECK:  store
+; CHECK:  br i1 %cmp, label %do.body.1, label %do.end
+; CHECK: do.end:
+; CHECK:  ret void
+; CHECK: do.body.1:
+; CHECK:  store
+; CHECK:  store
+; CHECK:  br label %do.end
+define void @u32_max2(i32 %n, i32* %p) {
+entry:
+  %add = add i32 %n, 2
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %do.body ]
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp = icmp ult i32 %i.0, %add
+  br i1 %cmp, label %do.body, label %do.end ; taken either 0 or 2 times
+
+do.end:
+  ret void
+}
+
+; Should not be unrolled
+; CHECK-LABEL: @u32_maxx
+; CHECK: do.body:
+; CHECK: do.end:
+; CHECK-NOT: do.body.1:
+define void @u32_maxx(i32 %n, i32 %x, i32* %p) {
+entry:
+  %add = add i32 %x, %n
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %do.body ]
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp = icmp ult i32 %i.0, %add
+  br i1 %cmp, label %do.body, label %do.end ; taken either 0 or x times
+
+do.end:
+  ret void
+}
+
+; Should not be unrolled
+; CHECK-LABEL: @u32_max2_unpredictable_exit
+; CHECK: do.body:
+; CHECK: do.end:
+; CHECK-NOT: do.body.1:
+define void @u32_max2_unpredictable_exit(i32 %n, i32 %x, i32* %p) {
+entry:
+  %add = add i32 %n, 2
+  br label %do.body
+
+do.body:
+  %i.0 = phi i32 [ %n, %entry ], [ %inc, %if.end ]
+  %cmp = icmp eq i32 %i.0, %x
+  br i1 %cmp, label %do.end, label %if.end ; unpredictable
+
+if.end:
+  %arrayidx = getelementptr i32, i32* %p, i32 %i.0
+  store i32 %i.0, i32* %arrayidx, align 4
+  %inc = add i32 %i.0, 1
+  %cmp1 = icmp ult i32 %i.0, %add
+  br i1 %cmp1, label %do.body, label %do.end ; taken either 0 or 2 times
+
+do.end:
+  ret void
+}