[X86] Refactor the broadcast and load folding in tryVPTESTM to reduce some code.
authorCraig Topper <craig.topper@intel.com>
Sat, 1 Aug 2020 06:10:47 +0000 (23:10 -0700)
committerCraig Topper <craig.topper@intel.com>
Sat, 1 Aug 2020 06:57:13 +0000 (23:57 -0700)
Now we try to load and broadcast together for operand 1. Followed
by load and broadcast for operand 1. Previously we tried load
operand 1, load operand 1, broadcast operand 0, broadcast operand 1.

Now we have a single helper that tries load and broadcast for
one operand that we can just call twice.

llvm/lib/Target/X86/X86ISelDAGToDAG.cpp

index bb04690e04d162318149888a45a28e8bb044e1d2..58424892535a7fcc1973590582ce684423c4a6e5 100644 (file)
@@ -4207,15 +4207,15 @@ VPTESTM_CASE(v16i16, WZ256##SUFFIX) \
 VPTESTM_CASE(v64i8, BZ##SUFFIX) \
 VPTESTM_CASE(v32i16, WZ##SUFFIX)
 
-  if (FoldedLoad) {
+  if (FoldedBCast) {
     switch (TestVT.SimpleTy) {
-    VPTESTM_FULL_CASES(rm)
+    VPTESTM_BROADCAST_CASES(rmb)
     }
   }
 
-  if (FoldedBCast) {
+  if (FoldedLoad) {
     switch (TestVT.SimpleTy) {
-    VPTESTM_BROADCAST_CASES(rmb)
+    VPTESTM_FULL_CASES(rm)
     }
   }
 
@@ -4274,68 +4274,57 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     }
   }
 
-  // Without VLX we need to widen the load.
+  // Without VLX we need to widen the operation.
   bool Widen = !Subtarget->hasVLX() && !CmpVT.is512BitVector();
 
-  // We can only fold loads if the sources are unique.
-  bool CanFoldLoads = Src0 != Src1;
+  auto tryFoldLoadOrBCast = [&](SDNode *Root, SDNode *P, SDValue &L,
+                                SDValue &Base, SDValue &Scale, SDValue &Index,
+                                SDValue &Disp, SDValue &Segment) {
+    // If we need to widen, we can't fold the load.
+    if (!Widen)
+      if (tryFoldLoad(Root, P, L, Base, Scale, Index, Disp, Segment))
+        return true;
 
-  // Try to fold loads unless we need to widen.
-  bool FoldedLoad = false;
-  SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, Load;
-  if (!Widen && CanFoldLoads) {
-    Load = Src1;
-    FoldedLoad = tryFoldLoad(Root, N0.getNode(), Load, Tmp0, Tmp1, Tmp2, Tmp3,
-                             Tmp4);
-    if (!FoldedLoad) {
-      // And is computative.
-      Load = Src0;
-      FoldedLoad = tryFoldLoad(Root, N0.getNode(), Load, Tmp0, Tmp1, Tmp2,
-                               Tmp3, Tmp4);
-      if (FoldedLoad)
-        std::swap(Src0, Src1);
-    }
-  }
+    // If we didn't fold a load, try to match broadcast. No widening limitation
+    // for this. But only 32 and 64 bit types are supported.
+    if (CmpSVT != MVT::i32 && CmpSVT != MVT::i64)
+      return false;
 
-  auto findBroadcastedOp = [](SDValue Src, MVT CmpSVT, SDNode *&Parent) {
     // Look through single use bitcasts.
-    if (Src.getOpcode() == ISD::BITCAST && Src.hasOneUse()) {
-      Parent = Src.getNode();
-      Src = Src.getOperand(0);
+    if (L.getOpcode() == ISD::BITCAST && L.hasOneUse()) {
+      P = L.getNode();
+      L = L.getOperand(0);
     }
 
-    if (Src.getOpcode() == X86ISD::VBROADCAST_LOAD && Src.hasOneUse()) {
-      auto *MemIntr = cast<MemIntrinsicSDNode>(Src);
-      if (MemIntr->getMemoryVT().getSizeInBits() == CmpSVT.getSizeInBits())
-        return Src;
-    }
+    if (L.getOpcode() != X86ISD::VBROADCAST_LOAD)
+      return false;
 
-    return SDValue();
+    auto *MemIntr = cast<MemIntrinsicSDNode>(L);
+    if (MemIntr->getMemoryVT().getSizeInBits() != CmpSVT.getSizeInBits())
+      return false;
+
+    return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment);
   };
 
-  // If we didn't fold a load, try to match broadcast. No widening limitation
-  // for this. But only 32 and 64 bit types are supported.
-  bool FoldedBCast = false;
-  if (!FoldedLoad && CanFoldLoads &&
-      (CmpSVT == MVT::i32 || CmpSVT == MVT::i64)) {
-    SDNode *ParentNode = N0.getNode();
-    if ((Load = findBroadcastedOp(Src1, CmpSVT, ParentNode))) {
-      FoldedBCast = tryFoldBroadcast(Root, ParentNode, Load, Tmp0,
-                                     Tmp1, Tmp2, Tmp3, Tmp4);
-    }
+  // We can only fold loads if the sources are unique.
+  bool CanFoldLoads = Src0 != Src1;
 
-    // Try the other operand.
-    if (!FoldedBCast) {
-      SDNode *ParentNode = N0.getNode();
-      if ((Load = findBroadcastedOp(Src0, CmpSVT, ParentNode))) {
-        FoldedBCast = tryFoldBroadcast(Root, ParentNode, Load, Tmp0,
-                                       Tmp1, Tmp2, Tmp3, Tmp4);
-        if (FoldedBCast)
-          std::swap(Src0, Src1);
-      }
+  bool FoldedLoad = false;
+  SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
+  if (CanFoldLoads) {
+    FoldedLoad = tryFoldLoadOrBCast(Root, N0.getNode(), Src1, Tmp0, Tmp1, Tmp2,
+                                    Tmp3, Tmp4);
+    if (!FoldedLoad) {
+      // And is commutative.
+      FoldedLoad = tryFoldLoadOrBCast(Root, N0.getNode(), Src0, Tmp0, Tmp1,
+                                      Tmp2, Tmp3, Tmp4);
+      if (FoldedLoad)
+        std::swap(Src0, Src1);
     }
   }
 
+  bool FoldedBCast = FoldedLoad && Src1.getOpcode() == X86ISD::VBROADCAST_LOAD;
+
   bool IsMasked = InMask.getNode() != nullptr;
 
   SDLoc dl(Root);
@@ -4353,7 +4342,6 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
                                                      CmpVT), 0);
     Src0 = CurDAG->getTargetInsertSubreg(SubReg, dl, CmpVT, ImplDef, Src0);
 
-    assert(!FoldedLoad && "Shouldn't have folded the load");
     if (!FoldedBCast)
       Src1 = CurDAG->getTargetInsertSubreg(SubReg, dl, CmpVT, ImplDef, Src1);
 
@@ -4371,23 +4359,23 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
                                IsMasked);
 
   MachineSDNode *CNode;
-  if (FoldedLoad || FoldedBCast) {
+  if (FoldedLoad) {
     SDVTList VTs = CurDAG->getVTList(MaskVT, MVT::Other);
 
     if (IsMasked) {
       SDValue Ops[] = { InMask, Src0, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4,
-                        Load.getOperand(0) };
+                        Src1.getOperand(0) };
       CNode = CurDAG->getMachineNode(Opc, dl, VTs, Ops);
     } else {
       SDValue Ops[] = { Src0, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4,
-                        Load.getOperand(0) };
+                        Src1.getOperand(0) };
       CNode = CurDAG->getMachineNode(Opc, dl, VTs, Ops);
     }
 
     // Update the chain.
-    ReplaceUses(Load.getValue(1), SDValue(CNode, 1));
+    ReplaceUses(Src1.getValue(1), SDValue(CNode, 1));
     // Record the mem-refs
-    CurDAG->setNodeMemRefs(CNode, {cast<MemSDNode>(Load)->getMemOperand()});
+    CurDAG->setNodeMemRefs(CNode, {cast<MemSDNode>(Src1)->getMemOperand()});
   } else {
     if (IsMasked)
       CNode = CurDAG->getMachineNode(Opc, dl, MaskVT, InMask, Src0, Src1);