[SCEV] Generalize MatchBinaryAddToConst to support non-add expressions.
authorFlorian Hahn <flo@fhahn.com>
Thu, 24 Jun 2021 09:00:27 +0000 (10:00 +0100)
committerFlorian Hahn <flo@fhahn.com>
Thu, 24 Jun 2021 11:16:15 +0000 (12:16 +0100)
This patch generalizes MatchBinaryAddToConst to support matching
(A + C1), (A + C2), instead of just matching (A + C1), A.

The existing cases can be handled by treating non-add expressions A as
A + 0.

Reviewed By: mkazantsev

Differential Revision: https://reviews.llvm.org/D104634

llvm/lib/Analysis/ScalarEvolution.cpp

index 8bd8e28..8e13a2c 100644 (file)
@@ -10075,23 +10075,48 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges(
 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
                                                     const SCEV *LHS,
                                                     const SCEV *RHS) {
-  // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
-  // Return Y via OutY.
-  auto MatchBinaryAddToConst =
-      [this](const SCEV *Result, const SCEV *X, APInt &OutY,
-             SCEV::NoWrapFlags ExpectedFlags) {
-    const SCEV *NonConstOp, *ConstOp;
-    SCEV::NoWrapFlags FlagsPresent;
-
-    if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
-        !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+  // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
+  // C1 and C2 are constant integers. If either X or Y are not add expressions,
+  // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
+  // OutC1 and OutC2.
+  auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
+                                      APInt &OutC1, APInt &OutC2,
+                                      SCEV::NoWrapFlags ExpectedFlags) {
+    const SCEV *XNonConstOp, *XConstOp;
+    const SCEV *YNonConstOp, *YConstOp;
+    SCEV::NoWrapFlags XFlagsPresent;
+    SCEV::NoWrapFlags YFlagsPresent;
+
+    if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
+      XConstOp = getZero(X->getType());
+      XNonConstOp = X;
+      XFlagsPresent = ExpectedFlags;
+    }
+    if (!isa<SCEVConstant>(XConstOp) ||
+        (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
       return false;
 
-    OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
-    return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+    if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
+      YConstOp = getZero(Y->getType());
+      YNonConstOp = Y;
+      YFlagsPresent = ExpectedFlags;
+    }
+
+    if (!isa<SCEVConstant>(YConstOp) ||
+        (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
+      return false;
+
+    if (YNonConstOp != XNonConstOp)
+      return false;
+
+    OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
+    OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
+
+    return true;
   };
 
-  APInt C;
+  APInt C1;
+  APInt C2;
 
   switch (Pred) {
   default:
@@ -10101,45 +10126,38 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
     std::swap(LHS, RHS);
     LLVM_FALLTHROUGH;
   case ICmpInst::ICMP_SLE:
-    // X s<= (X + C)<nsw> if C >= 0
-    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+    // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
+    if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
       return true;
 
-    // (X + C)<nsw> s<= X if C <= 0
-    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
-        !C.isStrictlyPositive())
-      return true;
     break;
 
   case ICmpInst::ICMP_SGT:
     std::swap(LHS, RHS);
     LLVM_FALLTHROUGH;
   case ICmpInst::ICMP_SLT:
-    // X s< (X + C)<nsw> if C > 0
-    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
-        C.isStrictlyPositive())
+    // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
+    if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
       return true;
 
-    // (X + C)<nsw> s< X if C < 0
-    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
-      return true;
     break;
 
   case ICmpInst::ICMP_UGE:
     std::swap(LHS, RHS);
     LLVM_FALLTHROUGH;
   case ICmpInst::ICMP_ULE:
-    // X u<= (X + C)<nuw> for any C
-    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW))
+    // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
+    if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
       return true;
+
     break;
 
   case ICmpInst::ICMP_UGT:
     std::swap(LHS, RHS);
     LLVM_FALLTHROUGH;
   case ICmpInst::ICMP_ULT:
-    // X u< (X + C)<nuw> if C != 0
-    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue())
+    // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
+    if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
       return true;
     break;
   }