From 650fc40b6d8d9a5869b4fca525d5f237b0ee2803 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Sat, 15 Jan 2022 00:33:43 +0300 Subject: [PATCH] [NFC][SCEV] Introduce `getCastExpr()` QoL helper --- llvm/include/llvm/Analysis/ScalarEvolution.h | 1 + llvm/lib/Analysis/ScalarEvolution.cpp | 41 ++++++++++++---------------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 43a4fbb..fd23ba7 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -568,6 +568,7 @@ public: const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); + const SCEV *getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty); const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index beba8c3..e2d5df8 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2117,6 +2117,22 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { return S; } +const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op, + Type *Ty) { + switch (Kind) { + case scTruncate: + return getTruncateExpr(Op, Ty); + case scZeroExtend: + return getZeroExtendExpr(Op, Ty); + case scSignExtend: + return getSignExtendExpr(Op, Ty); + case scPtrToInt: + return getPtrToIntExpr(Op, Ty); + default: + llvm_unreachable("Not a SCEV cast expression!"); + } +} + /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, @@ -9390,32 +9406,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { return AddRec; } - if (const SCEVZeroExtendExpr *Cast = dyn_cast(V)) { + if (const SCEVCastExpr *Cast = dyn_cast(V)) { const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant - return getZeroExtendExpr(Op, Cast->getType()); - } - - if (const SCEVSignExtendExpr *Cast = dyn_cast(V)) { - const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); - if (Op == Cast->getOperand()) - return Cast; // must be loop invariant - return getSignExtendExpr(Op, Cast->getType()); - } - - if (const SCEVTruncateExpr *Cast = dyn_cast(V)) { - const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); - if (Op == Cast->getOperand()) - return Cast; // must be loop invariant - return getTruncateExpr(Op, Cast->getType()); - } - - if (const SCEVPtrToIntExpr *Cast = dyn_cast(V)) { - const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); - if (Op == Cast->getOperand()) - return Cast; // must be loop invariant - return getPtrToIntExpr(Op, Cast->getType()); + return getCastExpr(Cast->getSCEVType(), Op, Cast->getType()); } llvm_unreachable("Unknown SCEV type!"); -- 2.7.4