Revert "[SCEV] Preserve divisibility and min/max information in applyLoopGuards"
authorkomalon1 <alon.kom@mobileye.com>
Thu, 23 Feb 2023 12:40:50 +0000 (14:40 +0200)
committerkomalon1 <alon.kom@mobileye.com>
Thu, 23 Feb 2023 12:44:03 +0000 (14:44 +0200)
This reverts commit 219ba2fb7b0ae89101f3c81a47fe4fc4aa80dea4.

llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
llvm/unittests/Analysis/ScalarEvolutionTest.cpp

index 4a5680d..5c2c12c 100644 (file)
@@ -15034,91 +15034,6 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     if (MatchRangeCheckIdiom())
       return;
 
-    // Return true if \p Expr is a MinMax SCEV expression with a constant
-    // operand. If so, return in \p SCTy the SCEV type and in \p RHS the
-    // non-constant operand and in \p LHS the constant operand.
-    auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy,
-                                        const SCEV *&LHS, const SCEV *&RHS) {
-      if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
-        if (MinMax->getNumOperands() != 2)
-          return false;
-        SCTy = MinMax->getSCEVType();
-        if (!isa<SCEVConstant>(MinMax->getOperand(0)))
-          return false;
-        LHS = MinMax->getOperand(0);
-        RHS = MinMax->getOperand(1);
-        return true;
-      }
-      return false;
-    };
-
-    // Checks whether Expr is a non-negative constant, and Divisor is a positive
-    // constant, and returns their APInt in ExprVal and in DivisorVal.
-    auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
-                                          APInt &ExprVal, APInt &DivisorVal) {
-      if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor))
-        return false;
-      auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
-      auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
-      if (!ConstExpr || !ConstDivisor)
-        return false;
-      ExprVal = ConstExpr->getAPInt();
-      DivisorVal = ConstDivisor->getAPInt();
-      return true;
-    };
-
-    // Return a new SCEV that modifies \p Expr to the closest number divides by
-    // \p Divisor and greater or equal than Expr.
-    // For now, only handle constant Expr and Divisor.
-    auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
-                                           const SCEV *Divisor) {
-      APInt ExprVal;
-      APInt DivisorVal;
-      if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
-        return Expr;
-      APInt Rem = ExprVal.urem(DivisorVal);
-      if (!Rem.isZero())
-        // return the SCEV: Expr + Divisor - Expr % Divisor
-        return getConstant(ExprVal + DivisorVal - Rem);
-      return Expr;
-    };
-
-    // Return a new SCEV that modifies \p Expr to the closest number divides by
-    // \p Divisor and less or equal than Expr.
-    // For now, only handle constant Expr and Divisor.
-    auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
-                                               const SCEV *Divisor) {
-      APInt ExprVal;
-      APInt DivisorVal;
-      if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
-        return Expr;
-      APInt Rem = ExprVal.urem(DivisorVal);
-      // return the SCEV: Expr - Expr % Divisor
-      return getConstant(ExprVal - Rem);
-    };
-
-    // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
-    // recursively. This is done by aligning up/down the constant value to the
-    // Divisor.
-    std::function<const SCEV *(const SCEV *, const SCEV *)>
-        ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
-                                           const SCEV *Divisor) {
-          const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
-          SCEVTypes SCTy;
-          if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS))
-            return MinMaxExpr;
-          auto IsMin =
-              isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
-          assert(isKnownNonNegative(MinMaxLHS) &&
-                 "Expected non-negative operand!");
-          auto *DivisibleExpr =
-              IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
-                    : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
-          SmallVector<const SCEV *> Ops = {
-              ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
-          return getMinMaxExpr(SCTy, Ops);
-        };
-
     // If we have LHS == 0, check if LHS is computing a property of some unknown
     // SCEV %v which we can rewrite %v to express explicitly.
     const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
@@ -15130,12 +15045,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       const SCEV *URemRHS = nullptr;
       if (matchURem(LHS, URemLHS, URemRHS)) {
         if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
-          auto I = RewriteMap.find(LHSUnknown);
-          const SCEV *RewrittenLHS =
-              I != RewriteMap.end() ? I->second : LHSUnknown;
-          RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
-          const auto *Multiple =
-              getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+          const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
           RewriteMap[LHSUnknown] = Multiple;
           ExprsToRewrite.push_back(LHSUnknown);
           return;
@@ -15158,128 +15068,48 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     auto I = RewriteMap.find(LHS);
     const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
 
-    // Check for the SCEV expression (A /u B) * B while B is a constant, inside
-    // \p Expr. The check is done recuresively on \p Expr, which is assumed to
-    // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
-    // /u B) * B was found, and return the divisor B in \p DividesBy. For
-    // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
-    // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
-    // DividesBy.
-    std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
-        [&](const SCEV *Expr, const SCEV *&DividesBy) {
-          if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
-            if (Mul->getNumOperands() != 2)
-              return false;
-            auto *MulLHS = Mul->getOperand(0);
-            auto *MulRHS = Mul->getOperand(1);
-            if (isa<SCEVConstant>(MulLHS))
-              std::swap(MulLHS, MulRHS);
-            if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS)) {
-              if (Div->getOperand(1) == MulRHS) {
-                DividesBy = MulRHS;
-                return true;
-              }
-            }
-          }
-          if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
-            return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
-                   HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
-          }
-          return false;
-        };
-
-    // Return true if Expr known to divide by \p DividesBy.
-    std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
-        [&](const SCEV *Expr, const SCEV *DividesBy) {
-          if (getURemExpr(Expr, DividesBy)->isZero())
-            return true;
-          if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
-            return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
-                   IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
-          }
-          return false;
-        };
-
-    const SCEV *DividesBy = nullptr;
-    if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
-      // Check that the whole expression is divided by DividesBy
-      DividesBy =
-          IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
-
     const SCEV *RewrittenRHS = nullptr;
     switch (Predicate) {
     case CmpInst::ICMP_ULT: {
       if (RHS->getType()->isPointerTy())
         break;
       const SCEV *One = getOne(RHS->getType());
-      auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One);
-      ModifiedRHS =
-          DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
-                    : ModifiedRHS;
-      RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
+      RewrittenRHS =
+          getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One));
       break;
     }
-    case CmpInst::ICMP_SLT: {
-      auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType()));
-      ModifiedRHS =
-          DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
-                    : ModifiedRHS;
-      RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_SLT:
+      RewrittenRHS =
+          getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
       break;
-    }
-    case CmpInst::ICMP_ULE: {
-      auto *ModifiedRHS =
-          DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
-      RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_ULE:
+      RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
       break;
-    }
-    case CmpInst::ICMP_SLE: {
-      auto *ModifiedRHS =
-          DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
-      RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_SLE:
+      RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
       break;
-    }
-    case CmpInst::ICMP_UGT: {
-      auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
-      ModifiedRHS = DividesBy
-                        ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
-                        : ModifiedRHS;
-      RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_UGT:
+      RewrittenRHS =
+          getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
       break;
-    }
-    case CmpInst::ICMP_SGT: {
-      auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
-      ModifiedRHS = DividesBy
-                        ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
-                        : ModifiedRHS;
-      RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_SGT:
+      RewrittenRHS =
+          getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
       break;
-    }
-    case CmpInst::ICMP_UGE: {
-      auto *ModifiedRHS =
-          DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
-      RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_UGE:
+      RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
       break;
-    }
-    case CmpInst::ICMP_SGE: {
-      auto *ModifiedRHS =
-          DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
-      RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
+    case CmpInst::ICMP_SGE:
+      RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
       break;
-    }
     case CmpInst::ICMP_EQ:
       if (isa<SCEVConstant>(RHS))
         RewrittenRHS = RHS;
       break;
     case CmpInst::ICMP_NE:
       if (isa<SCEVConstant>(RHS) &&
-          cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
-        auto *ModifiedRHS = getOne(RHS->getType());
-        ModifiedRHS = DividesBy
-                          ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
-                          : ModifiedRHS;
-        RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
-      }
+          cast<SCEVConstant>(RHS)->getValue()->isNullValue())
+        RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
       break;
     default:
       break;
index 492ed9c..cfa91e3 100644 (file)
@@ -125,7 +125,7 @@ define void @test_trip_multiple_4_ugt_5_order_swapped(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 2
 ;
 entry:
   %u = urem i32 %num, 4
@@ -196,7 +196,7 @@ define void @test_trip_multiple_4_sgt_5_order_swapped(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 2
 ;
 entry:
   %u = urem i32 %num, 4
@@ -267,7 +267,7 @@ define void @test_trip_multiple_4_uge_5_order_swapped(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 1
 ;
 entry:
   %u = urem i32 %num, 4
@@ -338,7 +338,7 @@ define void @test_trip_multiple_4_sge_5_order_swapped(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 1
 ;
 entry:
   %u = urem i32 %num, 4
@@ -409,7 +409,7 @@ define void @test_trip_multiple_4_upper_lower_bounds(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 1
 ;
 entry:
   %cmp.1 = icmp uge i32 %num, 5
@@ -446,7 +446,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped1(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 1
 ;
 entry:
   %cmp.1 = icmp uge i32 %num, 5
@@ -483,7 +483,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped2(i32 %num) {
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is (-1 + %num)
 ; CHECK-NEXT:   Predicates:
-; CHECK:       Loop %for.body: Trip multiple is 4
+; CHECK:       Loop %for.body: Trip multiple is 1
 ;
 entry:
   %cmp.1 = icmp uge i32 %num, 5
index d0fec1a..8756e2c 100644 (file)
@@ -1744,42 +1744,4 @@ TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromMultiDemArray) {
   });
 }
 
-TEST_F(ScalarEvolutionsTest, ApplyLoopGuards) {
-  LLVMContext C;
-  SMDiagnostic Err;
-  std::unique_ptr<Module> M = parseAssemblyString(
-      "declare void @llvm.assume(i1)\n"
-      "define void @test(i32 %num) {\n"
-      "entry:\n"
-      "  %u = urem i32 %num, 4\n"
-      "  %cmp = icmp eq i32 %u, 0\n"
-      "  tail call void @llvm.assume(i1 %cmp)\n"
-      "  %cmp.1 = icmp ugt i32 %num, 0\n"
-      "  tail call void @llvm.assume(i1 %cmp.1)\n"
-      "  br label %for.body\n"
-      "for.body:\n"
-      "  %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]\n"
-      "  %inc = add nuw nsw i32 %i.010, 1\n"
-      "  %cmp2 = icmp ult i32 %inc, %num\n"
-      "  br i1 %cmp2, label %for.body, label %exit\n"
-      "exit:\n"
-      "  ret void\n"
-      "}\n",
-      Err, C);
-
-  ASSERT_TRUE(M && "Could not parse module?");
-  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
-
-  runWithSE(*M, "test", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
-    auto *TCScev = SE.getSCEV(getArgByName(F, "num"));
-    auto *ApplyLoopGuardsTC = SE.applyLoopGuards(TCScev, *LI.begin());
-    // Assert that the new TC is (4 * ((4 umax %num) /u 4))
-    APInt Four(32, 4);
-    auto *Constant4 = SE.getConstant(Four);
-    auto *Max = SE.getUMaxExpr(TCScev, Constant4);
-    auto *Mul = SE.getMulExpr(SE.getUDivExpr(Max, Constant4), Constant4);
-    ASSERT_TRUE(Mul == ApplyLoopGuardsTC);
-  });
-}
-
 }  // end namespace llvm