[SLP] fix insertion point for min/max reduction
authorSanjay Patel <spatel@rotateright.com>
Tue, 19 Nov 2019 15:47:07 +0000 (10:47 -0500)
committerSanjay Patel <spatel@rotateright.com>
Tue, 19 Nov 2019 15:50:10 +0000 (10:50 -0500)
As discussed in D70148 (and caused a revert of the original commit):
if we insert at the select, then we can produce invalid IR because
the replacement for the compare may have uses before the select.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/reduction.ll

index 373d7a7ffd13a6de84d805a192237e2f3aea0fc3..fc4f63e4a5b75568ef1f82be0d2f6ef2fd427719 100644 (file)
@@ -6741,8 +6741,23 @@ public:
       DebugLoc Loc = cast<Instruction>(ReducedVals[i])->getDebugLoc();
       Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues);
 
-      // Emit a reduction.
-      Builder.SetInsertPoint(cast<Instruction>(ReductionRoot));
+      auto getCmpForMinMaxReduction = [](Instruction *RdxRootInst) {
+        assert(isa<SelectInst>(RdxRootInst) &&
+               "Expected min/max reduction to have select root instruction");
+        Value *ScalarCond = cast<SelectInst>(RdxRootInst)->getCondition();
+        assert(isa<Instruction>(ScalarCond) &&
+               "Expected min/max reduction to have compare condition");
+        return cast<Instruction>(ScalarCond);
+      };
+
+      // Emit a reduction. For min/max, the root is a select, but the insertion
+      // point is the compare condition of that select.
+      Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
+      if (ReductionData.isMinMax())
+        Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
+      else
+        Builder.SetInsertPoint(RdxRootInst);
+
       Value *ReducedSubTree =
           emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
       if (VectorizedTree) {
index f6e1d0ad2fea91643907c3b788db641fed12f51c..3a82ee5fa45c6362f5471a9fd5a161cfbe0ce220 100644 (file)
@@ -131,13 +131,13 @@ define i1 @bad_insertpoint_rdx([8 x i32]* %p) #0 {
 ; CHECK-NEXT:    [[ARRAYIDX22:%.*]] = getelementptr inbounds [8 x i32], [8 x i32]* [[P:%.*]], i64 0, i64 0
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i32* [[ARRAYIDX22]] to <2 x i32>*
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <2 x i32>, <2 x i32>* [[TMP1]], align 16
-; CHECK-NEXT:    [[SPEC_STORE_SELECT87:%.*]] = zext i1 undef to i32
 ; CHECK-NEXT:    [[RDX_SHUF:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> undef, <2 x i32> <i32 1, i32 undef>
 ; CHECK-NEXT:    [[RDX_MINMAX_CMP:%.*]] = icmp sgt <2 x i32> [[TMP2]], [[RDX_SHUF]]
 ; CHECK-NEXT:    [[RDX_MINMAX_SELECT:%.*]] = select <2 x i1> [[RDX_MINMAX_CMP]], <2 x i32> [[TMP2]], <2 x i32> [[RDX_SHUF]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i32> [[RDX_MINMAX_SELECT]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = icmp sgt i32 [[TMP3]], 0
 ; CHECK-NEXT:    [[OP_EXTRA:%.*]] = select i1 [[TMP4]], i32 [[TMP3]], i32 0
+; CHECK-NEXT:    [[SPEC_STORE_SELECT87:%.*]] = zext i1 undef to i32
 ; CHECK-NEXT:    [[CMP23_2:%.*]] = icmp sgt i32 [[SPEC_STORE_SELECT87]], [[OP_EXTRA]]
 ; CHECK-NEXT:    ret i1 [[CMP23_2]]
 ;