From 09d879d060ed31b22a6e72f7f5e44fe9b5660aa3 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Tue, 25 Apr 2023 11:57:46 -0700 Subject: [PATCH] [SCEV] Common code for computing trip count in a fixed type [NFC-ish] This is a follow on to D147117 and D147355. In both cases, we were adding special cases to compute zext(BTC+1) instead of zext(BTC)+1 when the BTC+1 computation was known not to overflow. Differential Revision: https://reviews.llvm.org/D148661 --- llvm/include/llvm/Analysis/ScalarEvolution.h | 15 ++++---- llvm/lib/Analysis/ScalarEvolution.cpp | 44 ++++++++++++++++------- llvm/lib/Transforms/Scalar/LoopFlatten.cpp | 13 +++---- llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp | 31 ++-------------- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 30 +--------------- 5 files changed, 50 insertions(+), 83 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index f27cf22..0f281d0 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -791,16 +791,19 @@ public: bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); + /// A version of getTripCountFromExitCount below which always picks an + /// evaluation type which can not result in overflow. + const SCEV *getTripCountFromExitCount(const SCEV *ExitCount); + /// Convert from an "exit count" (i.e. "backedge taken count") to a "trip /// count". A "trip count" is the number of times the header of the loop /// will execute if an exit is taken after the specified number of backedges /// have been taken. (e.g. TripCount = ExitCount + 1). Note that the - /// expression can overflow if ExitCount = UINT_MAX. \p Extend controls - /// how potential overflow is handled. If true, a wider result type is - /// returned. ex: EC = 255 (i8), TC = 256 (i9). If false, result unsigned - /// wraps with 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8) - const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, - bool Extend = true); + /// expression can overflow if ExitCount = UINT_MAX. If EvalTy is not wide + /// enough to hold the result without overflow, result unsigned wraps with + /// 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8) + const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, Type *EvalTy, + const Loop *L); /// Returns the exact trip count of the loop if we can compute it, and /// the result is a small constant. '0' is used to represent an unknown diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 6a3e91a..15bb954 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8037,27 +8037,45 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, - bool Extend) { +const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { if (isa(ExitCount)) return getCouldNotCompute(); auto *ExitCountType = ExitCount->getType(); assert(ExitCountType->isIntegerTy()); + auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(), + 1 + ExitCountType->getScalarSizeInBits()); + return getTripCountFromExitCount(ExitCount, EvalTy, nullptr); +} + +const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, + Type *EvalTy, + const Loop *L) { + if (isa(ExitCount)) + return getCouldNotCompute(); - if (!Extend) - return getAddExpr(ExitCount, getOne(ExitCountType)); + unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType()); + unsigned EvalSize = EvalTy->getPrimitiveSizeInBits(); - ConstantRange ExitCountRange = + auto CanAddOneWithoutOverflow = [&]() { + ConstantRange ExitCountRange = getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED); - if (!ExitCountRange.contains( - APInt::getMaxValue(ExitCountRange.getBitWidth()))) - return getAddExpr(ExitCount, getOne(ExitCountType)); - - auto *WiderType = Type::getIntNTy(ExitCountType->getContext(), - 1 + ExitCountType->getScalarSizeInBits()); - return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType), - getOne(WiderType)); + if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize))) + return true; + + return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount, + getMinusOne(ExitCount->getType())); + }; + + // If we need to zero extend the backedge count, check if we can add one to + // it prior to zero extending without overflow. Provided this is safe, it + // allows better simplification of the +1. + if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow()) + return getZeroExtendExpr( + getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy); + + // Get the total trip count from the count by adding 1. This may wrap. + return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy)); } static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 591f30c..edc8a49 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -315,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L, return false; } - // The Extend=false flag is used for getTripCountFromExitCount as we want - // to verify and match it with the pattern matched tripcount. Please note - // that overflow checks are performed in checkOverflow, but are first tried - // to avoid by widening the IV. + // Evaluating in the trip count's type can not overflow here as the overflow + // checks are performed in checkOverflow, but are first tried to avoid by + // widening the IV. const SCEV *SCEVTripCount = - SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false); + SE->getTripCountFromExitCount(BackedgeTakenCount, + BackedgeTakenCount->getType(), L); const SCEV *SCEVRHS = SE->getSCEV(RHS); if (SCEVRHS == SCEVTripCount) @@ -333,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L, // Find the extended backedge taken count and extended trip count using // SCEV. One of these should now match the RHS of the compare. BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); - SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, + RHS->getType(), L); if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); return false; diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 2c999d7..bb0099e 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -983,33 +983,6 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, return SE->getMinusSCEV(Start, Index); } -/// Compute trip count from the backedge taken count. -static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, - Loop *CurLoop, const DataLayout *DL, - ScalarEvolution *SE) { - const SCEV *TripCountS = nullptr; - // The # stored bytes is (BECount+1). Expand the trip count out to - // pointer size if it isn't already. - // - // If we're going to need to zero extend the BE count, check if we can add - // one to it prior to zero extending without overflow. Provided this is safe, - // it allows better simplification of the +1. - if (DL->getTypeSizeInBits(BECount->getType()) < - DL->getTypeSizeInBits(IntPtr) && - SE->isLoopEntryGuardedByCond( - CurLoop, ICmpInst::ICMP_NE, BECount, - SE->getMinusOne(BECount->getType()))) { - TripCountS = SE->getZeroExtendExpr( - SE->getAddExpr(BECount, SE->getOne(BECount->getType())), - IntPtr); - } else { - TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), - SE->getOne(IntPtr)); - } - - return TripCountS; -} - /// Compute the number of bytes as a SCEV from the backedge taken count. /// /// This also maps the SCEV into the provided type and tries to handle the @@ -1017,8 +990,8 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, const SCEV *StoreSizeSCEV, Loop *CurLoop, const DataLayout *DL, ScalarEvolution *SE) { - const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); - + const SCEV *TripCountSCEV = + SE->getTripCountFromExitCount(BECount, IntPtr, CurLoop); return SE->getMulExpr(TripCountSCEV, SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), SCEV::FlagNUW); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 435cb9a..645b62d 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -987,35 +987,7 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE, assert(!isa(BackedgeTakenCount) && "Invalid loop count"); ScalarEvolution &SE = *PSE.getSE(); - - unsigned BackEdgeSize = SE.getTypeSizeInBits(BackedgeTakenCount->getType()); - unsigned IdxSize = IdxTy->getPrimitiveSizeInBits(); - - // If we need to need to zero extend the backedge count, check if we can - // add one to it prior to zero extending without overflow. Provided this is - // safe, it allows better simplification of the +1. - if (OrigLoop && BackEdgeSize < IdxSize && - SE.isLoopEntryGuardedByCond( - OrigLoop, ICmpInst::ICMP_NE, BackedgeTakenCount, - SE.getMinusOne(BackedgeTakenCount->getType()))) { - return SE.getZeroExtendExpr( - SE.getAddExpr(BackedgeTakenCount, - SE.getOne(BackedgeTakenCount->getType())), - IdxTy); - } - - // The exit count might have the type of i64 while the phi is i32. This can - // happen if we have an induction variable that is sign extended before the - // compare. The only way that we get a backedge taken count is that the - // induction variable was signed and as such will not overflow. In such a case - // truncation is legal. - if (BackEdgeSize > IdxSize) - BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy); - BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); - - // Get the total trip count from the count by adding 1. - return SE.getAddExpr(BackedgeTakenCount, - SE.getOne(BackedgeTakenCount->getType())); + return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop); } static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy, -- 2.7.4