From 782231ac799fd9c58d317c7ba168495510995df8 Mon Sep 17 00:00:00 2001 From: Christopher Tetreault Date: Wed, 6 May 2020 09:53:57 -0700 Subject: [PATCH] [SVE] Fix invalid uses of VectorType::getNumElements() in ValueTracking Summary: Any function in this module that make use of DemandedElts laregely does not work with scalable vectors. DemandedElts is used to define which elements of the vector to look at. At best, for scalable vectors, we can express the first N elements of the vector. However, in practice, most code that uses these functions expect to be able to talk about the entire vector. In principle, this module should be able to be extended to work with scalable vectors. However, before we can do that, we should ensure that it does not cause code with scalable vectors to miscompile. All functions that use a DemandedElts will bail out if the vector is scalable. Usages of getNumElements() are updated to go through FixedVectorType pointers. Reviewers: rengolin, efriedma, sdesmalen, c-rhodes, spatel Reviewed By: efriedma Subscribers: david-arm, tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D79053 --- llvm/lib/Analysis/ValueTracking.cpp | 221 ++++++++++++++++++++++-------------- 1 file changed, 133 insertions(+), 88 deletions(-) diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index b02a4d2..b57a439 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -206,11 +206,16 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts, static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { - Type *Ty = V->getType(); + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa(V->getType())) { + Known.resetAll(); + return; + } + + auto *FVTy = dyn_cast(V->getType()); APInt DemandedElts = - Ty->isVectorTy() - ? APInt::getAllOnesValue(cast(Ty)->getNumElements()) - : APInt(1, 1); + FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1); computeKnownBits(V, DemandedElts, Known, Depth, Q); } @@ -374,11 +379,14 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { - Type *Ty = V->getType(); + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa(V->getType())) + return 1; + + auto *FVTy = dyn_cast(V->getType()); APInt DemandedElts = - Ty->isVectorTy() - ? APInt::getAllOnesValue(cast(Ty)->getNumElements()) - : APInt(1, 1); + FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1); return ComputeNumSignBits(V, DemandedElts, Depth, Q); } @@ -1808,7 +1816,12 @@ static void computeKnownBitsFromOperator(const Operator *I, const Value *Vec = I->getOperand(0); const Value *Idx = I->getOperand(1); auto *CIdx = dyn_cast(Idx); - unsigned NumElts = cast(Vec->getType())->getNumElements(); + if (isa(Vec->getType())) { + // FIXME: there's probably *something* we can do with scalable vectors + Known.resetAll(); + break; + } + unsigned NumElts = cast(Vec->getType())->getNumElements(); APInt DemandedVecElts = APInt::getAllOnesValue(NumElts); if (CIdx && CIdx->getValue().ult(NumElts)) DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); @@ -1880,30 +1893,41 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { /// for all of the demanded elements in the vector specified by DemandedElts. void computeKnownBits(const Value *V, const APInt &DemandedElts, KnownBits &Known, unsigned Depth, const Query &Q) { + if (!DemandedElts || isa(V->getType())) { + // No demanded elts or V is a scalable vector, better to assume we don't + // know anything. + Known.resetAll(); + return; + } + assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); - unsigned BitWidth = Known.getBitWidth(); +#ifndef NDEBUG Type *Ty = V->getType(); + unsigned BitWidth = Known.getBitWidth(); + assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) && "Not integer or pointer type!"); - assert(((Ty->isVectorTy() && cast(Ty)->getNumElements() == - DemandedElts.getBitWidth()) || - (!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) && - "Unexpected vector size"); + + if (auto *FVTy = dyn_cast(Ty)) { + assert( + FVTy->getNumElements() == DemandedElts.getBitWidth() && + "DemandedElt width should equal the fixed vector number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } Type *ScalarTy = Ty->getScalarType(); - unsigned ExpectedWidth = ScalarTy->isPointerTy() ? - Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); - assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); - (void)BitWidth; - (void)ExpectedWidth; - - if (!DemandedElts) { - // No demanded elts, better to assume we don't know anything. - Known.resetAll(); - return; + if (ScalarTy->isPointerTy()) { + assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); + } else { + assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); } +#endif const APInt *C; if (match(V, m_APInt(C))) { @@ -1919,17 +1943,14 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts, } // Handle a constant vector by taking the intersection of the known bits of // each element. - if (const ConstantDataSequential *CDS = dyn_cast(V)) { - assert((!Ty->isVectorTy() || - CDS->getNumElements() == DemandedElts.getBitWidth()) && - "Unexpected vector size"); - // We know that CDS must be a vector of integers. Take the intersection of + if (const ConstantDataVector *CDV = dyn_cast(V)) { + // We know that CDV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); - for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { - if (Ty->isVectorTy() && !DemandedElts[i]) + for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) { + if (!DemandedElts[i]) continue; - APInt Elt = CDS->getElementAsAPInt(i); + APInt Elt = CDV->getElementAsAPInt(i); Known.Zero &= ~Elt; Known.One &= Elt; } @@ -1937,8 +1958,6 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts, } if (const auto *CV = dyn_cast(V)) { - assert(CV->getNumOperands() == DemandedElts.getBitWidth() && - "Unexpected vector size"); // We know that CV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); @@ -1986,7 +2005,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts, computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q); // Aligned pointers have trailing zeros - refine Known.Zero set - if (Ty->isPointerTy()) { + if (isa(V->getType())) { const MaybeAlign Align = V->getPointerAlignment(Q.DL); if (Align) Known.Zero.setLowBits(countTrailingZeros(Align->value())); @@ -2274,6 +2293,11 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) /// Supports values with integer or pointer type and vectors of integers. bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, const Query &Q) { + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa(V->getType())) + return false; + if (auto *C = dyn_cast(V)) { if (C->isNullValue()) return false; @@ -2292,7 +2316,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, // For constant vectors, check that all elements are undefined or known // non-zero to determine that the whole vector is known non-zero. - if (auto *VecTy = dyn_cast(C->getType())) { + if (auto *VecTy = dyn_cast(C->getType())) { for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { if (!DemandedElts[i]) continue; @@ -2527,7 +2551,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, const Value *Vec = EEI->getVectorOperand(); const Value *Idx = EEI->getIndexOperand(); auto *CIdx = dyn_cast(Idx); - unsigned NumElts = cast(Vec->getType())->getNumElements(); + unsigned NumElts = cast(Vec->getType())->getNumElements(); APInt DemandedVecElts = APInt::getAllOnesValue(NumElts); if (CIdx && CIdx->getValue().ult(NumElts)) DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); @@ -2540,11 +2564,14 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, } bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) { - Type *Ty = V->getType(); + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa(V->getType())) + return false; + + auto *FVTy = dyn_cast(V->getType()); APInt DemandedElts = - Ty->isVectorTy() - ? APInt::getAllOnesValue(cast(Ty)->getNumElements()) - : APInt(1, 1); + FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1); return isKnownNonZero(V, DemandedElts, Depth, Q); } @@ -2641,11 +2668,11 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V, const APInt &DemandedElts, unsigned TyBits) { const auto *CV = dyn_cast(V); - if (!CV || !CV->getType()->isVectorTy()) + if (!CV || !isa(CV->getType())) return 0; unsigned MinSignBits = TyBits; - unsigned NumElts = cast(CV->getType())->getNumElements(); + unsigned NumElts = cast(CV->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { if (!DemandedElts[i]) continue; @@ -2681,18 +2708,30 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, static unsigned ComputeNumSignBitsImpl(const Value *V, const APInt &DemandedElts, unsigned Depth, const Query &Q) { + Type *Ty = V->getType(); + + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa(Ty)) + return 1; + +#ifndef NDEBUG assert(Depth <= MaxDepth && "Limit Search Depth"); + if (auto *FVTy = dyn_cast(Ty)) { + assert( + FVTy->getNumElements() == DemandedElts.getBitWidth() && + "DemandedElt width should equal the fixed vector number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } +#endif + // We return the minimum number of sign bits that are guaranteed to be present // in V, so for undef we have to conservatively return 1. We don't have the // same behavior for poison though -- that's a FIXME today. - Type *Ty = V->getType(); - assert(((Ty->isVectorTy() && cast(Ty)->getNumElements() == - DemandedElts.getBitWidth()) || - (!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) && - "Unexpected vector size"); - Type *ScalarTy = Ty->getScalarType(); unsigned TyBits = ScalarTy->isPointerTy() ? Q.DL.getPointerTypeSizeInBits(ScalarTy) : @@ -3266,8 +3305,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, // Handle vector of constants. if (auto *CV = dyn_cast(V)) { - if (auto *CVVTy = dyn_cast(CV->getType())) { - unsigned NumElts = CVVTy->getNumElements(); + if (auto *CVFVTy = dyn_cast(CV->getType())) { + unsigned NumElts = CVFVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { auto *CFP = dyn_cast_or_null(CV->getAggregateElement(i)); if (!CFP) @@ -3438,24 +3477,26 @@ bool llvm::isKnownNeverInfinity(const Value *V, const TargetLibraryInfo *TLI, } } - // Bail out for constant expressions, but try to handle vector constants. - if (!V->getType()->isVectorTy() || !isa(V)) - return false; - - // For vectors, verify that each element is not infinity. - unsigned NumElts = cast(V->getType())->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast(V)->getAggregateElement(i); - if (!Elt) - return false; - if (isa(Elt)) - continue; - auto *CElt = dyn_cast(Elt); - if (!CElt || CElt->isInfinity()) - return false; + // try to handle fixed width vector constants + if (isa(V->getType()) && isa(V)) { + // For vectors, verify that each element is not infinity. + unsigned NumElts = cast(V->getType())->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = cast(V)->getAggregateElement(i); + if (!Elt) + return false; + if (isa(Elt)) + continue; + auto *CElt = dyn_cast(Elt); + if (!CElt || CElt->isInfinity()) + return false; + } + // All elements were confirmed non-infinity or undefined. + return true; } - // All elements were confirmed non-infinity or undefined. - return true; + + // was not able to prove that V never contains infinity + return false; } bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, @@ -3539,24 +3580,26 @@ bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, } } - // Bail out for constant expressions, but try to handle vector constants. - if (!V->getType()->isVectorTy() || !isa(V)) - return false; - - // For vectors, verify that each element is not NaN. - unsigned NumElts = cast(V->getType())->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast(V)->getAggregateElement(i); - if (!Elt) - return false; - if (isa(Elt)) - continue; - auto *CElt = dyn_cast(Elt); - if (!CElt || CElt->isNaN()) - return false; + // Try to handle fixed width vector constants + if (isa(V->getType()) && isa(V)) { + // For vectors, verify that each element is not NaN. + unsigned NumElts = cast(V->getType())->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = cast(V)->getAggregateElement(i); + if (!Elt) + return false; + if (isa(Elt)) + continue; + auto *CElt = dyn_cast(Elt); + if (!CElt || CElt->isNaN()) + return false; + } + // All elements were confirmed not-NaN or undefined. + return true; } - // All elements were confirmed not-NaN or undefined. - return true; + + // Was not able to prove that V never contains NaN + return false; } Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { @@ -4634,11 +4677,13 @@ bool llvm::canCreatePoison(const Instruction *I) { // Shifts return poison if shiftwidth is larger than the bitwidth. if (auto *C = dyn_cast(I->getOperand(1))) { SmallVector ShiftAmounts; - if (C->getType()->isVectorTy()) { - unsigned NumElts = cast(C->getType())->getNumElements(); + if (auto *FVTy = dyn_cast(C->getType())) { + unsigned NumElts = FVTy->getNumElements(); for (unsigned i = 0; i < NumElts; ++i) ShiftAmounts.push_back(C->getAggregateElement(i)); - } else + } else if (isa(C->getType())) + return true; // Can't tell, just return true to be safe + else ShiftAmounts.push_back(C); bool Safe = llvm::all_of(ShiftAmounts, [](Constant *C) { -- 2.7.4