[SVE] Fix invalid uses of VectorType::getNumElements() in ValueTracking
authorChristopher Tetreault <ctetreau@quicinc.com>
Wed, 6 May 2020 16:53:57 +0000 (09:53 -0700)
committerChristopher Tetreault <ctetreau@quicinc.com>
Wed, 6 May 2020 17:06:06 +0000 (10:06 -0700)
Summary:
Any function in this module that make use of DemandedElts laregely does
not work with scalable vectors. DemandedElts is used to define which
elements of the vector to look at. At best, for scalable vectors, we can
express the first N elements of the vector. However, in practice, most
code that uses these functions expect to be able to talk about the
entire vector. In principle, this module should be able to be extended
to work with scalable vectors. However, before we can do that, we should
ensure that it does not cause code with scalable vectors to miscompile.
All functions that use a DemandedElts will bail out if the vector is
scalable. Usages of getNumElements() are updated to go through
FixedVectorType pointers.

Reviewers: rengolin, efriedma, sdesmalen, c-rhodes, spatel

Reviewed By: efriedma

Subscribers: david-arm, tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79053

llvm/lib/Analysis/ValueTracking.cpp

index b02a4d2..b57a439 100644 (file)
@@ -206,11 +206,16 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts,
 
 static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
                              const Query &Q) {
-  Type *Ty = V->getType();
+  // FIXME: We currently have no way to represent the DemandedElts of a scalable
+  // vector
+  if (isa<ScalableVectorType>(V->getType())) {
+    Known.resetAll();
+    return;
+  }
+
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
   APInt DemandedElts =
-      Ty->isVectorTy()
-          ? APInt::getAllOnesValue(cast<VectorType>(Ty)->getNumElements())
-          : APInt(1, 1);
+      FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1);
   computeKnownBits(V, DemandedElts, Known, Depth, Q);
 }
 
@@ -374,11 +379,14 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
 
 static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
                                    const Query &Q) {
-  Type *Ty = V->getType();
+  // FIXME: We currently have no way to represent the DemandedElts of a scalable
+  // vector
+  if (isa<ScalableVectorType>(V->getType()))
+    return 1;
+
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
   APInt DemandedElts =
-      Ty->isVectorTy()
-          ? APInt::getAllOnesValue(cast<VectorType>(Ty)->getNumElements())
-          : APInt(1, 1);
+      FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1);
   return ComputeNumSignBits(V, DemandedElts, Depth, Q);
 }
 
@@ -1808,7 +1816,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
     const Value *Vec = I->getOperand(0);
     const Value *Idx = I->getOperand(1);
     auto *CIdx = dyn_cast<ConstantInt>(Idx);
-    unsigned NumElts = cast<VectorType>(Vec->getType())->getNumElements();
+    if (isa<ScalableVectorType>(Vec->getType())) {
+      // FIXME: there's probably *something* we can do with scalable vectors
+      Known.resetAll();
+      break;
+    }
+    unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
     APInt DemandedVecElts = APInt::getAllOnesValue(NumElts);
     if (CIdx && CIdx->getValue().ult(NumElts))
       DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
@@ -1880,30 +1893,41 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) {
 /// for all of the demanded elements in the vector specified by DemandedElts.
 void computeKnownBits(const Value *V, const APInt &DemandedElts,
                       KnownBits &Known, unsigned Depth, const Query &Q) {
+  if (!DemandedElts || isa<ScalableVectorType>(V->getType())) {
+    // No demanded elts or V is a scalable vector, better to assume we don't
+    // know anything.
+    Known.resetAll();
+    return;
+  }
+
   assert(V && "No Value?");
   assert(Depth <= MaxDepth && "Limit Search Depth");
-  unsigned BitWidth = Known.getBitWidth();
 
+#ifndef NDEBUG
   Type *Ty = V->getType();
+  unsigned BitWidth = Known.getBitWidth();
+
   assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
          "Not integer or pointer type!");
-  assert(((Ty->isVectorTy() && cast<VectorType>(Ty)->getNumElements() ==
-                                   DemandedElts.getBitWidth()) ||
-          (!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) &&
-         "Unexpected vector size");
+
+  if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
+    assert(
+        FVTy->getNumElements() == DemandedElts.getBitWidth() &&
+        "DemandedElt width should equal the fixed vector number of elements");
+  } else {
+    assert(DemandedElts == APInt(1, 1) &&
+           "DemandedElt width should be 1 for scalars");
+  }
 
   Type *ScalarTy = Ty->getScalarType();
-  unsigned ExpectedWidth = ScalarTy->isPointerTy() ?
-    Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy);
-  assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth");
-  (void)BitWidth;
-  (void)ExpectedWidth;
-
-  if (!DemandedElts) {
-    // No demanded elts, better to assume we don't know anything.
-    Known.resetAll();
-    return;
+  if (ScalarTy->isPointerTy()) {
+    assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
+           "V and Known should have same BitWidth");
+  } else {
+    assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
+           "V and Known should have same BitWidth");
   }
+#endif
 
   const APInt *C;
   if (match(V, m_APInt(C))) {
@@ -1919,17 +1943,14 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
   }
   // Handle a constant vector by taking the intersection of the known bits of
   // each element.
-  if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) {
-    assert((!Ty->isVectorTy() ||
-            CDS->getNumElements() == DemandedElts.getBitWidth()) &&
-           "Unexpected vector size");
-    // We know that CDS must be a vector of integers. Take the intersection of
+  if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
+    // We know that CDV must be a vector of integers. Take the intersection of
     // each element.
     Known.Zero.setAllBits(); Known.One.setAllBits();
-    for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) {
-      if (Ty->isVectorTy() && !DemandedElts[i])
+    for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
+      if (!DemandedElts[i])
         continue;
-      APInt Elt = CDS->getElementAsAPInt(i);
+      APInt Elt = CDV->getElementAsAPInt(i);
       Known.Zero &= ~Elt;
       Known.One &= Elt;
     }
@@ -1937,8 +1958,6 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
   }
 
   if (const auto *CV = dyn_cast<ConstantVector>(V)) {
-    assert(CV->getNumOperands() == DemandedElts.getBitWidth() &&
-           "Unexpected vector size");
     // We know that CV must be a vector of integers. Take the intersection of
     // each element.
     Known.Zero.setAllBits(); Known.One.setAllBits();
@@ -1986,7 +2005,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
     computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q);
 
   // Aligned pointers have trailing zeros - refine Known.Zero set
-  if (Ty->isPointerTy()) {
+  if (isa<PointerType>(V->getType())) {
     const MaybeAlign Align = V->getPointerAlignment(Q.DL);
     if (Align)
       Known.Zero.setLowBits(countTrailingZeros(Align->value()));
@@ -2274,6 +2293,11 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value)
 /// Supports values with integer or pointer type and vectors of integers.
 bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
                     const Query &Q) {
+  // FIXME: We currently have no way to represent the DemandedElts of a scalable
+  // vector
+  if (isa<ScalableVectorType>(V->getType()))
+    return false;
+
   if (auto *C = dyn_cast<Constant>(V)) {
     if (C->isNullValue())
       return false;
@@ -2292,7 +2316,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
 
     // For constant vectors, check that all elements are undefined or known
     // non-zero to determine that the whole vector is known non-zero.
-    if (auto *VecTy = dyn_cast<VectorType>(C->getType())) {
+    if (auto *VecTy = dyn_cast<FixedVectorType>(C->getType())) {
       for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
         if (!DemandedElts[i])
           continue;
@@ -2527,7 +2551,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
     const Value *Vec = EEI->getVectorOperand();
     const Value *Idx = EEI->getIndexOperand();
     auto *CIdx = dyn_cast<ConstantInt>(Idx);
-    unsigned NumElts = cast<VectorType>(Vec->getType())->getNumElements();
+    unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
     APInt DemandedVecElts = APInt::getAllOnesValue(NumElts);
     if (CIdx && CIdx->getValue().ult(NumElts))
       DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
@@ -2540,11 +2564,14 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
 }
 
 bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) {
-  Type *Ty = V->getType();
+  // FIXME: We currently have no way to represent the DemandedElts of a scalable
+  // vector
+  if (isa<ScalableVectorType>(V->getType()))
+    return false;
+
+  auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
   APInt DemandedElts =
-      Ty->isVectorTy()
-          ? APInt::getAllOnesValue(cast<VectorType>(Ty)->getNumElements())
-          : APInt(1, 1);
+      FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1);
   return isKnownNonZero(V, DemandedElts, Depth, Q);
 }
 
@@ -2641,11 +2668,11 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V,
                                                  const APInt &DemandedElts,
                                                  unsigned TyBits) {
   const auto *CV = dyn_cast<Constant>(V);
-  if (!CV || !CV->getType()->isVectorTy())
+  if (!CV || !isa<FixedVectorType>(CV->getType()))
     return 0;
 
   unsigned MinSignBits = TyBits;
-  unsigned NumElts = cast<VectorType>(CV->getType())->getNumElements();
+  unsigned NumElts = cast<FixedVectorType>(CV->getType())->getNumElements();
   for (unsigned i = 0; i != NumElts; ++i) {
     if (!DemandedElts[i])
       continue;
@@ -2681,18 +2708,30 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
 static unsigned ComputeNumSignBitsImpl(const Value *V,
                                        const APInt &DemandedElts,
                                        unsigned Depth, const Query &Q) {
+  Type *Ty = V->getType();
+
+  // FIXME: We currently have no way to represent the DemandedElts of a scalable
+  // vector
+  if (isa<ScalableVectorType>(Ty))
+    return 1;
+
+#ifndef NDEBUG
   assert(Depth <= MaxDepth && "Limit Search Depth");
 
+  if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
+    assert(
+        FVTy->getNumElements() == DemandedElts.getBitWidth() &&
+        "DemandedElt width should equal the fixed vector number of elements");
+  } else {
+    assert(DemandedElts == APInt(1, 1) &&
+           "DemandedElt width should be 1 for scalars");
+  }
+#endif
+
   // We return the minimum number of sign bits that are guaranteed to be present
   // in V, so for undef we have to conservatively return 1.  We don't have the
   // same behavior for poison though -- that's a FIXME today.
 
-  Type *Ty = V->getType();
-  assert(((Ty->isVectorTy() && cast<VectorType>(Ty)->getNumElements() ==
-                                   DemandedElts.getBitWidth()) ||
-          (!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) &&
-         "Unexpected vector size");
-
   Type *ScalarTy = Ty->getScalarType();
   unsigned TyBits = ScalarTy->isPointerTy() ?
     Q.DL.getPointerTypeSizeInBits(ScalarTy) :
@@ -3266,8 +3305,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V,
 
   // Handle vector of constants.
   if (auto *CV = dyn_cast<Constant>(V)) {
-    if (auto *CVVTy = dyn_cast<VectorType>(CV->getType())) {
-      unsigned NumElts = CVVTy->getNumElements();
+    if (auto *CVFVTy = dyn_cast<FixedVectorType>(CV->getType())) {
+      unsigned NumElts = CVFVTy->getNumElements();
       for (unsigned i = 0; i != NumElts; ++i) {
         auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
         if (!CFP)
@@ -3438,24 +3477,26 @@ bool llvm::isKnownNeverInfinity(const Value *V, const TargetLibraryInfo *TLI,
     }
   }
 
-  // Bail out for constant expressions, but try to handle vector constants.
-  if (!V->getType()->isVectorTy() || !isa<Constant>(V))
-    return false;
-
-  // For vectors, verify that each element is not infinity.
-  unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
-  for (unsigned i = 0; i != NumElts; ++i) {
-    Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
-    if (!Elt)
-      return false;
-    if (isa<UndefValue>(Elt))
-      continue;
-    auto *CElt = dyn_cast<ConstantFP>(Elt);
-    if (!CElt || CElt->isInfinity())
-      return false;
+  // try to handle fixed width vector constants
+  if (isa<FixedVectorType>(V->getType()) && isa<Constant>(V)) {
+    // For vectors, verify that each element is not infinity.
+    unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
+      if (!Elt)
+        return false;
+      if (isa<UndefValue>(Elt))
+        continue;
+      auto *CElt = dyn_cast<ConstantFP>(Elt);
+      if (!CElt || CElt->isInfinity())
+        return false;
+    }
+    // All elements were confirmed non-infinity or undefined.
+    return true;
   }
-  // All elements were confirmed non-infinity or undefined.
-  return true;
+
+  // was not able to prove that V never contains infinity
+  return false;
 }
 
 bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI,
@@ -3539,24 +3580,26 @@ bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI,
     }
   }
 
-  // Bail out for constant expressions, but try to handle vector constants.
-  if (!V->getType()->isVectorTy() || !isa<Constant>(V))
-    return false;
-
-  // For vectors, verify that each element is not NaN.
-  unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
-  for (unsigned i = 0; i != NumElts; ++i) {
-    Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
-    if (!Elt)
-      return false;
-    if (isa<UndefValue>(Elt))
-      continue;
-    auto *CElt = dyn_cast<ConstantFP>(Elt);
-    if (!CElt || CElt->isNaN())
-      return false;
+  // Try to handle fixed width vector constants
+  if (isa<FixedVectorType>(V->getType()) && isa<Constant>(V)) {
+    // For vectors, verify that each element is not NaN.
+    unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *Elt = cast<Constant>(V)->getAggregateElement(i);
+      if (!Elt)
+        return false;
+      if (isa<UndefValue>(Elt))
+        continue;
+      auto *CElt = dyn_cast<ConstantFP>(Elt);
+      if (!CElt || CElt->isNaN())
+        return false;
+    }
+    // All elements were confirmed not-NaN or undefined.
+    return true;
   }
-  // All elements were confirmed not-NaN or undefined.
-  return true;
+
+  // Was not able to prove that V never contains NaN
+  return false;
 }
 
 Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
@@ -4634,11 +4677,13 @@ bool llvm::canCreatePoison(const Instruction *I) {
     // Shifts return poison if shiftwidth is larger than the bitwidth.
     if (auto *C = dyn_cast<Constant>(I->getOperand(1))) {
       SmallVector<Constant *, 4> ShiftAmounts;
-      if (C->getType()->isVectorTy()) {
-        unsigned NumElts = cast<VectorType>(C->getType())->getNumElements();
+      if (auto *FVTy = dyn_cast<FixedVectorType>(C->getType())) {
+        unsigned NumElts = FVTy->getNumElements();
         for (unsigned i = 0; i < NumElts; ++i)
           ShiftAmounts.push_back(C->getAggregateElement(i));
-      } else
+      } else if (isa<ScalableVectorType>(C->getType()))
+        return true; // Can't tell, just return true to be safe
+      else
         ShiftAmounts.push_back(C);
 
       bool Safe = llvm::all_of(ShiftAmounts, [](Constant *C) {