From 46c59a55e747ed7c0c68e64b13621a5b5e243c83 Mon Sep 17 00:00:00 2001 From: Joshua Cao Date: Tue, 30 May 2023 20:40:10 -0700 Subject: [PATCH] [SCEV][NFC] Refactor range computation for AddRec to pass around APInt --- llvm/include/llvm/Analysis/ScalarEvolution.h | 4 +- llvm/lib/Analysis/ScalarEvolution.cpp | 73 +++++++++++++++------------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 58f821e..2db7126 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1674,7 +1674,7 @@ private: /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}. /// Helper for \c getRange. ConstantRange getRangeForAffineAR(const SCEV *Start, const SCEV *Step, - const SCEV *MaxBECount, unsigned BitWidth); + const APInt &MaxBECount); /// Determines the range for the affine non-self-wrapping SCEVAddRecExpr {\p /// Start,+,\p Step}. @@ -1687,7 +1687,7 @@ private: /// Step} by "factoring out" a ternary expression from the add recurrence. /// Helper called by \c getRange. ConstantRange getRangeViaFactoring(const SCEV *Start, const SCEV *Step, - const SCEV *MaxBECount, unsigned BitWidth); + const APInt &MaxBECount); /// If the unknown expression U corresponds to a simple recurrence, return /// a constant range which represents the entire recurrence. Note that diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index db8ac4f..d59d0bf 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6699,21 +6699,23 @@ const ConstantRange &ScalarEvolution::getRangeRef( // TODO: non-affine addrec if (AddRec->isAffine()) { - const SCEV *MaxBECount = + const SCEV *MaxBEScev = getConstantMaxBackedgeTakenCount(AddRec->getLoop()); - if (!isa(MaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - auto RangeFromAffine = getRangeForAffineAR( - AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, - BitWidth); - ConservativeResult = - ConservativeResult.intersectWith(RangeFromAffine, RangeType); + if (!isa(MaxBEScev)) { + APInt MaxBECount = cast(MaxBEScev)->getAPInt(); + if (MaxBECount.getBitWidth() < BitWidth) + MaxBECount = MaxBECount.zext(BitWidth); + if (MaxBECount.getBitWidth() == BitWidth) { + auto RangeFromAffine = getRangeForAffineAR( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount); + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine, RangeType); - auto RangeFromFactoring = getRangeViaFactoring( - AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, - BitWidth); - ConservativeResult = - ConservativeResult.intersectWith(RangeFromFactoring, RangeType); + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount); + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring, RangeType); + } } // Now try symbolic BE count and more powerful methods. @@ -6721,7 +6723,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( const SCEV *SymbolicMaxBECount = getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(SymbolicMaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth && AddRec->hasNoSelfWrap()) { auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR( AddRec, SymbolicMaxBECount, BitWidth, SignHint); @@ -6885,7 +6887,10 @@ const ConstantRange &ScalarEvolution::getRangeRef( static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, - unsigned BitWidth, bool Signed) { + bool Signed) { + unsigned BitWidth = Step.getBitWidth(); + assert(BitWidth == StartRange.getBitWidth() && + BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths"); // If either Step or MaxBECount is 0, then the expression won't change, and we // just need to return the initial range. if (Step == 0 || MaxBECount == 0) @@ -6944,14 +6949,11 @@ static ConstantRange getRangeForAffineARHelper(APInt Step, ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, const SCEV *Step, - const SCEV *MaxBECount, - unsigned BitWidth) { - assert(!isa(MaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && - "Precondition!"); - - MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); - APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount); + const APInt &MaxBECount) { + assert(getTypeSizeInBits(Start->getType()) == + getTypeSizeInBits(Step->getType()) && + getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() && + "mismatched bit widths"); // First, consider step signed. ConstantRange StartSRange = getSignedRange(Start); @@ -6959,17 +6961,16 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, // If Step can be both positive and negative, we need to find ranges for the // maximum absolute step values in both directions and union them. - ConstantRange SR = - getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, - MaxBECountValue, BitWidth, /* Signed = */ true); + ConstantRange SR = getRangeForAffineARHelper( + StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true); SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), - StartSRange, MaxBECountValue, - BitWidth, /* Signed = */ true)); + StartSRange, MaxBECount, + /* Signed = */ true)); // Next, consider step unsigned. ConstantRange UR = getRangeForAffineARHelper( - getUnsignedRangeMax(Step), getUnsignedRange(Start), - MaxBECountValue, BitWidth, /* Signed = */ false); + getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount, + /* Signed = */ false); // Finally, intersect signed and unsigned ranges. return SR.intersectWith(UR, ConstantRange::Smallest); @@ -7045,11 +7046,15 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, const SCEV *Step, - const SCEV *MaxBECount, - unsigned BitWidth) { + const APInt &MaxBECount) { // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + unsigned BitWidth = MaxBECount.getBitWidth(); + assert(getTypeSizeInBits(Start->getType()) == BitWidth && + getTypeSizeInBits(Step->getType()) == BitWidth && + "mismatched bit widths"); + struct SelectPattern { Value *Condition = nullptr; APInt TrueValue; @@ -7151,9 +7156,9 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); ConstantRange TrueRange = - this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); + this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount); ConstantRange FalseRange = - this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount); return TrueRange.unionWith(FalseRange); } -- 2.7.4