[SLP] sort candidates to increase chance of optimal compare reduction
authorSanjay Patel <spatel@rotateright.com>
Thu, 17 Sep 2020 12:39:23 +0000 (08:39 -0400)
committerSanjay Patel <spatel@rotateright.com>
Thu, 17 Sep 2020 12:49:27 +0000 (08:49 -0400)
This is one (small) part of improving PR41312:
https://llvm.org/PR41312

As shown there and in the smaller tests here, if we have some member of the
reduction values that does not match the others, we want to push it to the
end (bring the matching members forward and together).

In the regression tests, we have 5 candidates for the 4 slots of the reduction.
If the one "wrong" compare is grouped with the others, it prevents forming the
ideal v4i1 compare reduction.

Differential Revision: https://reviews.llvm.org/D87772

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll

index 3d19e86..c487301 100644 (file)
@@ -6838,9 +6838,37 @@ public:
     for (ReductionOpsType &RdxOp : ReductionOps)
       IgnoreList.append(RdxOp.begin(), RdxOp.end());
 
+    unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
+    if (NumReducedVals > ReduxWidth) {
+      // In the loop below, we are building a tree based on a window of
+      // 'ReduxWidth' values.
+      // If the operands of those values have common traits (compare predicate,
+      // constant operand, etc), then we want to group those together to
+      // minimize the cost of the reduction.
+
+      // TODO: This should be extended to count common operands for
+      //       compares and binops.
+
+      // Step 1: Count the number of times each compare predicate occurs.
+      SmallDenseMap<unsigned, unsigned> PredCountMap;
+      for (Value *RdxVal : ReducedVals) {
+        CmpInst::Predicate Pred;
+        if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value())))
+          ++PredCountMap[Pred];
+      }
+      // Step 2: Sort the values so the most common predicates come first.
+      stable_sort(ReducedVals, [&PredCountMap](Value *A, Value *B) {
+        CmpInst::Predicate PredA, PredB;
+        if (match(A, m_Cmp(PredA, m_Value(), m_Value())) &&
+            match(B, m_Cmp(PredB, m_Value(), m_Value()))) {
+          return PredCountMap[PredA] > PredCountMap[PredB];
+        }
+        return false;
+      });
+    }
+
     Value *VectorizedTree = nullptr;
     unsigned i = 0;
-    unsigned ReduxWidth = PowerOf2Floor(NumReducedVals);
     while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
       ArrayRef<Value *> VL = makeArrayRef(&ReducedVals[i], ReduxWidth);
       V.buildTree(VL, ExternallyUsedValues, IgnoreList);
index daa96bf..b0971dd 100644 (file)
@@ -81,20 +81,12 @@ declare i32 @printf(i8* nocapture, ...)
 
 define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) {
 ; CHECK-LABEL: @merge_anyof_v4f32_wrong_first(
-; CHECK-NEXT:    [[X0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0
-; CHECK-NEXT:    [[X1:%.*]] = extractelement <4 x float> [[X]], i32 1
-; CHECK-NEXT:    [[X2:%.*]] = extractelement <4 x float> [[X]], i32 2
-; CHECK-NEXT:    [[X3:%.*]] = extractelement <4 x float> [[X]], i32 3
-; CHECK-NEXT:    [[CMP3WRONG:%.*]] = fcmp olt float [[X3]], 4.200000e+01
-; CHECK-NEXT:    [[CMP0:%.*]] = fcmp ogt float [[X0]], 1.000000e+00
-; CHECK-NEXT:    [[CMP1:%.*]] = fcmp ogt float [[X1]], 1.000000e+00
-; CHECK-NEXT:    [[CMP2:%.*]] = fcmp ogt float [[X2]], 1.000000e+00
-; CHECK-NEXT:    [[CMP3:%.*]] = fcmp ogt float [[X3]], 1.000000e+00
-; CHECK-NEXT:    [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3WRONG]]
-; CHECK-NEXT:    [[OR031:%.*]] = or i1 [[OR03]], [[CMP1]]
-; CHECK-NEXT:    [[OR0312:%.*]] = or i1 [[OR031]], [[CMP2]]
-; CHECK-NEXT:    [[OR03123:%.*]] = or i1 [[OR0312]], [[CMP3]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[OR03123]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3
+; CHECK-NEXT:    [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01
+; CHECK-NEXT:    [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00
 ; CHECK-NEXT:    ret float [[R]]
 ;
   %x0 = extractelement <4 x float> %x, i32 0
@@ -143,20 +135,12 @@ define float @merge_anyof_v4f32_wrong_last(<4 x float> %x) {
 
 define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) {
 ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle(
-; CHECK-NEXT:    [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0
-; CHECK-NEXT:    [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1
-; CHECK-NEXT:    [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2
-; CHECK-NEXT:    [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3
-; CHECK-NEXT:    [[CMP3WRONG:%.*]] = icmp slt i32 [[X3]], 42
-; CHECK-NEXT:    [[CMP0:%.*]] = icmp sgt i32 [[X0]], 1
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i32 [[X1]], 1
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[X2]], 1
-; CHECK-NEXT:    [[CMP3:%.*]] = icmp sgt i32 [[X3]], 1
-; CHECK-NEXT:    [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3]]
-; CHECK-NEXT:    [[OR033:%.*]] = or i1 [[OR03]], [[CMP3WRONG]]
-; CHECK-NEXT:    [[OR0332:%.*]] = or i1 [[OR033]], [[CMP2]]
-; CHECK-NEXT:    [[OR03321:%.*]] = or i1 [[OR0332]], [[CMP1]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[OR03321]], i32 -1, i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3
+; CHECK-NEXT:    [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %x0 = extractelement <4 x i32> %x, i32 0
@@ -176,29 +160,18 @@ define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) {
   ret i32 %r
 }
 
+; Operand/predicate swapping allows forming a reduction, but the
+; ideal reduction groups all of the original 'sgt' ops together.
+
 define i32 @merge_anyof_v4i32_wrong_middle_better_rdx(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle_better_rdx(
-; CHECK-NEXT:    [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0
-; CHECK-NEXT:    [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1
-; CHECK-NEXT:    [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2
-; CHECK-NEXT:    [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3
-; CHECK-NEXT:    [[Y0:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 0
-; CHECK-NEXT:    [[Y1:%.*]] = extractelement <4 x i32> [[Y]], i32 1
-; CHECK-NEXT:    [[Y2:%.*]] = extractelement <4 x i32> [[Y]], i32 2
-; CHECK-NEXT:    [[Y3:%.*]] = extractelement <4 x i32> [[Y]], i32 3
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i32 [[X1]], [[Y1]]
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x i32> undef, i32 [[X0]], i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X3]], i32 1
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[Y3]], i32 2
-; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[X2]], i32 3
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x i32> undef, i32 [[Y0]], i32 0
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[Y3]], i32 1
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x i32> [[TMP6]], i32 [[X3]], i32 2
-; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <4 x i32> [[TMP7]], i32 [[Y2]], i32 3
-; CHECK-NEXT:    [[TMP9:%.*]] = icmp sgt <4 x i32> [[TMP4]], [[TMP8]]
-; CHECK-NEXT:    [[TMP10:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP9]])
-; CHECK-NEXT:    [[TMP11:%.*]] = or i1 [[TMP10]], [[CMP1]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP11]], i32 -1, i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 3
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3
+; CHECK-NEXT:    [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], [[Y]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP3]])
+; CHECK-NEXT:    [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %x0 = extractelement <4 x i32> %x, i32 0