[X86] Simplify the predicates for avx2 masked gather patterns.
authorCraig Topper <craig.topper@intel.com>
Tue, 21 Nov 2017 06:01:20 +0000 (06:01 +0000)
committerCraig Topper <craig.topper@intel.com>
Tue, 21 Nov 2017 06:01:20 +0000 (06:01 +0000)
We don't need a dyn_cast and we only need to check the type of the index. The base ptr is guaranteed to be scalar.

llvm-svn: 318730

llvm/lib/Target/X86/X86InstrFragmentsSIMD.td

index 33d4557..2dc29ca 100644 (file)
@@ -1126,66 +1126,50 @@ def avx2_masked_gather_64 : SDNode<"ISD::MGATHER",
 // dword gathers
 def avx2_mvpgatherdd_ps_xmm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_32 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v4i32 ||
-            Mgt->getBasePtr().getValueType() == MVT::v4i32);
-  return false;
+  MaskedGatherSDNode *Mgt = cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v4i32;
 }]>;
 
 def avx2_mvpgatherqd_ps_xmm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_x86_masked_gather_32 node:$src1, node:$src2, node:$src3) , [{
-  if (X86MaskedGatherSDNode *Mgt = dyn_cast<X86MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v2i64 ||
-            Mgt->getBasePtr().getValueType() == MVT::v2i64);
-  return false;
+  X86MaskedGatherSDNode *Mgt = cast<X86MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v2i64;
 }]>;
 
 def avx2_mvpgatherdd_ps_ymm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_32 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v8i32 ||
-            Mgt->getBasePtr().getValueType() == MVT::v8i32);
-  return false;
+  MaskedGatherSDNode *Mgt = cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v8i32;
 }]>;
 
 def avx2_mvpgatherqd_ps_ymm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_32 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v4i64 ||
-            Mgt->getBasePtr().getValueType() == MVT::v4i64);
-  return false;
+  MaskedGatherSDNode *Mgt = cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v4i64;
 }]>;
 
 // qwords
 def avx2_mvpgatherdq_pd_xmm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_64 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v2i32 ||
-            Mgt->getBasePtr().getValueType() == MVT::v2i32);
-  return false;
+  MaskedGatherSDNode *Mgt = cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v2i32;
 }]>;
 
 def avx2_mvpgatherqq_pd_xmm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_64 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v2i64 ||
-            Mgt->getBasePtr().getValueType() == MVT::v2i64) &&
-            Mgt->getMemoryVT().is128BitVector();
-  return false;
+  MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v2i64 &&
+         Mgt->getMemoryVT().is128BitVector();
 }]>;
 
 def avx2_mvpgatherdq_pd_ymm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_64 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v4i32 ||
-            Mgt->getBasePtr().getValueType() == MVT::v4i32);
-  return false;
+  MaskedGatherSDNode *Mgt = cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v4i32;
 }]>;
 
 def avx2_mvpgatherqq_pd_ymm : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (avx2_masked_gather_64 node:$src1, node:$src2, node:$src3) , [{
-  if (MaskedGatherSDNode *Mgt = dyn_cast<MaskedGatherSDNode>(N))
-    return (Mgt->getIndex().getValueType() == MVT::v4i64 ||
-            Mgt->getBasePtr().getValueType() == MVT::v4i64);
-  return false;
+  MaskedGatherSDNode *Mgt = cast<MaskedGatherSDNode>(N);
+  return Mgt->getIndex().getValueType() == MVT::v4i64;
 }]>;