From 7b26dcae9eaf8cdcba7fef032fd83d060dffd4b4 Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Thu, 2 Mar 2023 11:59:50 +0000 Subject: [PATCH] Revert "[SCEV] Add SCEVType to represent `vscale`." This reverts commit 7912f5cc92f65ad0d3c705f3683a0b69dbedcc57. --- llvm/include/llvm/Analysis/ScalarEvolution.h | 1 - .../llvm/Analysis/ScalarEvolutionDivision.h | 2 - .../llvm/Analysis/ScalarEvolutionExpressions.h | 26 +-------- .../Transforms/Utils/ScalarEvolutionExpander.h | 2 - llvm/lib/Analysis/ScalarEvolution.cpp | 66 ++++++++-------------- llvm/lib/Analysis/ScalarEvolutionDivision.cpp | 4 -- llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 4 +- .../Transforms/Utils/ScalarEvolutionExpander.cpp | 7 --- 8 files changed, 27 insertions(+), 85 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 6cb8fec..0201942 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -566,7 +566,6 @@ public: const SCEV *getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth = 0); const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty); const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getVScale(Type *Ty); const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth = 0); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h index 3283d43..7d5902d 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h @@ -48,8 +48,6 @@ public: void visitConstant(const SCEVConstant *Numerator); - void visitVScale(const SCEVVScale *Numerator); - void visitAddRecExpr(const SCEVAddRecExpr *Numerator); void visitAddExpr(const SCEVAddExpr *Numerator); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 0a1c900..1b14d74 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -39,7 +39,6 @@ enum SCEVTypes : unsigned short { // These should be ordered in terms of increasing complexity to make the // folders simpler. scConstant, - scVScale, scTruncate, scZeroExtend, scSignExtend, @@ -76,23 +75,6 @@ public: static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } }; -/// This class represents the value of vscale, as used when defining the length -/// of a scalable vector or returned by the llvm.vscale() intrinsic. -class SCEVVScale : public SCEV { - friend class ScalarEvolution; - - SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty) - : SCEV(ID, scVScale, 0), Ty(ty) {} - - Type *Ty; - -public: - Type *getType() const { return Ty; } - - /// Methods for support type inquiry through isa, cast, and dyn_cast: - static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } -}; - inline unsigned short computeExpressionSize(ArrayRef Args) { APInt Size(16, 1); for (const auto *Arg : Args) @@ -597,6 +579,9 @@ class SCEVUnknown final : public SCEV, private CallbackVH { public: Value *getValue() const { return getValPtr(); } + /// Check whether this represents vscale. + bool isVScale() const; + Type *getType() const { return getValPtr()->getType(); } /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -610,8 +595,6 @@ template struct SCEVVisitor { switch (S->getSCEVType()) { case scConstant: return ((SC *)this)->visitConstant((const SCEVConstant *)S); - case scVScale: - return ((SC *)this)->visitVScale((const SCEVVScale *)S); case scPtrToInt: return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); case scTruncate: @@ -679,7 +662,6 @@ public: switch (S->getSCEVType()) { case scConstant: - case scVScale: case scUnknown: continue; case scPtrToInt: @@ -769,8 +751,6 @@ public: const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } - const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } - const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h index 5558970..131e24f 100644 --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -457,8 +457,6 @@ private: Value *visitConstant(const SCEVConstant *S) { return S->getValue(); } - Value *visitVScale(const SCEVVScale *S); - Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S); Value *visitTruncateExpr(const SCEVTruncateExpr *S); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index b074294..f997b19 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -271,9 +271,6 @@ void SCEV::print(raw_ostream &OS) const { case scConstant: cast(this)->getValue()->printAsOperand(OS, false); return; - case scVScale: - OS << "vscale"; - return; case scPtrToInt: { const SCEVPtrToIntExpr *PtrToInt = cast(this); const SCEV *Op = PtrToInt->getOperand(); @@ -369,9 +366,17 @@ void SCEV::print(raw_ostream &OS) const { OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; return; } - case scUnknown: - cast(this)->getValue()->printAsOperand(OS, false); + case scUnknown: { + const SCEVUnknown *U = cast(this); + if (U->isVScale()) { + OS << "vscale"; + return; + } + + // Otherwise just print it normally. + U->getValue()->printAsOperand(OS, false); return; + } case scCouldNotCompute: OS << "***COULDNOTCOMPUTE***"; return; @@ -383,8 +388,6 @@ Type *SCEV::getType() const { switch (getSCEVType()) { case scConstant: return cast(this)->getType(); - case scVScale: - return cast(this)->getType(); case scPtrToInt: case scTruncate: case scZeroExtend: @@ -416,7 +419,6 @@ Type *SCEV::getType() const { ArrayRef SCEV::operands() const { switch (getSCEVType()) { case scConstant: - case scVScale: case scUnknown: return {}; case scPtrToInt: @@ -499,18 +501,6 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { return getConstant(ConstantInt::get(ITy, V, isSigned)); } -const SCEV *ScalarEvolution::getVScale(Type *Ty) { - FoldingSetNodeID ID; - ID.AddInteger(scVScale); - ID.AddPointer(Ty); - void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) - return S; - SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); - UniqueSCEVs.InsertNode(S, IP); - return S; -} - SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} @@ -570,6 +560,10 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) { setValPtr(New); } +bool SCEVUnknown::isVScale() const { + return match(getValue(), m_VScale()); +} + //===----------------------------------------------------------------------===// // SCEV Utilities //===----------------------------------------------------------------------===// @@ -720,12 +714,6 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, return LA.ult(RA) ? -1 : 1; } - case scVScale: { - const auto *LTy = cast(cast(LHS)->getType()); - const auto *RTy = cast(cast(RHS)->getType()); - return LTy->getBitWidth() - RTy->getBitWidth(); - } - case scAddRecExpr: { const SCEVAddRecExpr *LA = cast(LHS); const SCEVAddRecExpr *RA = cast(RHS); @@ -4027,8 +4015,6 @@ public: RetVal visitConstant(const SCEVConstant *Constant) { return Constant; } - RetVal visitVScale(const SCEVVScale *VScale) { return VScale; } - RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; } RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; } @@ -4075,7 +4061,6 @@ public: static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) { switch (Kind) { case scConstant: - case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4119,7 +4104,6 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) { switch (S->getSCEVType()) { case scConstant: - case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4331,8 +4315,15 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, const SCEV * ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); - if (Size.isScalable()) - Res = getMulExpr(Res, getVScale(IntTy)); + if (Size.isScalable()) { + // TODO: Why is there no ConstantExpr::getVScale()? + Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1); + Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo()); + Constant *One = ConstantInt::get(IntTy, 1); + Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One); + Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy); + Res = getMulExpr(Res, getUnknown(VScale)); + } return Res; } @@ -5896,7 +5887,6 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, bool follow(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: - case scVScale: case scPtrToInt: case scTruncate: case scZeroExtend: @@ -6284,8 +6274,6 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: return cast(S)->getAPInt().countr_zero(); - case scVScale: - return 0; case scTruncate: { const SCEVTruncateExpr *T = cast(S); return std::min(GetMinTrailingZeros(T->getOperand()), @@ -6516,7 +6504,6 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, break; [[fallthrough]]; case scConstant: - case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -6620,8 +6607,6 @@ const ConstantRange &ScalarEvolution::getRangeRef( switch (S->getSCEVType()) { case scConstant: llvm_unreachable("Already handled above."); - case scVScale: - return setRange(S, SignHint, std::move(ConservativeResult)); case scTruncate: { const SCEVTruncateExpr *Trunc = cast(S); ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1); @@ -9726,7 +9711,6 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { switch (V->getSCEVType()) { case scCouldNotCompute: case scAddRecExpr: - case scVScale: return nullptr; case scConstant: return cast(V)->getValue(); @@ -9810,7 +9794,6 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { switch (V->getSCEVType()) { case scConstant: - case scVScale: return V; case scAddRecExpr: { // If this is a loop recurrence for a loop that does not contain L, then we @@ -9909,7 +9892,6 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { case scSequentialUMinExpr: return getSequentialMinMaxExpr(V->getSCEVType(), NewOps); case scConstant: - case scVScale: case scAddRecExpr: case scUnknown: case scCouldNotCompute: @@ -13695,7 +13677,6 @@ ScalarEvolution::LoopDisposition ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { switch (S->getSCEVType()) { case scConstant: - case scVScale: return LoopInvariant; case scAddRecExpr: { const SCEVAddRecExpr *AR = cast(S); @@ -13794,7 +13775,6 @@ ScalarEvolution::BlockDisposition ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { switch (S->getSCEVType()) { case scConstant: - case scVScale: return ProperlyDominatesBlock; case scAddRecExpr: { // This uses a "dominates" query instead of "properly dominates" query diff --git a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp index e1dd834..0619569 100644 --- a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp @@ -126,10 +126,6 @@ void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { } } -void SCEVDivision::visitVScale(const SCEVVScale *Numerator) { - return cannotDivide(Numerator); -} - void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; if (!Numerator->isAffine()) diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index e5da065..67c404a 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -976,7 +976,6 @@ static bool isHighCostExpansion(const SCEV *S, switch (S->getSCEVType()) { case scUnknown: case scConstant: - case scVScale: return false; case scTruncate: return isHighCostExpansion(cast(S)->getOperand(), @@ -2813,10 +2812,9 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) { /// SCEVUnknown, we simply return the rightmost SCEV operand. static const SCEV *getExprBase(const SCEV *S) { switch (S->getSCEVType()) { - default: // including scUnknown. + default: // uncluding scUnknown. return S; case scConstant: - case scVScale: return nullptr; case scTruncate: return getExprBase(cast(S)->getOperand()); diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 902eee2..24f1966 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -680,7 +680,6 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: - case scVScale: return nullptr; // A constant has no relevant loops. case scTruncate: case scZeroExtend: @@ -1745,10 +1744,6 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true); } -Value *SCEVExpander::visitVScale(const SCEVVScale *S) { - return Builder.CreateVScale(ConstantInt::get(S->getType(), 1)); -} - Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP) { setInsertPoint(IP); @@ -2129,7 +2124,6 @@ template static InstructionCost costAndCollectOperands( llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); case scUnknown: case scConstant: - case scVScale: return 0; case scPtrToInt: Cost = CastCost(Instruction::PtrToInt); @@ -2266,7 +2260,6 @@ bool SCEVExpander::isHighCostExpansionHelper( case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); case scUnknown: - case scVScale: // Assume to be zero-cost. return false; case scConstant: { -- 2.7.4