[X86] When removing sign extends from gather/scatter indices, make sure we handle...
authorCraig Topper <craig.topper@intel.com>
Fri, 27 Jul 2018 00:00:30 +0000 (00:00 +0000)
committerCraig Topper <craig.topper@intel.com>
Fri, 27 Jul 2018 00:00:30 +0000 (00:00 +0000)
If this happens the operands aren't updated and the existing node is returned. Make sure we pass this existing node up to the DAG combiner so that a proper replacement happens. Otherwise we get stuck in an infinite loop with an unoptimized node.

llvm-svn: 338090

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/masked_gather_scatter.ll

index c53c8d1..7572cf4 100644 (file)
@@ -38100,12 +38100,14 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
           Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
         SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
         NewOps[4] = Index.getOperand(0);
-        DAG.UpdateNodeOperands(N, NewOps);
-        // The original sign extend has less users, add back to worklist in case
-        // it needs to be removed
-        DCI.AddToWorklist(Index.getNode());
-        DCI.AddToWorklist(N);
-        return SDValue(N, 0);
+        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+        if (Res == N) {
+          // The original sign extend has less users, add back to worklist in
+          // case it needs to be removed
+          DCI.AddToWorklist(Index.getNode());
+          DCI.AddToWorklist(N);
+        }
+        return SDValue(Res, 0);
       }
     }
 
@@ -38118,9 +38120,10 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
       SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
       NewOps[4] = Index;
-      DAG.UpdateNodeOperands(N, NewOps);
-      DCI.AddToWorklist(N);
-      return SDValue(N, 0);
+      SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+      if (Res == N)
+        DCI.AddToWorklist(N);
+      return SDValue(Res, 0);
     }
 
     // Try to remove zero extends from 32->64 if we know the sign bit of
@@ -38131,12 +38134,14 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       if (DAG.SignBitIsZero(Index.getOperand(0))) {
         SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
         NewOps[4] = Index.getOperand(0);
-        DAG.UpdateNodeOperands(N, NewOps);
-        // The original zero extend has less users, add back to worklist in case
-        // it needs to be removed
-        DCI.AddToWorklist(Index.getNode());
-        DCI.AddToWorklist(N);
-        return SDValue(N, 0);
+        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
+        if (Res == N) {
+          // The original sign extend has less users, add back to worklist in
+          // case it needs to be removed
+          DCI.AddToWorklist(Index.getNode());
+          DCI.AddToWorklist(N);
+        }
+        return SDValue(Res, 0);
       }
     }
   }
index 6bd8f92..bba1b39 100644 (file)
@@ -2928,3 +2928,54 @@ define void @test_scatter_setcc_split(double* %base, <16 x i32> %ind, <16 x i32>
   call void @llvm.masked.scatter.v16f64.v16p0f64(<16 x double> %src0, <16 x double*> %gep.random, i32 4, <16 x i1> %mask)
   ret void
 }
+
+; This test case previously triggered an infinite loop when the two gathers became identical after DAG combine removed the sign extend.
+define <16 x float> @test_sext_cse(float* %base, <16 x i32> %ind, <16 x i32>* %foo) {
+; KNL_64-LABEL: test_sext_cse:
+; KNL_64:       # %bb.0:
+; KNL_64-NEXT:    vmovaps %zmm0, (%rsi)
+; KNL_64-NEXT:    kxnorw %k0, %k0, %k1
+; KNL_64-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; KNL_64-NEXT:    vaddps %zmm1, %zmm1, %zmm0
+; KNL_64-NEXT:    retq
+;
+; KNL_32-LABEL: test_sext_cse:
+; KNL_32:       # %bb.0:
+; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; KNL_32-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; KNL_32-NEXT:    vmovaps %zmm0, (%ecx)
+; KNL_32-NEXT:    kxnorw %k0, %k0, %k1
+; KNL_32-NEXT:    vgatherdps (%eax,%zmm0,4), %zmm1 {%k1}
+; KNL_32-NEXT:    vaddps %zmm1, %zmm1, %zmm0
+; KNL_32-NEXT:    retl
+;
+; SKX-LABEL: test_sext_cse:
+; SKX:       # %bb.0:
+; SKX-NEXT:    vmovaps %zmm0, (%rsi)
+; SKX-NEXT:    kxnorw %k0, %k0, %k1
+; SKX-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; SKX-NEXT:    vaddps %zmm1, %zmm1, %zmm0
+; SKX-NEXT:    retq
+;
+; SKX_32-LABEL: test_sext_cse:
+; SKX_32:       # %bb.0:
+; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; SKX_32-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; SKX_32-NEXT:    vmovaps %zmm0, (%ecx)
+; SKX_32-NEXT:    kxnorw %k0, %k0, %k1
+; SKX_32-NEXT:    vgatherdps (%eax,%zmm0,4), %zmm1 {%k1}
+; SKX_32-NEXT:    vaddps %zmm1, %zmm1, %zmm0
+; SKX_32-NEXT:    retl
+  %broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0
+  %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer
+
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr float, <16 x float*> %broadcast.splat, <16 x i64> %sext_ind
+
+  store <16 x i32> %ind, <16 x i32>* %foo
+  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
+  %gep.random2 = getelementptr float, <16 x float*> %broadcast.splat, <16 x i32> %ind
+  %res2 = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random2, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
+  %res3 = fadd <16 x float> %res2, %res
+  ret <16 x float>%res3
+}