[TTI] Change TargetTransformInfo::getMinimumVF to return ElementCount
authorSander de Smalen <sander.desmalen@arm.com>
Thu, 11 Feb 2021 08:50:08 +0000 (08:50 +0000)
committerSander de Smalen <sander.desmalen@arm.com>
Thu, 11 Feb 2021 09:08:48 +0000 (09:08 +0000)
This will be needed in the loop-vectorizer where the minimum VF
requested may be a scalable VF. getMinimumVF now takes an additional
operand 'IsScalableVF' that indicates whether a scalable VF is required.

Reviewed By: kparzysz, rampitec

Differential Revision: https://reviews.llvm.org/D96020

llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

index 992e155..b2eb986 100644 (file)
@@ -945,7 +945,8 @@ public:
   /// \return The minimum vectorization factor for types of given element
   /// bit width, or 0 if there is no minimum VF. The returned value only
   /// applies when shouldMaximizeVectorBandwidth returns true.
-  unsigned getMinimumVF(unsigned ElemWidth) const;
+  /// If IsScalable is true, the returned ElementCount must be a scalable VF.
+  ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalable) const;
 
   /// \return The maximum vectorization factor for types of given element
   /// bit width and opcode, or 0 if there is no maximum VF.
@@ -1523,7 +1524,8 @@ public:
   virtual unsigned getMinVectorRegisterBitWidth() = 0;
   virtual Optional<unsigned> getMaxVScale() const = 0;
   virtual bool shouldMaximizeVectorBandwidth(bool OptSize) const = 0;
-  virtual unsigned getMinimumVF(unsigned ElemWidth) const = 0;
+  virtual ElementCount getMinimumVF(unsigned ElemWidth,
+                                    bool IsScalable) const = 0;
   virtual unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const = 0;
   virtual bool shouldConsiderAddressTypePromotion(
       const Instruction &I, bool &AllowPromotionWithoutCommonHeader) = 0;
@@ -1951,8 +1953,9 @@ public:
   bool shouldMaximizeVectorBandwidth(bool OptSize) const override {
     return Impl.shouldMaximizeVectorBandwidth(OptSize);
   }
-  unsigned getMinimumVF(unsigned ElemWidth) const override {
-    return Impl.getMinimumVF(ElemWidth);
+  ElementCount getMinimumVF(unsigned ElemWidth,
+                            bool IsScalable) const override {
+    return Impl.getMinimumVF(ElemWidth, IsScalable);
   }
   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const override {
     return Impl.getMaximumVF(ElemWidth, Opcode);
index a4f9e1d..f0443fb 100644 (file)
@@ -374,7 +374,9 @@ public:
 
   bool shouldMaximizeVectorBandwidth(bool OptSize) const { return false; }
 
-  unsigned getMinimumVF(unsigned ElemWidth) const { return 0; }
+  ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalable) const {
+    return ElementCount::get(0, IsScalable);
+  }
 
   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const { return 0; }
 
index e376d42..ee15730 100644 (file)
@@ -640,8 +640,9 @@ bool TargetTransformInfo::shouldMaximizeVectorBandwidth(bool OptSize) const {
   return TTIImpl->shouldMaximizeVectorBandwidth(OptSize);
 }
 
-unsigned TargetTransformInfo::getMinimumVF(unsigned ElemWidth) const {
-  return TTIImpl->getMinimumVF(ElemWidth);
+ElementCount TargetTransformInfo::getMinimumVF(unsigned ElemWidth,
+                                               bool IsScalable) const {
+  return TTIImpl->getMinimumVF(ElemWidth, IsScalable);
 }
 
 unsigned TargetTransformInfo::getMaximumVF(unsigned ElemWidth,
index 1cefa6a..af7bc46 100644 (file)
@@ -104,8 +104,10 @@ unsigned HexagonTTIImpl::getMinVectorRegisterBitWidth() const {
   return useHVX() ? ST.getVectorLength()*8 : 32;
 }
 
-unsigned HexagonTTIImpl::getMinimumVF(unsigned ElemWidth) const {
-  return (8 * ST.getVectorLength()) / ElemWidth;
+ElementCount HexagonTTIImpl::getMinimumVF(unsigned ElemWidth,
+                                          bool IsScalable) const {
+  assert(!IsScalable && "Scalable VFs are not supported for Hexagon");
+  return ElementCount::getFixed((8 * ST.getVectorLength()) / ElemWidth);
 }
 
 unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty,
index 835358d..dc075d6 100644 (file)
@@ -82,7 +82,7 @@ public:
   unsigned getMaxInterleaveFactor(unsigned VF);
   unsigned getRegisterBitWidth(bool Vector) const;
   unsigned getMinVectorRegisterBitWidth() const;
-  unsigned getMinimumVF(unsigned ElemWidth) const;
+  ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalable) const;
 
   bool shouldMaximizeVectorBandwidth(bool OptSize) const {
     return true;
index 71b626e..cb7bc82 100644 (file)
@@ -5812,7 +5812,8 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(unsigned ConstTripCount,
         break;
       }
     }
-    if (auto MinVF = ElementCount::getFixed(TTI.getMinimumVF(SmallestType))) {
+    if (ElementCount MinVF =
+            TTI.getMinimumVF(SmallestType, /*IsScalable=*/false)) {
       if (ElementCount::isKnownLT(MaxVF, MinVF)) {
         LLVM_DEBUG(dbgs() << "LV: Overriding calculated MaxVF(" << MaxVF
                           << ") with target's minimum: " << MinVF << '\n');