[X86] Custom type legalize v2f32 masked gathers instead of trying to cleanup after...
authorCraig Topper <craig.topper@intel.com>
Thu, 16 Nov 2017 02:07:45 +0000 (02:07 +0000)
committerCraig Topper <craig.topper@intel.com>
Thu, 16 Nov 2017 02:07:45 +0000 (02:07 +0000)
llvm-svn: 318368

llvm/lib/Target/X86/X86ISelLowering.cpp

index 532097bbc13907d58754a81d89eb27fa0b01247d..a89920d553777694fc26b03bdd6bcb0a8d80dc99 100644 (file)
@@ -1370,6 +1370,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
         setOperationAction(ISD::CTPOP, VT, Legal);
     }
 
+    // Custom legalize 2x32 to get a little better code.
+    if (Subtarget.hasVLX()) {
+      setOperationAction(ISD::MGATHER, MVT::v2f32, Custom);
+    }
+
     // Custom lower several nodes.
     for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
                      MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) {
@@ -24378,32 +24383,6 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
     SDValue RetOps[] = { Sext, NewGather.getValue(1) };
     return DAG.getMergeValues(RetOps, dl);
   }
-  if (N->getMemoryVT() == MVT::v2f32 && Subtarget.hasVLX()) {
-    // This transformation is for optimization only.
-    // The type legalizer extended mask and index to 4 elements vector
-    // in order to match requirements of the common gather node - same
-    // vector width of index and value. X86 Gather node allows mismatch
-    // of vector width in order to select more optimal instruction at the
-    // end.
-    assert(VT == MVT::v4f32 && Src0.getValueType() == MVT::v4f32 &&
-           "Unexpected type in masked gather");
-    if (Mask.getOpcode() == ISD::CONCAT_VECTORS &&
-        ISD::isBuildVectorAllZeros(Mask.getOperand(1).getNode()) &&
-        Index.getOpcode() == ISD::CONCAT_VECTORS &&
-        Index.getOperand(1).isUndef()) {
-      Mask = Mask.getOperand(0);
-      Index = Index.getOperand(0);
-    } else
-      return Op;
-    SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
-    SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
-      DAG.getVTList(MVT::v4f32, MVT::v2i1, MVT::Other), Ops, dl,
-      N->getMemoryVT(), N->getMemOperand());
-
-    SDValue RetOps[] = { NewGather.getValue(0), NewGather.getValue(2) };
-    return DAG.getMergeValues(RetOps, dl);
-
-  }
   return Op;
 }
 
@@ -24902,6 +24881,29 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
     Results.push_back(DAG.getBuildVector(DstVT, dl, Elts));
     return;
   }
+  case ISD::MGATHER: {
+    EVT VT = N->getValueType(0);
+    if (VT == MVT::v2f32 && Subtarget.hasVLX()) {
+      auto *Gather = cast<MaskedGatherSDNode>(N);
+      SDValue Index = Gather->getIndex();
+      if (Index.getValueType() != MVT::v2i64)
+        return;
+      SDValue Mask = Gather->getMask();
+      assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
+      SDValue Src0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32,
+                                 Gather->getValue(),
+                                 DAG.getUNDEF(MVT::v2f32));
+      SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
+                        Index };
+      SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
+        DAG.getVTList(MVT::v4f32, MVT::v2i1, MVT::Other), Ops, dl,
+        Gather->getMemoryVT(), Gather->getMemOperand());
+      Results.push_back(Res);
+      Results.push_back(Res.getValue(2));
+      return;
+    }
+    break;
+  }
   }
 }