[SCEV] Common code for computing trip count in a fixed type [NFC-ish]
authorPhilip Reames <preames@rivosinc.com>
Tue, 25 Apr 2023 18:57:46 +0000 (11:57 -0700)
committerPhilip Reames <listmail@philipreames.com>
Tue, 25 Apr 2023 19:04:42 +0000 (12:04 -0700)
This is a follow on to D147117 and D147355. In both cases, we were adding special cases to compute zext(BTC+1) instead of zext(BTC)+1 when the BTC+1 computation was known not to overflow.

Differential Revision: https://reviews.llvm.org/D148661

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Transforms/Scalar/LoopFlatten.cpp
llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

index f27cf22..0f281d0 100644 (file)
@@ -791,16 +791,19 @@ public:
   bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred,
                                    const SCEV *LHS, const SCEV *RHS);
 
+  /// A version of getTripCountFromExitCount below which always picks an
+  /// evaluation type which can not result in overflow.
+  const SCEV *getTripCountFromExitCount(const SCEV *ExitCount);
+
   /// Convert from an "exit count" (i.e. "backedge taken count") to a "trip
   /// count".  A "trip count" is the number of times the header of the loop
   /// will execute if an exit is taken after the specified number of backedges
   /// have been taken.  (e.g. TripCount = ExitCount + 1).  Note that the
-  /// expression can overflow if ExitCount = UINT_MAX.  \p Extend controls
-  /// how potential overflow is handled.  If true, a wider result type is
-  /// returned. ex: EC = 255 (i8), TC = 256 (i9).  If false, result unsigned
-  /// wraps with 2s-complement semantics.  ex: EC = 255 (i8), TC = 0 (i8)
-  const SCEV *getTripCountFromExitCount(const SCEV *ExitCount,
-                                        bool Extend = true);
+  /// expression can overflow if ExitCount = UINT_MAX.  If EvalTy is not wide
+  /// enough to hold the result without overflow, result unsigned wraps with
+  /// 2s-complement semantics.  ex: EC = 255 (i8), TC = 0 (i8)
+  const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, Type *EvalTy,
+                                        const Loop *L);
 
   /// Returns the exact trip count of the loop if we can compute it, and
   /// the result is a small constant.  '0' is used to represent an unknown
index 6a3e91a..15bb954 100644 (file)
@@ -8037,27 +8037,45 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 //                   Iteration Count Computation Code
 //
 
-const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
-                                                       bool Extend) {
+const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) {
   if (isa<SCEVCouldNotCompute>(ExitCount))
     return getCouldNotCompute();
 
   auto *ExitCountType = ExitCount->getType();
   assert(ExitCountType->isIntegerTy());
+  auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
+                                 1 + ExitCountType->getScalarSizeInBits());
+  return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
+}
+
+const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
+                                                       Type *EvalTy,
+                                                       const Loop *L) {
+  if (isa<SCEVCouldNotCompute>(ExitCount))
+    return getCouldNotCompute();
 
-  if (!Extend)
-    return getAddExpr(ExitCount, getOne(ExitCountType));
+  unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
+  unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
 
-  ConstantRange ExitCountRange =
+  auto CanAddOneWithoutOverflow = [&]() {
+    ConstantRange ExitCountRange =
       getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
-  if (!ExitCountRange.contains(
-          APInt::getMaxValue(ExitCountRange.getBitWidth())))
-    return getAddExpr(ExitCount, getOne(ExitCountType));
-
-  auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
-                                    1 + ExitCountType->getScalarSizeInBits());
-  return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
-                    getOne(WiderType));
+    if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
+      return true;
+
+    return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
+                                         getMinusOne(ExitCount->getType()));
+  };
+
+  // If we need to zero extend the backedge count, check if we can add one to
+  // it prior to zero extending without overflow. Provided this is safe, it
+  // allows better simplification of the +1.
+  if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
+    return getZeroExtendExpr(
+        getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
+
+  // Get the total trip count from the count by adding 1.  This may wrap.
+  return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
 }
 
 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
index 591f30c..edc8a49 100644 (file)
@@ -315,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L,
     return false;
   }
 
-  // The Extend=false flag is used for getTripCountFromExitCount as we want
-  // to verify and match it with the pattern matched tripcount. Please note
-  // that overflow checks are performed in checkOverflow, but are first tried
-  // to avoid by widening the IV.
+  // Evaluating in the trip count's type can not overflow here as the overflow
+  // checks are performed in checkOverflow, but are first tried to avoid by
+  // widening the IV.
   const SCEV *SCEVTripCount =
-      SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
+    SE->getTripCountFromExitCount(BackedgeTakenCount,
+                                  BackedgeTakenCount->getType(), L);
 
   const SCEV *SCEVRHS = SE->getSCEV(RHS);
   if (SCEVRHS == SCEVTripCount)
@@ -333,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
       // Find the extended backedge taken count and extended trip count using
       // SCEV. One of these should now match the RHS of the compare.
       BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
-      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
+      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt,
+                                                       RHS->getType(), L);
       if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
         LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
         return false;
index 2c999d7..bb0099e 100644 (file)
@@ -983,33 +983,6 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
   return SE->getMinusSCEV(Start, Index);
 }
 
-/// Compute trip count from the backedge taken count.
-static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
-                                Loop *CurLoop, const DataLayout *DL,
-                                ScalarEvolution *SE) {
-  const SCEV *TripCountS = nullptr;
-  // The # stored bytes is (BECount+1).  Expand the trip count out to
-  // pointer size if it isn't already.
-  //
-  // If we're going to need to zero extend the BE count, check if we can add
-  // one to it prior to zero extending without overflow. Provided this is safe,
-  // it allows better simplification of the +1.
-  if (DL->getTypeSizeInBits(BECount->getType()) <
-          DL->getTypeSizeInBits(IntPtr) &&
-      SE->isLoopEntryGuardedByCond(
-          CurLoop, ICmpInst::ICMP_NE, BECount,
-          SE->getMinusOne(BECount->getType()))) {
-    TripCountS = SE->getZeroExtendExpr(
-        SE->getAddExpr(BECount, SE->getOne(BECount->getType())),
-        IntPtr);
-  } else {
-    TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
-                                SE->getOne(IntPtr));
-  }
-
-  return TripCountS;
-}
-
 /// Compute the number of bytes as a SCEV from the backedge taken count.
 ///
 /// This also maps the SCEV into the provided type and tries to handle the
@@ -1017,8 +990,8 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
 static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
                                const SCEV *StoreSizeSCEV, Loop *CurLoop,
                                const DataLayout *DL, ScalarEvolution *SE) {
-  const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
-
+  const SCEV *TripCountSCEV =
+      SE->getTripCountFromExitCount(BECount, IntPtr, CurLoop);
   return SE->getMulExpr(TripCountSCEV,
                         SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
                         SCEV::FlagNUW);
index 435cb9a..645b62d 100644 (file)
@@ -987,35 +987,7 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
   assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
 
   ScalarEvolution &SE = *PSE.getSE();
-
-  unsigned BackEdgeSize = SE.getTypeSizeInBits(BackedgeTakenCount->getType());
-  unsigned IdxSize = IdxTy->getPrimitiveSizeInBits();
-
-  // If we need to need to zero extend the backedge count, check if we can
-  // add one to it prior to zero extending without overflow. Provided this is
-  // safe, it allows better simplification of the +1.
-  if (OrigLoop && BackEdgeSize < IdxSize &&
-      SE.isLoopEntryGuardedByCond(
-          OrigLoop, ICmpInst::ICMP_NE, BackedgeTakenCount,
-          SE.getMinusOne(BackedgeTakenCount->getType()))) {
-    return SE.getZeroExtendExpr(
-        SE.getAddExpr(BackedgeTakenCount,
-                      SE.getOne(BackedgeTakenCount->getType())),
-        IdxTy);
-  }
-
-  // The exit count might have the type of i64 while the phi is i32. This can
-  // happen if we have an induction variable that is sign extended before the
-  // compare. The only way that we get a backedge taken count is that the
-  // induction variable was signed and as such will not overflow. In such a case
-  // truncation is legal.
-  if (BackEdgeSize > IdxSize)
-    BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy);
-  BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy);
-
-  // Get the total trip count from the count by adding 1.
-  return SE.getAddExpr(BackedgeTakenCount,
-                       SE.getOne(BackedgeTakenCount->getType()));
+  return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
 }
 
 static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy,