From f8eeeffadad33585027a489aaac79ff64d1e3464 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Sat, 21 Jan 2023 23:23:59 +0300 Subject: [PATCH] [NFC][SCEV] Reflow `computeSCEVAtScope()` into an exhaustive switch Otherwise instead of a compile-time error that you forgot to modify it, you'd get a run-time error, which happened every time i've added new expr. This is completely NFC, there are no other changes here. --- llvm/lib/Analysis/ScalarEvolution.cpp | 201 ++++++++++++++++++---------------- 1 file changed, 106 insertions(+), 95 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 6c71d69..8dcf11f 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9798,12 +9798,112 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { - if (isa(V)) + switch (V->getSCEVType()) { + case scConstant: return V; + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scPtrToInt: { + const SCEVCastExpr *Cast = cast(V); + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getCastExpr(Cast->getSCEVType(), Op, Cast->getType()); + } + case scUDivExpr: { + const SCEVUDivExpr *Div = cast(V); + const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); + const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); + if (LHS == Div->getLHS() && RHS == Div->getRHS()) + return Div; // must be loop invariant + return getUDivExpr(LHS, RHS); + } + case scAddRecExpr: { + // If this is a loop recurrence for a loop that does not contain L, then we + // are dealing with the final value computed by the loop. + const SCEVAddRecExpr *AddRec = cast(V); + // First, attempt to evaluate each operand. + // Avoid performing the look-up in the common case where the specified + // expression has no loop-variant portions. + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { + const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); + if (OpAtScope == AddRec->getOperand(i)) + continue; + + // Okay, at least one of these operands is loop variant but might be + // foldable. Build a new instance of the folded commutative expression. + SmallVector NewOps(AddRec->operands().take_front(i)); + NewOps.push_back(OpAtScope); + for (++i; i != e; ++i) + NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); + + const SCEV *FoldedRec = getAddRecExpr( + NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); + AddRec = dyn_cast(FoldedRec); + // The addrec may be folded to a nonrecurrence, for example, if the + // induction variable is multiplied by zero after constant folding. Go + // ahead and return the folded value. + if (!AddRec) + return FoldedRec; + break; + } + + // If the scope is outside the addrec's loop, evaluate it by using the + // loop exit value of the addrec. + if (!AddRec->getLoop()->contains(L)) { + // To evaluate this recurrence, we need to know how many times the AddRec + // loop iterates. Compute this now. + const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + if (BackedgeTakenCount == getCouldNotCompute()) + return AddRec; + + // Then, evaluate the AddRec. + return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); + } + + return AddRec; + } + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: + case scSequentialUMinExpr: { + const auto *Comm = cast(V); + // Avoid performing the look-up in the common case where the specified + // expression has no loop-variant portions. + for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { + const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + if (OpAtScope != Comm->getOperand(i)) { + // Okay, at least one of these operands is loop variant but might be + // foldable. Build a new instance of the folded commutative expression. + SmallVector NewOps(Comm->operands().take_front(i)); + NewOps.push_back(OpAtScope); - // If this instruction is evolved from a constant-evolving PHI, compute the - // exit value from the loop without using SCEVs. - if (const SCEVUnknown *SU = dyn_cast(V)) { + for (++i; i != e; ++i) { + OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + NewOps.push_back(OpAtScope); + } + if (isa(Comm)) + return getAddExpr(NewOps, Comm->getNoWrapFlags()); + if (isa(Comm)) + return getMulExpr(NewOps, Comm->getNoWrapFlags()); + if (isa(Comm)) + return getMinMaxExpr(Comm->getSCEVType(), NewOps); + if (isa(Comm)) + return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps); + llvm_unreachable("Unknown commutative / sequential min/max SCEV type!"); + } + } + // If we got here, all operands are loop invariant. + return Comm; + } + case scUnknown: { + // If this instruction is evolved from a constant-evolving PHI, compute the + // exit value from the loop without using SCEVs. + const SCEVUnknown *SU = cast(V); if (Instruction *I = dyn_cast(SU->getValue())) { if (PHINode *PN = dyn_cast(I)) { const Loop *CurrLoop = this->LI[I->getParent()]; @@ -9916,98 +10016,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // This is some other type of SCEVUnknown, just return it. return V; } - - if (isa(V) || isa(V)) { - const auto *Comm = cast(V); - // Avoid performing the look-up in the common case where the specified - // expression has no loop-variant portions. - for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { - const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); - if (OpAtScope != Comm->getOperand(i)) { - // Okay, at least one of these operands is loop variant but might be - // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps(Comm->operands().take_front(i)); - NewOps.push_back(OpAtScope); - - for (++i; i != e; ++i) { - OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); - NewOps.push_back(OpAtScope); - } - if (isa(Comm)) - return getAddExpr(NewOps, Comm->getNoWrapFlags()); - if (isa(Comm)) - return getMulExpr(NewOps, Comm->getNoWrapFlags()); - if (isa(Comm)) - return getMinMaxExpr(Comm->getSCEVType(), NewOps); - if (isa(Comm)) - return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps); - llvm_unreachable("Unknown commutative / sequential min/max SCEV type!"); - } - } - // If we got here, all operands are loop invariant. - return Comm; - } - - if (const SCEVUDivExpr *Div = dyn_cast(V)) { - const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); - const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); - if (LHS == Div->getLHS() && RHS == Div->getRHS()) - return Div; // must be loop invariant - return getUDivExpr(LHS, RHS); - } - - // If this is a loop recurrence for a loop that does not contain L, then we - // are dealing with the final value computed by the loop. - if (const SCEVAddRecExpr *AddRec = dyn_cast(V)) { - // First, attempt to evaluate each operand. - // Avoid performing the look-up in the common case where the specified - // expression has no loop-variant portions. - for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); - if (OpAtScope == AddRec->getOperand(i)) - continue; - - // Okay, at least one of these operands is loop variant but might be - // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps(AddRec->operands().take_front(i)); - NewOps.push_back(OpAtScope); - for (++i; i != e; ++i) - NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); - - const SCEV *FoldedRec = getAddRecExpr( - NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); - AddRec = dyn_cast(FoldedRec); - // The addrec may be folded to a nonrecurrence, for example, if the - // induction variable is multiplied by zero after constant folding. Go - // ahead and return the folded value. - if (!AddRec) - return FoldedRec; - break; - } - - // If the scope is outside the addrec's loop, evaluate it by using the - // loop exit value of the addrec. - if (!AddRec->getLoop()->contains(L)) { - // To evaluate this recurrence, we need to know how many times the AddRec - // loop iterates. Compute this now. - const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); - if (BackedgeTakenCount == getCouldNotCompute()) - return AddRec; - - // Then, evaluate the AddRec. - return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); - } - - return AddRec; - } - - if (const SCEVCastExpr *Cast = dyn_cast(V)) { - const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); - if (Op == Cast->getOperand()) - return Cast; // must be loop invariant - return getCastExpr(Cast->getSCEVType(), Op, Cast->getType()); + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); } - llvm_unreachable("Unknown SCEV type!"); } -- 2.7.4