[x86] use demanded bits to simplify masked store codegen
authorSanjay Patel <spatel@rotateright.com>
Tue, 9 Oct 2018 14:04:14 +0000 (14:04 +0000)
committerSanjay Patel <spatel@rotateright.com>
Tue, 9 Oct 2018 14:04:14 +0000 (14:04 +0000)
As noted in D52747, if we prefer IR to use trunc for bool vectors rather
than and+icmp, we can expose codegen shortcomings as seen here with masked store.

Replace a hard-coded PCMPGT simplification with the more general demanded bits call
to improve things.

Differential Revision: https://reviews.llvm.org/D52964

llvm-svn: 344048

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/masked_memop.ll

index 76d5f8e..2256fe0 100644 (file)
@@ -36521,31 +36521,31 @@ static SDValue reduceMaskedStoreToScalarStore(MaskedStoreSDNode *MS,
 }
 
 static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
+                                  TargetLowering::DAGCombinerInfo &DCI,
                                   const X86Subtarget &Subtarget) {
   MaskedStoreSDNode *Mst = cast<MaskedStoreSDNode>(N);
-
   if (Mst->isCompressingStore())
     return SDValue();
 
+  EVT VT = Mst->getValue().getValueType();
   if (!Mst->isTruncatingStore()) {
     if (SDValue ScalarStore = reduceMaskedStoreToScalarStore(Mst, DAG))
       return ScalarStore;
 
-    // If the mask is checking (0 > X), we're creating a vector with all-zeros
-    // or all-ones elements based on the sign bits of X. AVX1 masked store only
-    // cares about the sign bit of each mask element, so eliminate the compare:
-    // mstore val, ptr, (pcmpgt 0, X) --> mstore val, ptr, X
-    // Note that by waiting to match an x86-specific PCMPGT node, we're
-    // eliminating potentially more complex matching of a setcc node which has
-    // a full range of predicates.
+    // If the mask value has been legalized to a non-boolean vector, try to
+    // simplify ops leading up to it. We only demand the MSB of each lane.
     SDValue Mask = Mst->getMask();
-    if (Mask.getOpcode() == X86ISD::PCMPGT &&
-        ISD::isBuildVectorAllZeros(Mask.getOperand(0).getNode())) {
-      assert(Mask.getValueType() == Mask.getOperand(1).getValueType() &&
-             "Unexpected type for PCMPGT");
-      return DAG.getMaskedStore(
-          Mst->getChain(), SDLoc(N), Mst->getValue(), Mst->getBasePtr(),
-          Mask.getOperand(1), Mst->getMemoryVT(), Mst->getMemOperand());
+    if (Mask.getScalarValueSizeInBits() != 1) {
+      TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+                                            !DCI.isBeforeLegalizeOps());
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+      APInt DemandedMask(APInt::getSignMask(VT.getScalarSizeInBits()));
+      KnownBits Known;
+      if (TLI.SimplifyDemandedBits(Mask, DemandedMask, Known, TLO)) {
+        DCI.AddToWorklist(Mask.getNode());
+        DCI.CommitTargetLoweringOpt(TLO);
+        return SDValue(N, 0);
+      }
     }
 
     // TODO: AVX512 targets should also be able to simplify something like the
@@ -36556,7 +36556,6 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
   }
 
   // Resolve truncating stores.
-  EVT VT = Mst->getValue().getValueType();
   unsigned NumElems = VT.getVectorNumElements();
   EVT StVT = Mst->getMemoryVT();
   SDLoc dl(Mst);
@@ -40382,7 +40381,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::LOAD:           return combineLoad(N, DAG, DCI, Subtarget);
   case ISD::MLOAD:          return combineMaskedLoad(N, DAG, DCI, Subtarget);
   case ISD::STORE:          return combineStore(N, DAG, Subtarget);
-  case ISD::MSTORE:         return combineMaskedStore(N, DAG, Subtarget);
+  case ISD::MSTORE:         return combineMaskedStore(N, DAG, DCI, Subtarget);
   case ISD::SINT_TO_FP:     return combineSIntToFP(N, DAG, Subtarget);
   case ISD::UINT_TO_FP:     return combineUIntToFP(N, DAG, Subtarget);
   case ISD::FADD:
index 122704e..f1ab2ae 100644 (file)
@@ -1278,13 +1278,12 @@ define void @trunc_mask(<4 x float> %x, <4 x float>* %ptr, <4 x float> %y, <4 x
   ret void
 }
 
-; TODO: SimplifyDemandedBits should eliminate an ashr here.
+; SimplifyDemandedBits eliminates an ashr here.
 
 define void @masked_store_bool_mask_demand_trunc_sext(<4 x double> %x, <4 x double>* %p, <4 x i32> %masksrc) {
 ; AVX1-LABEL: masked_store_bool_mask_demand_trunc_sext:
 ; AVX1:       ## %bb.0:
 ; AVX1-NEXT:    vpslld $31, %xmm1, %xmm1
-; AVX1-NEXT:    vpsrad $31, %xmm1, %xmm1
 ; AVX1-NEXT:    vpmovsxdq %xmm1, %xmm2
 ; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[2,3,0,1]
 ; AVX1-NEXT:    vpmovsxdq %xmm1, %xmm1
@@ -1296,7 +1295,6 @@ define void @masked_store_bool_mask_demand_trunc_sext(<4 x double> %x, <4 x doub
 ; AVX2-LABEL: masked_store_bool_mask_demand_trunc_sext:
 ; AVX2:       ## %bb.0:
 ; AVX2-NEXT:    vpslld $31, %xmm1, %xmm1
-; AVX2-NEXT:    vpsrad $31, %xmm1, %xmm1
 ; AVX2-NEXT:    vpmovsxdq %xmm1, %ymm1
 ; AVX2-NEXT:    vmaskmovpd %ymm0, %ymm1, (%rdi)
 ; AVX2-NEXT:    vzeroupper
@@ -1338,7 +1336,6 @@ define void @widen_masked_store(<3 x i32> %v, <3 x i32>* %p, <3 x i1> %mask) {
 ; AVX1-NEXT:    vmovd %ecx, %xmm2
 ; AVX1-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
 ; AVX1-NEXT:    vpslld $31, %xmm1, %xmm1
-; AVX1-NEXT:    vpsrad $31, %xmm1, %xmm1
 ; AVX1-NEXT:    vmaskmovps %xmm0, %xmm1, (%rdi)
 ; AVX1-NEXT:    retq
 ;
@@ -1350,7 +1347,6 @@ define void @widen_masked_store(<3 x i32> %v, <3 x i32>* %p, <3 x i1> %mask) {
 ; AVX2-NEXT:    vmovd %ecx, %xmm2
 ; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
 ; AVX2-NEXT:    vpslld $31, %xmm1, %xmm1
-; AVX2-NEXT:    vpsrad $31, %xmm1, %xmm1
 ; AVX2-NEXT:    vpmaskmovd %xmm0, %xmm1, (%rdi)
 ; AVX2-NEXT:    retq
 ;