From cbba71bfb50fb668b80ed430125a460279928272 Mon Sep 17 00:00:00 2001 From: Eli Friedman Date: Fri, 9 Jul 2021 14:10:44 -0700 Subject: [PATCH] [ScalarEvolution] Fix overflow in computeBECount. The current implementation of computeBECount doesn't account for the possibility that adding "Stride - 1" to Delta might overflow. For almost all loops, it doesn't, but it's not actually proven anywhere. To deal with this, use a variety of tricks to try to prove that the addition doesn't overflow. If the proof is impossible, use an alternate sequence which never overflows. Differential Revision: https://reviews.llvm.org/D105216 --- llvm/include/llvm/Analysis/ScalarEvolution.h | 7 -- llvm/lib/Analysis/ScalarEvolution.cpp | 129 +++++++++++++++++---- .../Analysis/ScalarEvolution/2008-11-18-Stride2.ll | 2 +- .../ScalarEvolution/trip-count-unknown-stride.ll | 6 +- 4 files changed, 112 insertions(+), 32 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 788e9ca..ae9c73f 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -2032,13 +2032,6 @@ private: Optional>> createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI); - /// Compute the backedge taken count knowing the interval difference, and - /// the stride for an inequality. Result takes the form: - /// (Delta + (Stride - 1)) udiv Stride. - /// Caller must ensure that this expression either does not overflow or - /// that the result is undefined if it does. - const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride); - /// Compute the maximum backedge count based on the range of values /// permitted by Start, End, and Stride. This is for loops of the form /// {Start, +, Stride} LT End. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 88858ba..32023e5 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -11528,13 +11528,6 @@ const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) { return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D)); } -const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, - const SCEV *Step) { - const SCEV *One = getOne(Step->getType()); - Delta = getAddExpr(Delta, getMinusSCEV(Step, One)); - return getUDivExpr(Delta, Step); -} - const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride, const SCEV *End, @@ -11743,7 +11736,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, return RHS; } - const SCEV *End = RHS; // When the RHS is not invariant, we do not know the end bound of the loop and // cannot calculate the ExactBECount needed by ExitLimit. However, we can // calculate the MaxBECount, given the start, stride and max value for the end @@ -11755,13 +11747,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, false /*MaxOrZero*/, Predicates); } - // If the backedge is taken at least once, then it will be taken - // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start - // is the LHS value of the less-than comparison the first time it is evaluated - // and End is the RHS. - const SCEV *BECountIfBackedgeTaken = - computeBECount(getMinusSCEV(End, Start), Stride); - + // We use the expression (max(End,Start)-Start)/Stride to describe the // backedge count, as if the backedge is taken at least once max(End,Start) // is End and so the result is as above, and if not max(End,Start) is Start @@ -11796,6 +11782,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, BECount = getUDivExpr(Numerator, Stride); } } + + const SCEV *BECountIfBackedgeTaken = nullptr; if (!BECount) { auto canProveRHSGreaterThanEqualStart = [&]() { auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; @@ -11819,18 +11807,112 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // If we know that RHS >= Start in the context of loop, then we know that // max(RHS, Start) = RHS at this point. - if (canProveRHSGreaterThanEqualStart()) + const SCEV *End; + if (canProveRHSGreaterThanEqualStart()) { End = RHS; - else + } else { + // If RHS < Start, the backedge will be taken zero times. So in + // general, we can write the backedge-taken count as: + // + // RHS >= Start ? ceil(RHS - Start) / Stride : 0 + // + // We convert it to the following to make it more convenient for SCEV: + // + // ceil(max(RHS, Start) - Start) / Stride End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); - BECount = computeBECount(getMinusSCEV(End, Start), Stride); + + // See what would happen if we assume the backedge is taken. This is + // used to compute MaxBECount. + BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride); + } + + // At this point, we know: + // + // 1. If IsSigned, Start <=s End; otherwise, Start <=u End + // 2. The index variable doesn't overflow. + // + // Therefore, we know N exists such that + // (Start + Stride * N) >= End, and computing "(Start + Stride * N)" + // doesn't overflow. + // + // Using this information, try to prove whether the addition in + // "(Start - End) + (Stride - 1)" has unsigned overflow. + const SCEV *One = getOne(Stride->getType()); + bool MayAddOverflow = [&] { + if (auto *StrideC = dyn_cast(Stride)) { + if (StrideC->getAPInt().isPowerOf2()) { + // Suppose Stride is a power of two, and Start/End are unsigned + // integers. Let UMAX be the largest representable unsigned + // integer. + // + // By the preconditions of this function, we know + // "(Start + Stride * N) >= End", and this doesn't overflow. + // As a formula: + // + // End <= (Start + Stride * N) <= UMAX + // + // Subtracting Start from all the terms: + // + // End - Start <= Stride * N <= UMAX - Start + // + // Since Start is unsigned, UMAX - Start <= UMAX. Therefore: + // + // End - Start <= Stride * N <= UMAX + // + // Stride * N is a multiple of Stride. Therefore, + // + // End - Start <= Stride * N <= UMAX - (UMAX mod Stride) + // + // Since Stride is a power of two, UMAX + 1 is divisible by Stride. + // Therefore, UMAX mod Stride == Stride - 1. So we can write: + // + // End - Start <= Stride * N <= UMAX - Stride - 1 + // + // Dropping the middle term: + // + // End - Start <= UMAX - Stride - 1 + // + // Adding Stride - 1 to both sides: + // + // (End - Start) + (Stride - 1) <= UMAX + // + // In other words, the addition doesn't have unsigned overflow. + // + // A similar proof works if we treat Start/End as signed values. + // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to + // use signed max instead of unsigned max. Note that we're trying + // to prove a lack of unsigned overflow in either case. + return false; + } + } + if (Start == Stride || Start == getMinusSCEV(Stride, One)) { + // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1. + // If !IsSigned, 0 (BECount)) + if (isa(BECount)) { MaxBECount = BECount; - else if (isa(BECountIfBackedgeTaken)) { + } else if (BECountIfBackedgeTaken && + isa(BECountIfBackedgeTaken)) { // If we know exactly how many times the backedge will be taken if it's // taken at least once, then the backedge count will either be that or // zero. @@ -11909,7 +11991,12 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, return End; } - const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride); + // Compute ((Start - End) + (Stride - 1)) / Stride. + // FIXME: This can overflow. Holding off on fixing this for now; + // howManyGreaterThans will hopefully be gone soon. + const SCEV *One = getOne(Stride->getType()); + const SCEV *BECount = getUDivExpr( + getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride); APInt MaxStart = IsSigned ? getSignedRangeMax(Start) : getUnsignedRangeMax(Start); diff --git a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll index 390f974..79f3fa5 100644 --- a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll +++ b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll @@ -1,7 +1,7 @@ ; RUN: opt < %s -analyze -enable-new-pm=0 -scalar-evolution 2>&1 | FileCheck %s ; RUN: opt < %s -disable-output "-passes=print" 2>&1 2>&1 | FileCheck %s -; CHECK: Loop %bb: backedge-taken count is ((-1 + (-1 * %x) + (1000 umax (3 + %x))) /u 3) +; CHECK: Loop %bb: backedge-taken count is (((-3 + (-1 * (1 umin (-3 + (-1 * %x) + (1000 umax (3 + %x))))) + (-1 * %x) + (1000 umax (3 + %x))) /u 3) + (1 umin (-3 + (-1 * %x) + (1000 umax (3 + %x))))) ; CHECK: Loop %bb: max backedge-taken count is 334 diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll index a085ee9..ec03623 100644 --- a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll @@ -35,7 +35,7 @@ for.end: ; preds = %for.body, %entry ; Check that we are able to compute trip count of a loop without an entry guard. ; CHECK: Determining loop execution counts for: @foo2 -; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s)) +; CHECK: backedge-taken count is ((((-1 * (1 umin ((-1 * %s) + (%n smax %s)))) + (-1 * %s) + (%n smax %s)) /u (1 umax %s)) + (1 umin ((-1 * %s) + (%n smax %s)))) ; We should have a conservative estimate for the max backedge taken count for ; loops with unknown stride. @@ -85,7 +85,7 @@ for.end: ; preds = %for.body, %entry ; Same as foo2, but with mustprogress on loop, not function ; CHECK: Determining loop execution counts for: @foo4 -; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s)) +; CHECK: backedge-taken count is ((((-1 * (1 umin ((-1 * %s) + (%n smax %s)))) + (-1 * %s) + (%n smax %s)) /u (1 umax %s)) + (1 umin ((-1 * %s) + (%n smax %s)))) ; CHECK: max backedge-taken count is -1 define void @foo4(i32* nocapture %A, i32 %n, i32 %s) { @@ -108,7 +108,7 @@ for.end: ; preds = %for.body, %entry ; A more complex case with pre-increment compare instead of post-increment. ; CHECK-LABEL: Determining loop execution counts for: @foo5 -; CHECK: Loop %for.body: backedge-taken count is ((-1 + (-1 * %start) + (1 umax %s) + (%n smax %start)) /u (1 umax %s)) +; CHECK: Loop %for.body: backedge-taken count is ((((-1 * (1 umin ((-1 * %start) + (%n smax %start)))) + (-1 * %start) + (%n smax %start)) /u (1 umax %s)) + (1 umin ((-1 * %start) + (%n smax %start)))) ; We should have a conservative estimate for the max backedge taken count for ; loops with unknown stride. -- 2.7.4