[X86][SSE] Pulled out repeated target shuffle decodes into helper functions. NFCI.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 7 Feb 2016 14:33:03 +0000 (14:33 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 7 Feb 2016 14:33:03 +0000 (14:33 +0000)
Pulled out the code used by PSHUFB/VPERMV/VPERMV3 shuffle mask decoding into common helper functions.

The helper functions handle masks coming from BROADCAST/BUILD_VECTOR and ConstantPool nodes respectively.

llvm-svn: 260032

llvm/lib/Target/X86/X86ISelLowering.cpp

index 31be8e3..99c41ad 100644 (file)
@@ -4821,6 +4821,84 @@ static SDValue getShuffleVectorZeroOrUndef(SDValue V2, unsigned Idx,
   return DAG.getVectorShuffle(VT, SDLoc(V2), V1, V2, &MaskVec[0]);
 }
 
+static bool getTargetShuffleMaskIndices(SDValue MaskNode,
+                                        unsigned MaskEltSizeInBits,
+                                        SmallVectorImpl<uint64_t> &RawMask) {
+  while (MaskNode.getOpcode() == ISD::BITCAST)
+    MaskNode = MaskNode.getOperand(0);
+
+  MVT VT = MaskNode.getSimpleValueType();
+  assert(VT.isVector() && "Can't produce a non-vector with a build_vector!");
+
+  if (MaskNode.getOpcode() == X86ISD::VBROADCAST) {
+    if (VT.getScalarSizeInBits() != MaskEltSizeInBits)
+      return false;
+    if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode.getOperand(0))) {
+      APInt MaskElement = CN->getAPIntValue();
+      for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
+        APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
+        RawMask.push_back(RawElt.getZExtValue());
+      }
+    }
+    return false;
+  }
+
+  if (MaskNode.getOpcode() != ISD::BUILD_VECTOR)
+    return false;
+
+  if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) != 0)
+    return false;
+  unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits;
+
+  for (int i = 0, e = MaskNode.getNumOperands(); i < e; ++i) {
+    SDValue Op = MaskNode.getOperand(i);
+    if (Op->getOpcode() == ISD::UNDEF) {
+      RawMask.push_back((uint64_t)SM_SentinelUndef);
+      continue;
+    }
+
+    APInt MaskElement;
+    if (auto *CN = dyn_cast<ConstantSDNode>(Op.getNode()))
+      MaskElement = CN->getAPIntValue();
+    else if (auto *CFN = dyn_cast<ConstantFPSDNode>(Op.getNode()))
+      MaskElement = CFN->getValueAPF().bitcastToAPInt();
+    else
+      return false;
+
+    // We now have to decode the element which could be any integer size and
+    // extract each byte of it.
+    for (unsigned j = 0; j < ElementSplit; ++j) {
+      // Note that this is x86 and so always little endian: the low byte is
+      // the first byte of the mask.
+      APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
+      RawMask.push_back(RawElt.getZExtValue());
+      MaskElement = MaskElement.lshr(MaskEltSizeInBits);
+    }
+  }
+
+  return true;
+}
+
+static const Constant *getTargetShuffleMaskConstant(SDValue MaskNode) {
+  while (MaskNode.getOpcode() == ISD::BITCAST)
+    MaskNode = MaskNode.getOperand(0);
+
+  auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
+  if (!MaskLoad)
+    return nullptr;
+
+  SDValue Ptr = MaskLoad->getBasePtr();
+  if (Ptr->getOpcode() == X86ISD::Wrapper ||
+      Ptr->getOpcode() == X86ISD::WrapperRIP)
+    Ptr = Ptr->getOperand(0);
+
+  auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
+  if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
+    return nullptr;
+
+  return dyn_cast<Constant>(MaskCP->getConstVal());
+}
+
 /// Calculates the shuffle mask corresponding to the target-specific opcode.
 /// Returns true if the Mask could be calculated. Sets IsUnary to true if only
 /// uses one source. Note that this will set IsUnary for shuffles which use a
@@ -4891,62 +4969,15 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
   case X86ISD::PSHUFB: {
     IsUnary = true;
     SDValue MaskNode = N->getOperand(1);
-    while (MaskNode->getOpcode() == ISD::BITCAST)
-      MaskNode = MaskNode->getOperand(0);
-
-    if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
-      // If we have a build-vector, then things are easy.
-      MVT VT = MaskNode.getSimpleValueType();
-      assert(VT.isVector() &&
-             "Can't produce a non-vector with a build_vector!");
-      if (!VT.isInteger())
-        return false;
-
-      int NumBytesPerElement = VT.getVectorElementType().getSizeInBits() / 8;
-
-      SmallVector<uint64_t, 32> RawMask;
-      for (int i = 0, e = MaskNode->getNumOperands(); i < e; ++i) {
-        SDValue Op = MaskNode->getOperand(i);
-        if (Op->getOpcode() == ISD::UNDEF) {
-          RawMask.push_back((uint64_t)SM_SentinelUndef);
-          continue;
-        }
-        auto *CN = dyn_cast<ConstantSDNode>(Op.getNode());
-        if (!CN)
-          return false;
-        APInt MaskElement = CN->getAPIntValue();
-
-        // We now have to decode the element which could be any integer size and
-        // extract each byte of it.
-        for (int j = 0; j < NumBytesPerElement; ++j) {
-          // Note that this is x86 and so always little endian: the low byte is
-          // the first byte of the mask.
-          RawMask.push_back(MaskElement.getLoBits(8).getZExtValue());
-          MaskElement = MaskElement.lshr(8);
-        }
-      }
+    SmallVector<uint64_t, 32> RawMask;
+    if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask)) {
       DecodePSHUFBMask(RawMask, Mask);
       break;
     }
-
-    auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
-    if (!MaskLoad)
-      return false;
-
-    SDValue Ptr = MaskLoad->getBasePtr();
-    if (Ptr->getOpcode() == X86ISD::Wrapper ||
-        Ptr->getOpcode() == X86ISD::WrapperRIP)
-      Ptr = Ptr->getOperand(0);
-
-    auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
-    if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
-      return false;
-
-    if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
+    if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
       DecodePSHUFBMask(C, Mask);
       break;
     }
-
     return false;
   }
   case X86ISD::VPERMI:
@@ -4983,57 +5014,13 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
   case X86ISD::VPERMV: {
     IsUnary = true;
     SDValue MaskNode = N->getOperand(0);
-    while (MaskNode->getOpcode() == ISD::BITCAST)
-      MaskNode = MaskNode->getOperand(0);
-
-    unsigned MaskLoBits = Log2_64(VT.getVectorNumElements());
     SmallVector<uint64_t, 32> RawMask;
-    if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
-      // If we have a build-vector, then things are easy.
-      assert(MaskNode.getSimpleValueType().isInteger() &&
-             MaskNode.getSimpleValueType().getVectorNumElements() ==
-             VT.getVectorNumElements());
-
-      for (unsigned i = 0; i < MaskNode->getNumOperands(); ++i) {
-        SDValue Op = MaskNode->getOperand(i);
-        if (Op->getOpcode() == ISD::UNDEF)
-          RawMask.push_back((uint64_t)SM_SentinelUndef);
-        else if (isa<ConstantSDNode>(Op)) {
-          APInt MaskElement = cast<ConstantSDNode>(Op)->getAPIntValue();
-          RawMask.push_back(MaskElement.getLoBits(MaskLoBits).getZExtValue());
-        } else
-          return false;
-      }
+    unsigned MaskLoBits = Log2_64(VT.getVectorNumElements());
+    if (getTargetShuffleMaskIndices(MaskNode, MaskLoBits, RawMask)) {
       DecodeVPERMVMask(RawMask, Mask);
       break;
     }
-    if (MaskNode->getOpcode() == X86ISD::VBROADCAST) {
-      unsigned NumEltsInMask = MaskNode->getNumOperands();
-      MaskNode = MaskNode->getOperand(0);
-      if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode)) {
-        APInt MaskEltValue = CN->getAPIntValue();
-        for (unsigned i = 0; i < NumEltsInMask; ++i)
-          RawMask.push_back(MaskEltValue.getLoBits(MaskLoBits).getZExtValue());
-        DecodeVPERMVMask(RawMask, Mask);
-        break;
-      }
-      // It may be a scalar load
-    }
-
-    auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
-    if (!MaskLoad)
-      return false;
-
-    SDValue Ptr = MaskLoad->getBasePtr();
-    if (Ptr->getOpcode() == X86ISD::Wrapper ||
-        Ptr->getOpcode() == X86ISD::WrapperRIP)
-      Ptr = Ptr->getOperand(0);
-
-    auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
-    if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
-      return false;
-
-    if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
+    if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
       DecodeVPERMVMask(C, VT, Mask);
       break;
     }
@@ -5042,48 +5029,14 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero,
   case X86ISD::VPERMV3: {
     IsUnary = false;
     SDValue MaskNode = N->getOperand(1);
-    while (MaskNode->getOpcode() == ISD::BITCAST)
-      MaskNode = MaskNode->getOperand(1);
-
-    if (MaskNode->getOpcode() == ISD::BUILD_VECTOR) {
-      // If we have a build-vector, then things are easy.
-      assert(MaskNode.getSimpleValueType().isInteger() &&
-             MaskNode.getSimpleValueType().getVectorNumElements() ==
-             VT.getVectorNumElements());
-
-      SmallVector<uint64_t, 32> RawMask;
-      unsigned MaskLoBits = Log2_64(VT.getVectorNumElements()*2);
-
-      for (unsigned i = 0; i < MaskNode->getNumOperands(); ++i) {
-        SDValue Op = MaskNode->getOperand(i);
-        if (Op->getOpcode() == ISD::UNDEF)
-          RawMask.push_back((uint64_t)SM_SentinelUndef);
-        else {
-          auto *CN = dyn_cast<ConstantSDNode>(Op.getNode());
-          if (!CN)
-            return false;
-          APInt MaskElement = CN->getAPIntValue();
-          RawMask.push_back(MaskElement.getLoBits(MaskLoBits).getZExtValue());
-        }
-      }
+
+    SmallVector<uint64_t, 32> RawMask;
+    unsigned MaskLoBits = Log2_64(VT.getVectorNumElements() * 2);
+    if (getTargetShuffleMaskIndices(MaskNode, MaskLoBits, RawMask)) {
       DecodeVPERMV3Mask(RawMask, Mask);
       break;
     }
-
-    auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode);
-    if (!MaskLoad)
-      return false;
-
-    SDValue Ptr = MaskLoad->getBasePtr();
-    if (Ptr->getOpcode() == X86ISD::Wrapper ||
-        Ptr->getOpcode() == X86ISD::WrapperRIP)
-      Ptr = Ptr->getOperand(0);
-
-    auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr);
-    if (!MaskCP || MaskCP->isMachineConstantPoolEntry())
-      return false;
-
-    if (auto *C = dyn_cast<Constant>(MaskCP->getConstVal())) {
+    if (auto *C = getTargetShuffleMaskConstant(MaskNode)) {
       DecodeVPERMV3Mask(C, VT, Mask);
       break;
     }