[TTI] getMinMaxReductionCost - add FastMathFlag argument
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 13 Apr 2023 09:42:36 +0000 (10:42 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 13 Apr 2023 09:42:42 +0000 (10:42 +0100)
Similar to the getArithmeticReductionCost / getExtendedReductionCost calls (which really don't need to use std::optional<>).

This will be necessary to correct recognize fast/nnan fmax/fmul reductions which can avoid nan handling - which will allow us to remove the fmax/fmin special case in X86TTIImpl::getMinMaxCost and use getIntrinsicInstrCost like we do for integer reductions (63c3895327839ba5b57f5b99ec9e888abf976ac6).

Differential Revision: https://reviews.llvm.org/D148149

13 files changed:
llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/CodeGen/BasicTTIImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
llvm/lib/Target/X86/X86TargetTransformInfo.cpp
llvm/lib/Target/X86/X86TargetTransformInfo.h
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

index 9f074e5..7dcc8fb 100644 (file)
@@ -1392,6 +1392,7 @@ public:
 
   InstructionCost getMinMaxReductionCost(
       VectorType *Ty, VectorType *CondTy, bool IsUnsigned,
+      FastMathFlags FMF = FastMathFlags(),
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
@@ -1934,7 +1935,7 @@ public:
                              TTI::TargetCostKind CostKind) = 0;
   virtual InstructionCost
   getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, bool IsUnsigned,
-                         TTI::TargetCostKind CostKind) = 0;
+                         FastMathFlags FMF, TTI::TargetCostKind CostKind) = 0;
   virtual InstructionCost getExtendedReductionCost(
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
       std::optional<FastMathFlags> FMF,
@@ -2545,8 +2546,9 @@ public:
   }
   InstructionCost
   getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, bool IsUnsigned,
+                         FastMathFlags FMF,
                          TTI::TargetCostKind CostKind) override {
-    return Impl.getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
+    return Impl.getMinMaxReductionCost(Ty, CondTy, IsUnsigned, FMF, CostKind);
   }
   InstructionCost getExtendedReductionCost(
       unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
index 721959f..8c15fb7 100644 (file)
@@ -721,6 +721,7 @@ public:
   }
 
   InstructionCost getMinMaxReductionCost(VectorType *, VectorType *, bool,
+                                         FastMathFlags,
                                          TTI::TargetCostKind) const {
     return 1;
   }
index 88d7061..4a3f38d 100644 (file)
@@ -1890,12 +1890,12 @@ public:
     case Intrinsic::vector_reduce_fmin:
       return thisT()->getMinMaxReductionCost(
           VecOpTy, cast<VectorType>(CmpInst::makeCmpResultType(VecOpTy)),
-          /*IsUnsigned=*/false, CostKind);
+          /*IsUnsigned=*/false, ICA.getFlags(), CostKind);
     case Intrinsic::vector_reduce_umax:
     case Intrinsic::vector_reduce_umin:
       return thisT()->getMinMaxReductionCost(
           VecOpTy, cast<VectorType>(CmpInst::makeCmpResultType(VecOpTy)),
-          /*IsUnsigned=*/true, CostKind);
+          /*IsUnsigned=*/true, ICA.getFlags(), CostKind);
     case Intrinsic::abs: {
       // abs(X) = select(icmp(X,0),X,sub(0,X))
       Type *CondTy = RetTy->getWithNewBitWidth(1);
@@ -2344,7 +2344,7 @@ public:
   /// Try to calculate op costs for min/max reduction operations.
   /// \param CondTy Conditional type for the Select instruction.
   InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                         bool IsUnsigned,
+                                         bool IsUnsigned, FastMathFlags FMF,
                                          TTI::TargetCostKind CostKind) {
     // Targets must implement a default value for the scalable case, since
     // we don't know how many lanes the vector has.
index 68674b9..8ba8e77 100644 (file)
@@ -1048,10 +1048,10 @@ InstructionCost TargetTransformInfo::getArithmeticReductionCost(
 }
 
 InstructionCost TargetTransformInfo::getMinMaxReductionCost(
-    VectorType *Ty, VectorType *CondTy, bool IsUnsigned,
+    VectorType *Ty, VectorType *CondTy, bool IsUnsigned, FastMathFlags FMF,
     TTI::TargetCostKind CostKind) const {
   InstructionCost Cost =
-      TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
+      TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsUnsigned, FMF, CostKind);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
index 97153b0..67eb400 100644 (file)
@@ -2957,12 +2957,12 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
 
 InstructionCost
 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                       bool IsUnsigned,
+                                       bool IsUnsigned, FastMathFlags FMF,
                                        TTI::TargetCostKind CostKind) {
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
 
   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
-    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
+    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, FMF, CostKind);
 
   assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
          "Both vector needs to be equally scalable");
@@ -2970,11 +2970,12 @@ AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
   InstructionCost LegalizationCost = 0;
   if (LT.first > 1) {
     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
-    unsigned MinMaxOpcode =
+    Intrinsic::ID MinMaxOpcode =
         Ty->isFPOrFPVectorTy()
             ? Intrinsic::maxnum
             : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
-    IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
+    IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy},
+                                  FMF);
     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
   }
 
index de8b2c6..d241b70 100644 (file)
@@ -179,7 +179,7 @@ public:
                                      unsigned Index);
 
   InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                         bool IsUnsigned,
+                                         bool IsUnsigned, FastMathFlags FMF,
                                          TTI::TargetCostKind CostKind);
 
   InstructionCost getArithmeticReductionCostSVE(unsigned Opcode,
index 6b1ffd2..d5f2a11 100644 (file)
@@ -774,14 +774,14 @@ GCNTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
 
 InstructionCost
 GCNTTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                   bool IsUnsigned,
+                                   bool IsUnsigned, FastMathFlags FMF,
                                    TTI::TargetCostKind CostKind) {
   EVT OrigTy = TLI->getValueType(DL, Ty);
 
   // Computes cost on targets that have packed math instructions(which support
   // 16-bit types only).
   if (!ST->hasVOP3PInsts() || OrigTy.getScalarSizeInBits() != 16)
-    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
+    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, FMF, CostKind);
 
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
   return LT.first * getHalfRateInstrCost(CostKind);
index f965b54..972ea8c 100644 (file)
@@ -241,9 +241,9 @@ public:
 
   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                                         TTI::TargetCostKind CostKind);
-  InstructionCost getMinMaxReductionCost(
-      VectorType *Ty, VectorType *CondTy, bool IsUnsigned,
-      TTI::TargetCostKind CostKind);
+  InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
+                                         bool IsUnsigned, FastMathFlags FMF,
+                                         TTI::TargetCostKind CostKind);
 };
 
 } // end namespace llvm
index fcd7134..da2451e 100644 (file)
@@ -1167,14 +1167,14 @@ unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {
 
 InstructionCost
 RISCVTTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                     bool IsUnsigned,
+                                     bool IsUnsigned, FastMathFlags FMF,
                                      TTI::TargetCostKind CostKind) {
   if (isa<FixedVectorType>(Ty) && !ST->useRVVForFixedLengthVectors())
-    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
+    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, FMF, CostKind);
 
   // Skip if scalar size of Ty is bigger than ELEN.
   if (Ty->getScalarSizeInBits() > ST->getELEN())
-    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
+    return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, FMF, CostKind);
 
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
   if (Ty->getElementType()->isIntegerTy(1))
index 9df45ab..f1c6104 100644 (file)
@@ -143,7 +143,7 @@ public:
                                    const Instruction *I = nullptr);
 
   InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                         bool IsUnsigned,
+                                         bool IsUnsigned, FastMathFlags FMF,
                                          TTI::TargetCostKind CostKind);
 
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
index 6a6b955..1db80a4 100644 (file)
@@ -5192,10 +5192,10 @@ X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
 
 InstructionCost X86TTIImpl::getMinMaxCost(Type *Ty, Type *CondTy,
                                           TTI::TargetCostKind CostKind,
-                                          bool IsUnsigned) {
+                                          bool IsUnsigned, FastMathFlags FMF) {
   if (Ty->isIntOrIntVectorTy()) {
     Intrinsic::ID Id = IsUnsigned ? Intrinsic::umin : Intrinsic::smin;
-    IntrinsicCostAttributes ICA(Id, Ty, {Ty, Ty});
+    IntrinsicCostAttributes ICA(Id, Ty, {Ty, Ty}, FMF);
     return getIntrinsicInstrCost(ICA, CostKind);
   }
 
@@ -5253,7 +5253,7 @@ InstructionCost X86TTIImpl::getMinMaxCost(Type *Ty, Type *CondTy,
 
 InstructionCost
 X86TTIImpl::getMinMaxReductionCost(VectorType *ValTy, VectorType *CondTy,
-                                   bool IsUnsigned,
+                                   bool IsUnsigned, FastMathFlags FMF,
                                    TTI::TargetCostKind CostKind) {
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
 
@@ -5343,7 +5343,7 @@ X86TTIImpl::getMinMaxReductionCost(VectorType *ValTy, VectorType *CondTy,
                               MTy.getVectorNumElements());
     auto *SubCondTy = FixedVectorType::get(CondTy->getElementType(),
                                            MTy.getVectorNumElements());
-    MinMaxCost = getMinMaxCost(Ty, SubCondTy, CostKind, IsUnsigned);
+    MinMaxCost = getMinMaxCost(Ty, SubCondTy, CostKind, IsUnsigned, FMF);
     MinMaxCost *= LT.first - 1;
     NumVecElts = MTy.getVectorNumElements();
   }
@@ -5370,7 +5370,8 @@ X86TTIImpl::getMinMaxReductionCost(VectorType *ValTy, VectorType *CondTy,
   // by type legalization.
   if (!isPowerOf2_32(ValVTy->getNumElements()) ||
       ScalarSize != MTy.getScalarSizeInBits())
-    return BaseT::getMinMaxReductionCost(ValTy, CondTy, IsUnsigned, CostKind);
+    return BaseT::getMinMaxReductionCost(ValTy, CondTy, IsUnsigned, FMF,
+                                         CostKind);
 
   // Now handle reduction with the legal type, taking into account size changes
   // at each level.
@@ -5416,7 +5417,7 @@ X86TTIImpl::getMinMaxReductionCost(VectorType *ValTy, VectorType *CondTy,
     // Add the arithmetic op for this level.
     auto *SubCondTy =
         FixedVectorType::get(CondTy->getElementType(), Ty->getNumElements());
-    MinMaxCost += getMinMaxCost(Ty, SubCondTy, CostKind, IsUnsigned);
+    MinMaxCost += getMinMaxCost(Ty, SubCondTy, CostKind, IsUnsigned, FMF);
   }
 
   // Add the final extract element to the cost.
index 096983e..f182aa7 100644 (file)
@@ -205,11 +205,12 @@ public:
                                              std::optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
 
-  InstructionCost getMinMaxCost(Type *Ty, Type *CondTy, TTI::TargetCostKind CostKind,
-                                bool IsUnsigned);
+  InstructionCost getMinMaxCost(Type *Ty, Type *CondTy,
+                                TTI::TargetCostKind CostKind, bool IsUnsigned,
+                                FastMathFlags FMF);
 
   InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
-                                         bool IsUnsigned,
+                                         bool IsUnsigned, FastMathFlags FMF,
                                          TTI::TargetCostKind CostKind);
 
   InstructionCost getInterleavedMemoryOpCost(
index 3faa5f4..8d26cb2 100644 (file)
@@ -13628,7 +13628,7 @@ private:
             cast<VectorType>(CmpInst::makeCmpResultType(VectorTy));
         VectorCost =
             TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
-                                        /*IsUnsigned=*/false, CostKind);
+                                        /*IsUnsigned=*/false, FMF, CostKind);
       }
       CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind);
       ScalarCost = EvaluateScalarCost([&]() {
@@ -13650,7 +13650,7 @@ private:
         bool IsUnsigned =
             RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin;
         VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy,
-                                                 IsUnsigned, CostKind);
+                                                 IsUnsigned, FMF, CostKind);
       }
       CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind);
       ScalarCost = EvaluateScalarCost([&]() {