[X86][SSE] combineMulToPMADDWD - replace sext(v8i16) -> zext(v8i16)
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 24 Sep 2021 14:30:04 +0000 (15:30 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 24 Sep 2021 15:42:01 +0000 (16:42 +0100)
As suggested on D108522, if we're sign extending a v4i16 source before multiplying as a v4i32, then we can replace that with a zero extension and rely on the implicit sign-extension of PMADDWD.

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/madd.ll

index c1e8531..1e79620 100644 (file)
@@ -44258,10 +44258,29 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
   if (DAG.ComputeNumSignBits(N1) < 17 || DAG.ComputeNumSignBits(N0) < 17)
     return SDValue();
 
-  // At least one of the elements must be zero in the upper 17 bits.
-  APInt Mask17 = APInt::getHighBitsSet(32, 17);
-  if (!DAG.MaskedValueIsZero(N1, Mask17) && !DAG.MaskedValueIsZero(N0, Mask17))
+  // At least one of the elements must be zero in the upper 17 bits, or can be
+  // safely made zero without altering the final result.
+  auto GetZeroableOp = [&](SDValue Op) {
+    APInt Mask17 = APInt::getHighBitsSet(32, 17);
+    if (DAG.MaskedValueIsZero(Op, Mask17))
+      return Op;
+    // Convert sext(vXi16) to zext(vXi16).
+    // TODO: Enable pre-SSE41 once we can prefer MULHU/MULHS first.
+    // TODO: Handle sext from smaller types as well?
+    if (Op.getOpcode() == ISD::SIGN_EXTEND && VT.is128BitVector() &&
+        Subtarget.hasSSE41() && N->isOnlyUserOf(Op.getNode())) {
+      SDValue Src = Op.getOperand(0);
+      if (Src.getScalarValueSizeInBits() == 16)
+        return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, Src);
+    }
+    return SDValue();
+  };
+  SDValue ZeroN0 = GetZeroableOp(N0);
+  SDValue ZeroN1 = GetZeroableOp(N1);
+  if (!ZeroN0 && !ZeroN1)
     return SDValue();
+  N0 = ZeroN0 ? ZeroN0 : N0;
+  N1 = ZeroN1 ? ZeroN1 : N1;
 
   // Use SplitOpsAndApply to handle AVX splitting.
   auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
index 6b83d2f..330ca90 100644 (file)
@@ -40,9 +40,9 @@ define i32 @_Z10test_shortPsS_i_128(i16* nocapture readonly, i16* nocapture read
 ; AVX-NEXT:    .p2align 4, 0x90
 ; AVX-NEXT:  .LBB0_1: # %vector.body
 ; AVX-NEXT:    # =>This Inner Loop Header: Depth=1
-; AVX-NEXT:    vpmovsxwd (%rdi,%rcx,2), %xmm1
-; AVX-NEXT:    vpmovsxwd (%rsi,%rcx,2), %xmm2
-; AVX-NEXT:    vpmulld %xmm1, %xmm2, %xmm1
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm2 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero
+; AVX-NEXT:    vpmaddwd %xmm1, %xmm2, %xmm1
 ; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    addq $8, %rcx
 ; AVX-NEXT:    cmpq %rcx, %rax
@@ -2603,16 +2603,13 @@ define <4 x i32> @pmaddwd_bad_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovdqa (%rdi), %xmm0
 ; AVX-NEXT:    vmovdqa (%rsi), %xmm1
-; AVX-NEXT:    vpshufb {{.*#+}} xmm2 = xmm0[2,3,4,5,10,11,12,13,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,1,6,7,8,9,14,15,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpshufb {{.*#+}} xmm3 = xmm1[0,1,4,5,8,9,12,13,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpshufb {{.*#+}} xmm1 = xmm1[2,3,6,7,10,11,14,15,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX-NEXT:    vpmulld %xmm3, %xmm2, %xmm2
-; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX-NEXT:    vpblendw {{.*#+}} xmm2 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX-NEXT:    vpshufb {{.*#+}} xmm3 = xmm0[2,3],zero,zero,xmm0[4,5],zero,zero,xmm0[10,11],zero,zero,xmm0[12,13],zero,zero
+; AVX-NEXT:    vpmaddwd %xmm2, %xmm3, %xmm2
+; AVX-NEXT:    vpsrld $16, %xmm1, %xmm1
+; AVX-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,1],zero,zero,xmm0[6,7],zero,zero,xmm0[8,9],zero,zero,xmm0[14,15],zero,zero
+; AVX-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
 ; AVX-NEXT:    retq
   %A = load <8 x i16>, <8 x i16>* %Aptr
@@ -3105,38 +3102,28 @@ define <4 x i32> @output_size_mismatch_high_subvector(<16 x i16> %x, <16 x i16>
 ;
 ; AVX1-LABEL: output_size_mismatch_high_subvector:
 ; AVX1:       # %bb.0:
-; AVX1-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
 ; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm0
-; AVX1-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
-; AVX1-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX1-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
-; AVX1-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
-; AVX1-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
-; AVX1-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX1-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX1-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX1-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
-; AVX1-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX1-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm3 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm2 = xmm0[0],xmm2[1],xmm0[2],xmm2[3],xmm0[4],xmm2[5],xmm0[6],xmm2[7]
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm2, %xmm2
+; AVX1-NEXT:    vpsrld $16, %xmm1, %xmm1
+; AVX1-NEXT:    vpsrld $16, %xmm0, %xmm0
+; AVX1-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX1-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
 ; AVX1-NEXT:    vzeroupper
 ; AVX1-NEXT:    retq
 ;
 ; AVX256-LABEL: output_size_mismatch_high_subvector:
 ; AVX256:       # %bb.0:
-; AVX256-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
 ; AVX256-NEXT:    vextracti128 $1, %ymm0, %xmm0
-; AVX256-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
-; AVX256-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX256-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
-; AVX256-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
-; AVX256-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
-; AVX256-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX256-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX256-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX256-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
-; AVX256-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX256-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX256-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX256-NEXT:    vpblendw {{.*#+}} xmm3 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX256-NEXT:    vpblendw {{.*#+}} xmm2 = xmm0[0],xmm2[1],xmm0[2],xmm2[3],xmm0[4],xmm2[5],xmm0[6],xmm2[7]
+; AVX256-NEXT:    vpmaddwd %xmm3, %xmm2, %xmm2
+; AVX256-NEXT:    vpsrld $16, %xmm1, %xmm1
+; AVX256-NEXT:    vpsrld $16, %xmm0, %xmm0
+; AVX256-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX256-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
 ; AVX256-NEXT:    vzeroupper
 ; AVX256-NEXT:    retq