From 03783f19dc78fc45fd987f892c314578b5e52d78 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 17 Sep 2020 08:39:23 -0400 Subject: [PATCH] [SLP] sort candidates to increase chance of optimal compare reduction This is one (small) part of improving PR41312: https://llvm.org/PR41312 As shown there and in the smaller tests here, if we have some member of the reduction values that does not match the others, we want to push it to the end (bring the matching members forward and together). In the regression tests, we have 5 candidates for the 4 slots of the reduction. If the one "wrong" compare is grouped with the others, it prevents forming the ideal v4i1 compare reduction. Differential Revision: https://reviews.llvm.org/D87772 --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 30 ++++++++- .../Transforms/SLPVectorizer/X86/compare-reduce.ll | 71 +++++++--------------- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 3d19e86..c487301 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6838,9 +6838,37 @@ public: for (ReductionOpsType &RdxOp : ReductionOps) IgnoreList.append(RdxOp.begin(), RdxOp.end()); + unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); + if (NumReducedVals > ReduxWidth) { + // In the loop below, we are building a tree based on a window of + // 'ReduxWidth' values. + // If the operands of those values have common traits (compare predicate, + // constant operand, etc), then we want to group those together to + // minimize the cost of the reduction. + + // TODO: This should be extended to count common operands for + // compares and binops. + + // Step 1: Count the number of times each compare predicate occurs. + SmallDenseMap PredCountMap; + for (Value *RdxVal : ReducedVals) { + CmpInst::Predicate Pred; + if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value()))) + ++PredCountMap[Pred]; + } + // Step 2: Sort the values so the most common predicates come first. + stable_sort(ReducedVals, [&PredCountMap](Value *A, Value *B) { + CmpInst::Predicate PredA, PredB; + if (match(A, m_Cmp(PredA, m_Value(), m_Value())) && + match(B, m_Cmp(PredB, m_Value(), m_Value()))) { + return PredCountMap[PredA] > PredCountMap[PredB]; + } + return false; + }); + } + Value *VectorizedTree = nullptr; unsigned i = 0; - unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { ArrayRef VL = makeArrayRef(&ReducedVals[i], ReduxWidth); V.buildTree(VL, ExternallyUsedValues, IgnoreList); diff --git a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll index daa96bf..b0971dd 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll @@ -81,20 +81,12 @@ declare i32 @printf(i8* nocapture, ...) define float @merge_anyof_v4f32_wrong_first(<4 x float> %x) { ; CHECK-LABEL: @merge_anyof_v4f32_wrong_first( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x float> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x float> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x float> [[X]], i32 3 -; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[X3]], 4.200000e+01 -; CHECK-NEXT: [[CMP0:%.*]] = fcmp ogt float [[X0]], 1.000000e+00 -; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[X1]], 1.000000e+00 -; CHECK-NEXT: [[CMP2:%.*]] = fcmp ogt float [[X2]], 1.000000e+00 -; CHECK-NEXT: [[CMP3:%.*]] = fcmp ogt float [[X3]], 1.000000e+00 -; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3WRONG]] -; CHECK-NEXT: [[OR031:%.*]] = or i1 [[OR03]], [[CMP1]] -; CHECK-NEXT: [[OR0312:%.*]] = or i1 [[OR031]], [[CMP2]] -; CHECK-NEXT: [[OR03123:%.*]] = or i1 [[OR0312]], [[CMP3]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03123]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01 +; CHECK-NEXT: [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], +; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %x0 = extractelement <4 x float> %x, i32 0 @@ -143,20 +135,12 @@ define float @merge_anyof_v4f32_wrong_last(<4 x float> %x) { define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) { ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 -; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[X3]], 42 -; CHECK-NEXT: [[CMP0:%.*]] = icmp sgt i32 [[X0]], 1 -; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], 1 -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X2]], 1 -; CHECK-NEXT: [[CMP3:%.*]] = icmp sgt i32 [[X3]], 1 -; CHECK-NEXT: [[OR03:%.*]] = or i1 [[CMP0]], [[CMP3]] -; CHECK-NEXT: [[OR033:%.*]] = or i1 [[OR03]], [[CMP3WRONG]] -; CHECK-NEXT: [[OR0332:%.*]] = or i1 [[OR033]], [[CMP2]] -; CHECK-NEXT: [[OR03321:%.*]] = or i1 [[OR0332]], [[CMP1]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[OR03321]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42 +; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], +; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0 @@ -176,29 +160,18 @@ define i32 @merge_anyof_v4i32_wrong_middle(<4 x i32> %x) { ret i32 %r } +; Operand/predicate swapping allows forming a reduction, but the +; ideal reduction groups all of the original 'sgt' ops together. + define i32 @merge_anyof_v4i32_wrong_middle_better_rdx(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @merge_anyof_v4i32_wrong_middle_better_rdx( -; CHECK-NEXT: [[X0:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 0 -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X]], i32 1 -; CHECK-NEXT: [[X2:%.*]] = extractelement <4 x i32> [[X]], i32 2 -; CHECK-NEXT: [[X3:%.*]] = extractelement <4 x i32> [[X]], i32 3 -; CHECK-NEXT: [[Y0:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 0 -; CHECK-NEXT: [[Y1:%.*]] = extractelement <4 x i32> [[Y]], i32 1 -; CHECK-NEXT: [[Y2:%.*]] = extractelement <4 x i32> [[Y]], i32 2 -; CHECK-NEXT: [[Y3:%.*]] = extractelement <4 x i32> [[Y]], i32 3 -; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[X1]], [[Y1]] -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32> undef, i32 [[X0]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[X3]], i32 1 -; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[Y3]], i32 2 -; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[X2]], i32 3 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> undef, i32 [[Y0]], i32 0 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[Y3]], i32 1 -; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i32> [[TMP6]], i32 [[X3]], i32 2 -; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x i32> [[TMP7]], i32 [[Y2]], i32 3 -; CHECK-NEXT: [[TMP9:%.*]] = icmp sgt <4 x i32> [[TMP4]], [[TMP8]] -; CHECK-NEXT: [[TMP10:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP9]]) -; CHECK-NEXT: [[TMP11:%.*]] = or i1 [[TMP10]], [[CMP1]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP11]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 3 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 +; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], [[Y]] +; CHECK-NEXT: [[TMP4:%.*]] = call i1 @llvm.experimental.vector.reduce.or.v4i1(<4 x i1> [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0 -- 2.7.4