[TTI] add wrapper for matching vector reduction to reduce code duplication; NFC
authorSanjay Patel <spatel@rotateright.com>
Wed, 23 Sep 2020 17:46:26 +0000 (13:46 -0400)
committerSanjay Patel <spatel@rotateright.com>
Wed, 23 Sep 2020 17:48:57 +0000 (13:48 -0400)
I'm not sure what this means, but the order in which we try
the matches makes a difference on at least 1 regression test...

llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp

index 297cc8e..b7bcc34 100644 (file)
@@ -876,6 +876,10 @@ public:
   static ReductionKind matchVectorSplittingReduction(
     const ExtractElementInst *ReduxRoot, unsigned &Opcode, VectorType *&Ty);
 
+  static ReductionKind matchVectorReduction(const ExtractElementInst *ReduxRoot,
+                                            unsigned &Opcode, VectorType *&Ty,
+                                            bool &IsPairwise);
+
   /// Additional information about an operand's possible values.
   enum OperandValueKind {
     OK_AnyValue,               // Operand can have any value.
index ebd1beb..22708d0 100644 (file)
@@ -1004,41 +1004,23 @@ public:
       if (CI)
         Idx = CI->getZExtValue();
 
-      // Try to match a reduction sequence (series of shufflevector and
-      // vector  adds followed by a extractelement).
-      unsigned ReduxOpCode;
-      VectorType *ReduxType;
-
-      switch (TTI::matchVectorSplittingReduction(EEI, ReduxOpCode,
-                                                 ReduxType)) {
-      case TTI::RK_Arithmetic:
-        return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
-                                          /*IsPairwiseForm=*/false,
-                                          CostKind);
-      case TTI::RK_MinMax:
-        return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/false, /*IsUnsigned=*/false, CostKind);
-      case TTI::RK_UnsignedMinMax:
-        return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/false, /*IsUnsigned=*/true, CostKind);
-      case TTI::RK_None:
-        break;
-      }
-
-      switch (TTI::matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) {
+      // Try to match a reduction (a series of shufflevector and vector ops
+      // followed by an extractelement).
+      unsigned RdxOpcode;
+      VectorType *RdxType;
+      bool IsPairwise;
+      switch (TTI::matchVectorReduction(EEI, RdxOpcode, RdxType, IsPairwise)) {
       case TTI::RK_Arithmetic:
-        return TargetTTI->getArithmeticReductionCost(ReduxOpCode, ReduxType,
-                                          /*IsPairwiseForm=*/true, CostKind);
+        return TargetTTI->getArithmeticReductionCost(RdxOpcode, RdxType,
+                                                     IsPairwise, CostKind);
       case TTI::RK_MinMax:
         return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/true, /*IsUnsigned=*/false, CostKind);
+            RdxType, cast<VectorType>(CmpInst::makeCmpResultType(RdxType)),
+            IsPairwise, /*IsUnsigned=*/false, CostKind);
       case TTI::RK_UnsignedMinMax:
         return TargetTTI->getMinMaxReductionCost(
-            ReduxType, cast<VectorType>(CmpInst::makeCmpResultType(ReduxType)),
-            /*IsPairwiseForm=*/true, /*IsUnsigned=*/true, CostKind);
+            RdxType, cast<VectorType>(CmpInst::makeCmpResultType(RdxType)),
+            IsPairwise, /*IsUnsigned=*/true, CostKind);
       case TTI::RK_None:
         break;
       }
index efecb45..4836d80 100644 (file)
@@ -1308,6 +1308,18 @@ TTI::ReductionKind TTI::matchVectorSplittingReduction(
   return RD->Kind;
 }
 
+TTI::ReductionKind
+TTI::matchVectorReduction(const ExtractElementInst *Root, unsigned &Opcode,
+                          VectorType *&Ty, bool &IsPairwise) {
+  TTI::ReductionKind RdxKind = matchVectorSplittingReduction(Root, Opcode, Ty);
+  if (RdxKind != TTI::ReductionKind::RK_None) {
+    IsPairwise = false;
+    return RdxKind;
+  }
+  IsPairwise = true;
+  return matchPairwiseReduction(Root, Opcode, Ty);
+}
+
 int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const {
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;