[X86] Fold PMADD(x,0) or PMADD(0,x) -> 0
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 2 Sep 2021 09:10:08 +0000 (10:10 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 2 Sep 2021 09:48:50 +0000 (10:48 +0100)
Pulled out of D108522 - handle zero-operand cases for PMADDWD/VPMADDUBSW ops

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/combine-pmadd.ll

index 17fd32e..f308163 100644 (file)
@@ -51824,6 +51824,21 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Simplify VPMADDUBSW/VPMADDWD operations.
+static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
+                             TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  // Multiply by zero.
+  // Don't return LHS/RHS as it may contain UNDEFs.
+  if (ISD::isBuildVectorAllZeros(LHS.getNode()) ||
+      ISD::isBuildVectorAllZeros(RHS.getNode()))
+    return DAG.getConstant(0, SDLoc(N), N->getValueType(0));
+
+  return SDValue();
+}
+
 static SDValue combineEXTEND_VECTOR_INREG(SDNode *N, SelectionDAG &DAG,
                                           TargetLowering::DAGCombinerInfo &DCI,
                                           const X86Subtarget &Subtarget) {
@@ -52274,6 +52289,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
   case X86ISD::PMULDQ:
   case X86ISD::PMULUDQ:     return combinePMULDQ(N, DAG, DCI, Subtarget);
+  case X86ISD::VPMADDUBSW:
+  case X86ISD::VPMADDWD:    return combineVPMADD(N, DAG, DCI);
   case X86ISD::KSHIFTL:
   case X86ISD::KSHIFTR:     return combineKSHIFT(N, DAG, DCI);
   case ISD::FP16_TO_FP:     return combineFP16_TO_FP(N, DAG, Subtarget);
index f84e2e8..e8a61ee 100644 (file)
@@ -9,14 +9,12 @@ declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind
 define <4 x i32> @combine_pmaddwd_zero(<8 x i16> %a0, <8 x i16> %a1) {
 ; SSE-LABEL: combine_pmaddwd_zero:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddwd %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddwd_zero:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a0, <8 x i16> zeroinitializer)
   ret <4 x i32> %1
@@ -25,14 +23,12 @@ define <4 x i32> @combine_pmaddwd_zero(<8 x i16> %a0, <8 x i16> %a1) {
 define <4 x i32> @combine_pmaddwd_zero_commute(<8 x i16> %a0, <8 x i16> %a1) {
 ; SSE-LABEL: combine_pmaddwd_zero_commute:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddwd %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddwd_zero_commute:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddwd %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> zeroinitializer, <8 x i16> %a0)
   ret <4 x i32> %1
@@ -41,14 +37,12 @@ define <4 x i32> @combine_pmaddwd_zero_commute(<8 x i16> %a0, <8 x i16> %a1) {
 define <8 x i16> @combine_pmaddubsw_zero(<16 x i8> %a0, <16 x i8> %a1) {
 ; SSE-LABEL: combine_pmaddubsw_zero:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddubsw %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddubsw_zero:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddubsw %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a0, <16 x i8> zeroinitializer)
   ret <8 x i16> %1
@@ -57,15 +51,12 @@ define <8 x i16> @combine_pmaddubsw_zero(<16 x i8> %a0, <16 x i8> %a1) {
 define <8 x i16> @combine_pmaddubsw_zero_commute(<16 x i8> %a0, <16 x i8> %a1) {
 ; SSE-LABEL: combine_pmaddubsw_zero_commute:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddubsw %xmm0, %xmm1
-; SSE-NEXT:    movdqa %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddubsw_zero_commute:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddubsw %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> zeroinitializer, <16 x i8> %a0)
   ret <8 x i16> %1