[X86] Split combineGatherScatter into a version for generic ISD nodes and another...
authorCraig Topper <craig.topper@intel.com>
Sat, 28 Sep 2019 01:06:58 +0000 (01:06 +0000)
committerCraig Topper <craig.topper@intel.com>
Sat, 28 Sep 2019 01:06:58 +0000 (01:06 +0000)
The majority of the code doesn't run on the X86 nodes today since
its gated by isBeforeLegalizeOps and we don't formm X86 nodes
until after that. Except for a couple special case in type
legalization. But I think we would probably break those if
some of the transforms fire on them.

I want to remove the hardcoded operand numbers and the unusual
use of UpdateNodeOperands. Being able to know which ISD opcodes
are present should help with that.

llvm-svn: 373136

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/avx512-vec-cmp.ll

index 77a7489..5741b80 100644 (file)
@@ -43364,9 +43364,22 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
+                                       TargetLowering::DAGCombinerInfo &DCI) {
+  // With vector masks we only demand the upper bit of the mask.
+  SDValue Mask = cast<X86MaskedGatherScatterSDNode>(N)->getMask();
+  if (Mask.getScalarValueSizeInBits() != 1) {
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
+    if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI))
+      return SDValue(N, 0);
+  }
+
+  return SDValue();
+}
+
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
-                                    TargetLowering::DAGCombinerInfo &DCI,
-                                    const X86Subtarget &Subtarget) {
+                                    TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
 
   if (DCI.isBeforeLegalizeOps()) {
@@ -43426,7 +43439,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
   }
 
   // With vector masks we only demand the upper bit of the mask.
-  SDValue Mask = N->getOperand(2);
+  SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
@@ -44465,6 +44478,27 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
                             HADDBuilder);
   }
 
+  // If vectors of i1 are legal, turn (add (zext (vXi1 X)), Y) into
+  // (sub Y, (sext (vXi1 X))).
+  // FIXME: We have the (sub Y, (zext (vXi1 X))) -> (add (sext (vXi1 X)), Y) in
+  // generic DAG combine without a legal type check, but adding this there
+  // caused regressions.
+  if (Subtarget.hasAVX512() && VT.isVector()) {
+    if (Op0.getOpcode() == ISD::ZERO_EXTEND &&
+        Op0.getOperand(0).getValueType().getVectorElementType() == MVT::i1) {
+      SDLoc DL(N);
+      SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op0.getOperand(0));
+      return DAG.getNode(ISD::SUB, DL, VT, Op1, SExt);
+    }
+
+    if (Op1.getOpcode() == ISD::ZERO_EXTEND &&
+        Op1.getOperand(0).getValueType().getVectorElementType() == MVT::i1) {
+      SDLoc DL(N);
+      SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op1.getOperand(0));
+      return DAG.getNode(ISD::SUB, DL, VT, Op0, SExt);
+    }
+  }
+
   return combineAddOrSubToADCOrSBB(N, DAG);
 }
 
@@ -45355,9 +45389,9 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::FMSUBADD:    return combineFMADDSUB(N, DAG, Subtarget);
   case X86ISD::MOVMSK:      return combineMOVMSK(N, DAG, DCI, Subtarget);
   case X86ISD::MGATHER:
-  case X86ISD::MSCATTER:
+  case X86ISD::MSCATTER:    return combineX86GatherScatter(N, DAG, DCI);
   case ISD::MGATHER:
-  case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI, Subtarget);
+  case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::PCMPEQ:
   case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
   case X86ISD::PMULDQ:
index b5fcc75..88910fa 100644 (file)
@@ -1414,8 +1414,7 @@ define <4 x i32> @zext_bool_logic(<4 x i64> %cond1, <4 x i64> %cond2, <4 x i32>
 ; AVX512-NEXT:    vptestnmq %zmm1, %zmm1, %k1 ## encoding: [0x62,0xf2,0xf6,0x48,0x27,0xc9]
 ; AVX512-NEXT:    korw %k1, %k0, %k1 ## encoding: [0xc5,0xfc,0x45,0xc9]
 ; AVX512-NEXT:    vpternlogd $255, %zmm0, %zmm0, %zmm0 {%k1} {z} ## encoding: [0x62,0xf3,0x7d,0xc9,0x25,0xc0,0xff]
-; AVX512-NEXT:    vpsrld $31, %xmm0, %xmm0 ## encoding: [0xc5,0xf9,0x72,0xd0,0x1f]
-; AVX512-NEXT:    vpaddd %xmm2, %xmm0, %xmm0 ## encoding: [0xc5,0xf9,0xfe,0xc2]
+; AVX512-NEXT:    vpsubd %xmm0, %xmm2, %xmm0 ## encoding: [0xc5,0xe9,0xfa,0xc0]
 ; AVX512-NEXT:    vzeroupper ## encoding: [0xc5,0xf8,0x77]
 ; AVX512-NEXT:    retq ## encoding: [0xc3]
 ;
@@ -1425,8 +1424,7 @@ define <4 x i32> @zext_bool_logic(<4 x i64> %cond1, <4 x i64> %cond2, <4 x i32>
 ; SKX-NEXT:    vptestnmq %ymm1, %ymm1, %k1 ## encoding: [0x62,0xf2,0xf6,0x28,0x27,0xc9]
 ; SKX-NEXT:    korw %k1, %k0, %k0 ## encoding: [0xc5,0xfc,0x45,0xc1]
 ; SKX-NEXT:    vpmovm2d %k0, %xmm0 ## encoding: [0x62,0xf2,0x7e,0x08,0x38,0xc0]
-; SKX-NEXT:    vpsrld $31, %xmm0, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xf9,0x72,0xd0,0x1f]
-; SKX-NEXT:    vpaddd %xmm2, %xmm0, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xf9,0xfe,0xc2]
+; SKX-NEXT:    vpsubd %xmm0, %xmm2, %xmm0 ## EVEX TO VEX Compression encoding: [0xc5,0xe9,0xfa,0xc0]
 ; SKX-NEXT:    vzeroupper ## encoding: [0xc5,0xf8,0x77]
 ; SKX-NEXT:    retq ## encoding: [0xc3]
   %a = icmp eq <4 x i64> %cond1, zeroinitializer