bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
- // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
- // Return Y via OutY.
- auto MatchBinaryAddToConst =
- [this](const SCEV *Result, const SCEV *X, APInt &OutY,
- SCEV::NoWrapFlags ExpectedFlags) {
- const SCEV *NonConstOp, *ConstOp;
- SCEV::NoWrapFlags FlagsPresent;
-
- if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
- !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+ // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
+ // C1 and C2 are constant integers. If either X or Y are not add expressions,
+ // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
+ // OutC1 and OutC2.
+ auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
+ APInt &OutC1, APInt &OutC2,
+ SCEV::NoWrapFlags ExpectedFlags) {
+ const SCEV *XNonConstOp, *XConstOp;
+ const SCEV *YNonConstOp, *YConstOp;
+ SCEV::NoWrapFlags XFlagsPresent;
+ SCEV::NoWrapFlags YFlagsPresent;
+
+ if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
+ XConstOp = getZero(X->getType());
+ XNonConstOp = X;
+ XFlagsPresent = ExpectedFlags;
+ }
+ if (!isa<SCEVConstant>(XConstOp) ||
+ (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
return false;
- OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
- return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+ if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
+ YConstOp = getZero(Y->getType());
+ YNonConstOp = Y;
+ YFlagsPresent = ExpectedFlags;
+ }
+
+ if (!isa<SCEVConstant>(YConstOp) ||
+ (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
+ return false;
+
+ if (YNonConstOp != XNonConstOp)
+ return false;
+
+ OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
+ OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
+
+ return true;
};
- APInt C;
+ APInt C1;
+ APInt C2;
switch (Pred) {
default:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLE:
- // X s<= (X + C)<nsw> if C >= 0
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+ // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
+ if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
return true;
- // (X + C)<nsw> s<= X if C <= 0
- if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
- !C.isStrictlyPositive())
- return true;
break;
case ICmpInst::ICMP_SGT:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_SLT:
- // X s< (X + C)<nsw> if C > 0
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
- C.isStrictlyPositive())
+ // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
+ if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
return true;
- // (X + C)<nsw> s< X if C < 0
- if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
- return true;
break;
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULE:
- // X u<= (X + C)<nuw> for any C
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW))
+ // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
+ if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
return true;
+
break;
case ICmpInst::ICMP_UGT:
std::swap(LHS, RHS);
LLVM_FALLTHROUGH;
case ICmpInst::ICMP_ULT:
- // X u< (X + C)<nuw> if C != 0
- if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue())
+ // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
+ if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
return true;
break;
}