[SCEV][NFC] Refactor range computation for AddRec to pass around APInt
authorJoshua Cao <cao.joshua@yahoo.com>
Wed, 31 May 2023 03:40:10 +0000 (20:40 -0700)
committerJoshua Cao <cao.joshua@yahoo.com>
Thu, 1 Jun 2023 04:03:20 +0000 (21:03 -0700)
llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp

index 58f821e..2db7126 100644 (file)
@@ -1674,7 +1674,7 @@ private:
   /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}.
   /// Helper for \c getRange.
   ConstantRange getRangeForAffineAR(const SCEV *Start, const SCEV *Step,
-                                    const SCEV *MaxBECount, unsigned BitWidth);
+                                    const APInt &MaxBECount);
 
   /// Determines the range for the affine non-self-wrapping SCEVAddRecExpr {\p
   /// Start,+,\p Step}<nw>.
@@ -1687,7 +1687,7 @@ private:
   /// Step} by "factoring out" a ternary expression from the add recurrence.
   /// Helper called by \c getRange.
   ConstantRange getRangeViaFactoring(const SCEV *Start, const SCEV *Step,
-                                     const SCEV *MaxBECount, unsigned BitWidth);
+                                     const APInt &MaxBECount);
 
   /// If the unknown expression U corresponds to a simple recurrence, return
   /// a constant range which represents the entire recurrence.  Note that
index db8ac4f..d59d0bf 100644 (file)
@@ -6699,21 +6699,23 @@ const ConstantRange &ScalarEvolution::getRangeRef(
 
     // TODO: non-affine addrec
     if (AddRec->isAffine()) {
-      const SCEV *MaxBECount =
+      const SCEV *MaxBEScev =
           getConstantMaxBackedgeTakenCount(AddRec->getLoop());
-      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
-          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
-        auto RangeFromAffine = getRangeForAffineAR(
-            AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
-            BitWidth);
-        ConservativeResult =
-            ConservativeResult.intersectWith(RangeFromAffine, RangeType);
+      if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
+        APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
+        if (MaxBECount.getBitWidth() < BitWidth)
+          MaxBECount = MaxBECount.zext(BitWidth);
+        if (MaxBECount.getBitWidth() == BitWidth) {
+          auto RangeFromAffine = getRangeForAffineAR(
+              AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
+          ConservativeResult =
+              ConservativeResult.intersectWith(RangeFromAffine, RangeType);
 
-        auto RangeFromFactoring = getRangeViaFactoring(
-            AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
-            BitWidth);
-        ConservativeResult =
-            ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
+          auto RangeFromFactoring = getRangeViaFactoring(
+              AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
+          ConservativeResult =
+              ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
+        }
       }
 
       // Now try symbolic BE count and more powerful methods.
@@ -6721,7 +6723,7 @@ const ConstantRange &ScalarEvolution::getRangeRef(
         const SCEV *SymbolicMaxBECount =
             getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
         if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
-            getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
+            getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
             AddRec->hasNoSelfWrap()) {
           auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
               AddRec, SymbolicMaxBECount, BitWidth, SignHint);
@@ -6885,7 +6887,10 @@ const ConstantRange &ScalarEvolution::getRangeRef(
 static ConstantRange getRangeForAffineARHelper(APInt Step,
                                                const ConstantRange &StartRange,
                                                const APInt &MaxBECount,
-                                               unsigned BitWidth, bool Signed) {
+                                               bool Signed) {
+  unsigned BitWidth = Step.getBitWidth();
+  assert(BitWidth == StartRange.getBitWidth() &&
+         BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
   // If either Step or MaxBECount is 0, then the expression won't change, and we
   // just need to return the initial range.
   if (Step == 0 || MaxBECount == 0)
@@ -6944,14 +6949,11 @@ static ConstantRange getRangeForAffineARHelper(APInt Step,
 
 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
                                                    const SCEV *Step,
-                                                   const SCEV *MaxBECount,
-                                                   unsigned BitWidth) {
-  assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&
-         getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
-         "Precondition!");
-
-  MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
-  APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
+                                                   const APInt &MaxBECount) {
+  assert(getTypeSizeInBits(Start->getType()) ==
+             getTypeSizeInBits(Step->getType()) &&
+         getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
+         "mismatched bit widths");
 
   // First, consider step signed.
   ConstantRange StartSRange = getSignedRange(Start);
@@ -6959,17 +6961,16 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
 
   // If Step can be both positive and negative, we need to find ranges for the
   // maximum absolute step values in both directions and union them.
-  ConstantRange SR =
-      getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
-                                MaxBECountValue, BitWidth, /* Signed = */ true);
+  ConstantRange SR = getRangeForAffineARHelper(
+      StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
   SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
-                                              StartSRange, MaxBECountValue,
-                                              BitWidth, /* Signed = */ true));
+                                              StartSRange, MaxBECount,
+                                              /* Signed = */ true));
 
   // Next, consider step unsigned.
   ConstantRange UR = getRangeForAffineARHelper(
-      getUnsignedRangeMax(Step), getUnsignedRange(Start),
-      MaxBECountValue, BitWidth, /* Signed = */ false);
+      getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
+      /* Signed = */ false);
 
   // Finally, intersect signed and unsigned ranges.
   return SR.intersectWith(UR, ConstantRange::Smallest);
@@ -7045,11 +7046,15 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
 
 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
                                                     const SCEV *Step,
-                                                    const SCEV *MaxBECount,
-                                                    unsigned BitWidth) {
+                                                    const APInt &MaxBECount) {
   //    RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
   // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
 
+  unsigned BitWidth = MaxBECount.getBitWidth();
+  assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
+         getTypeSizeInBits(Step->getType()) == BitWidth &&
+         "mismatched bit widths");
+
   struct SelectPattern {
     Value *Condition = nullptr;
     APInt TrueValue;
@@ -7151,9 +7156,9 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
   const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
 
   ConstantRange TrueRange =
-      this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
+      this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
   ConstantRange FalseRange =
-      this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
+      this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
 
   return TrueRange.unionWith(FalseRange);
 }