[X86][SSE] Improve lowering of vXi64 multiply with known zero 32-bit halves
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 17 Nov 2016 12:14:49 +0000 (12:14 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 17 Nov 2016 12:14:49 +0000 (12:14 +0000)
vXi64 multiplication is lowered into 3 calls of vpmuludq with the upper/lower 32-bit halves.

If any of these halves are zero then we can remove individual calls. Although there was isBuildVectorAllZeros code to do this I don't think it ever worked (maybe just for constant folded cases that don't seem to be tested for any longer).

This requires additional X86ISD support for computeKnownBitsForTargetNode, so far I've just added support for X86ISD::VZEXT (VPMOVZX* - helping the AVX2+ cases).

Partial fix for PR30845

Differential Revision: https://reviews.llvm.org/D26590

llvm-svn: 287223

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/pmul.ll

index 5d81deb..838e75b 100644 (file)
@@ -20144,33 +20144,43 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
   //  AloBhi = psllqi(AloBhi, 32);
   //  AhiBlo = psllqi(AhiBlo, 32);
   //  return AloBlo + AloBhi + AhiBlo;
+  APInt LowerBitsMask = APInt::getLowBitsSet(64, 32);
+  bool ALoiIsZero = DAG.MaskedValueIsZero(A, LowerBitsMask);
+  bool BLoiIsZero = DAG.MaskedValueIsZero(B, LowerBitsMask);
 
-  SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
-  SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
+  APInt UpperBitsMask = APInt::getHighBitsSet(64, 32);
+  bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask);
+  bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask);
 
-  SDValue AhiBlo = Ahi;
-  SDValue AloBhi = Bhi;
   // Bit cast to 32-bit vectors for MULUDQ
   MVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 :
                                   (VT == MVT::v4i64) ? MVT::v8i32 : MVT::v16i32;
-  A = DAG.getBitcast(MulVT, A);
-  B = DAG.getBitcast(MulVT, B);
-  Ahi = DAG.getBitcast(MulVT, Ahi);
-  Bhi = DAG.getBitcast(MulVT, Bhi);
-
-  SDValue AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, B);
-  // After shifting right const values the result may be all-zero.
-  if (!ISD::isBuildVectorAllZeros(Ahi.getNode())) {
-    AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
-    AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
-  }
-  if (!ISD::isBuildVectorAllZeros(Bhi.getNode())) {
-    AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
+  SDValue Alo = DAG.getBitcast(MulVT, A);
+  SDValue Blo = DAG.getBitcast(MulVT, B);
+
+  SDValue Res;
+
+  // Only multiply lo/hi halves that aren't known to be zero.
+  if (!ALoiIsZero && !BLoiIsZero)
+    Res = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Blo);
+
+  if (!ALoiIsZero && !BHiIsZero) {
+    SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
+    Bhi = DAG.getBitcast(MulVT, Bhi);
+    SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Bhi);
     AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG);
+    Res = (Res.getNode() ? DAG.getNode(ISD::ADD, dl, VT, Res, AloBhi) : AloBhi);
   }
 
-  SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi);
-  return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo);
+  if (!AHiIsZero && !BLoiIsZero) {
+    SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
+    Ahi = DAG.getBitcast(MulVT, Ahi);
+    SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, Blo);
+    AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
+    Res = (Res.getNode() ? DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo) : AhiBlo);
+  }
+
+  return (Res.getNode() ? Res : getZeroVector(VT, Subtarget, DAG, dl));
 }
 
 static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
@@ -25256,6 +25266,20 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - NumLoBits);
     break;
   }
+  case X86ISD::VZEXT: {
+    SDValue N0 = Op.getOperand(0);
+    unsigned NumElts = Op.getValueType().getVectorNumElements();
+    unsigned InNumElts = N0.getValueType().getVectorNumElements();
+    unsigned InBitWidth = N0.getValueType().getScalarSizeInBits();
+
+    KnownZero = KnownOne = APInt(InBitWidth, 0);
+    APInt DemandedElts = APInt::getLowBitsSet(InNumElts, NumElts);
+    DAG.computeKnownBits(N0, KnownZero, KnownOne, DemandedElts, Depth + 1);
+    KnownOne = KnownOne.zext(BitWidth);
+    KnownZero = KnownZero.zext(BitWidth);
+    KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - InBitWidth);
+    break;
+  }
   }
 }
 
index 2117c9d..b72f6cf 100644 (file)
@@ -1225,15 +1225,7 @@ define <4 x i32> @mul_v4i64_zero_upper(<4 x i32> %val1, <4 x i32> %val2) {
 ; AVX2:       # BB#0: # %entry
 ; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
-; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX2-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX2-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX2-NEXT:    vpsrlq $32, %ymm0, %ymm0
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
-; AVX2-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
@@ -1245,15 +1237,7 @@ define <4 x i32> @mul_v4i64_zero_upper(<4 x i32> %val1, <4 x i32> %val2) {
 ; AVX512:       # BB#0: # %entry
 ; AVX512-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX512-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
-; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX512-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX512-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX512-NEXT:    vpsrlq $32, %ymm0, %ymm0
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
-; AVX512-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
@@ -1338,13 +1322,9 @@ define <4 x i32> @mul_v4i64_zero_upper_left(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX2:       # BB#0: # %entry
 ; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX2-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX2-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX2-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX2-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
@@ -1357,13 +1337,9 @@ define <4 x i32> @mul_v4i64_zero_upper_left(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX512:       # BB#0: # %entry
 ; AVX512-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX512-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX512-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX512-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX512-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX512-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
@@ -1455,13 +1431,9 @@ define <4 x i32> @mul_v4i64_zero_lower(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX2-NEXT:    vpxor %ymm2, %ymm2, %ymm2
 ; AVX2-NEXT:    vpblendd {{.*#+}} ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7]
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX2-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX2-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX2-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX2-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
@@ -1476,13 +1448,9 @@ define <4 x i32> @mul_v4i64_zero_lower(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX512-NEXT:    vpxor %ymm2, %ymm2, %ymm2
 ; AVX512-NEXT:    vpblendd {{.*#+}} ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7]
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX512-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX512-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX512-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX512-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX512-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]