[X86][SSE] Improved support for decoding target shuffle masks through bitcasts
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 24 Apr 2016 14:53:54 +0000 (14:53 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 24 Apr 2016 14:53:54 +0000 (14:53 +0000)
Reused the ability to split constants of a type wider than the shuffle mask to work with masks generated from scalar constants transfered to xmm.

This fixes an issue preventing PSHUFB target shuffle masks decoding rematerialized scalar constants and also exposes the XOP VPPERM bug described in PR27472.

llvm-svn: 267343

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll
llvm/test/CodeGen/X86/vector-shuffle-combining-xop.ll

index 2ec27bd..966af04 100644 (file)
@@ -4678,7 +4678,22 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode,
   MVT VT = MaskNode.getSimpleValueType();
   assert(VT.isVector() && "Can't produce a non-vector with a build_vector!");
 
+  // Split an APInt element into MaskEltSizeInBits sized pieces and
+  // insert into the shuffle mask.
+  auto SplitElementToMask = [&](APInt Element) {
+    // Note that this is x86 and so always little endian: the low byte is
+    // the first byte of the mask.
+    int Split = VT.getScalarSizeInBits() / MaskEltSizeInBits;
+    for (int i = 0; i < Split; ++i) {
+      APInt RawElt = Element.getLoBits(MaskEltSizeInBits);
+      Element = Element.lshr(MaskEltSizeInBits);
+      RawMask.push_back(RawElt.getZExtValue());
+    }
+  };
+
   if (MaskNode.getOpcode() == X86ISD::VBROADCAST) {
+    // TODO: Handle (MaskEltSizeInBits % VT.getScalarSizeInBits()) == 0
+    // TODO: Handle (VT.getScalarSizeInBits() % MaskEltSizeInBits) == 0
     if (VT.getScalarSizeInBits() != MaskEltSizeInBits)
       return false;
     if (auto *CN = dyn_cast<ConstantSDNode>(MaskNode.getOperand(0))) {
@@ -4693,13 +4708,16 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode,
 
   if (MaskNode.getOpcode() == X86ISD::VZEXT_MOVL &&
       MaskNode.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR) {
-    if (VT.getScalarSizeInBits() != MaskEltSizeInBits)
+
+    // TODO: Handle (MaskEltSizeInBits % VT.getScalarSizeInBits()) == 0
+    if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) != 0)
       return false;
-    SDValue MaskElement = MaskNode.getOperand(0).getOperand(0);
-    if (auto *CN = dyn_cast<ConstantSDNode>(MaskElement)) {
-      APInt RawElt = CN->getAPIntValue().getLoBits(MaskEltSizeInBits);
-      RawMask.push_back(RawElt.getZExtValue());
-      RawMask.append(VT.getVectorNumElements() - 1, 0);
+    unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits;
+
+    SDValue MaskOp = MaskNode.getOperand(0).getOperand(0);
+    if (auto *CN = dyn_cast<ConstantSDNode>(MaskOp)) {
+      SplitElementToMask(CN->getAPIntValue());
+      RawMask.append((VT.getVectorNumElements() - 1) * ElementSplit, 0);
       return true;
     }
     return false;
@@ -4711,7 +4729,6 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode,
   // TODO: Handle (MaskEltSizeInBits % VT.getScalarSizeInBits()) == 0
   if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) != 0)
     return false;
-  unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits;
 
   for (int i = 0, e = MaskNode.getNumOperands(); i < e; ++i) {
     SDValue Op = MaskNode.getOperand(i);
@@ -4720,23 +4737,12 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode,
       continue;
     }
 
-    APInt MaskElement;
     if (auto *CN = dyn_cast<ConstantSDNode>(Op.getNode()))
-      MaskElement = CN->getAPIntValue();
+      SplitElementToMask(CN->getAPIntValue());
     else if (auto *CFN = dyn_cast<ConstantFPSDNode>(Op.getNode()))
-      MaskElement = CFN->getValueAPF().bitcastToAPInt();
+      SplitElementToMask(CFN->getValueAPF().bitcastToAPInt());
     else
       return false;
-
-    // We now have to decode the element which could be any integer size and
-    // extract each byte of it.
-    for (unsigned j = 0; j < ElementSplit; ++j) {
-      // Note that this is x86 and so always little endian: the low byte is
-      // the first byte of the mask.
-      APInt RawElt = MaskElement.getLoBits(MaskEltSizeInBits);
-      RawMask.push_back(RawElt.getZExtValue());
-      MaskElement = MaskElement.lshr(MaskEltSizeInBits);
-    }
   }
 
   return true;
index cee102c..0f1cae4 100644 (file)
@@ -12,18 +12,12 @@ declare <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8>, <16 x i8>)
 define <16 x i8> @combine_vpshufb_zero(<16 x i8> %a0) {
 ; SSE-LABEL: combine_vpshufb_zero:
 ; SSE:       # BB#0:
-; SSE-NEXT:    movl $128, %eax
-; SSE-NEXT:    movd %eax, %xmm1
-; SSE-NEXT:    pshufb %xmm1, %xmm0
-; SSE-NEXT:    pshufb {{.*#+}} xmm0 = xmm0[0],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vpshufb_zero:
 ; AVX:       # BB#0:
-; AVX-NEXT:    movl $128, %eax
-; AVX-NEXT:    vmovd %eax, %xmm1
-; AVX-NEXT:    vpshufb %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %res0 = call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> %a0, <16 x i8> <i8 128, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>)
   %res1 = call <16 x i8> @llvm.x86.ssse3.pshuf.b.128(<16 x i8> %res0, <16 x i8> <i8 0, i8 128, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>)
index a6376dc..0fba831 100644 (file)
@@ -23,11 +23,7 @@ define <16 x i8> @combine_vpperm_identity(<16 x i8> %a0, <16 x i8> %a1) {
 define <16 x i8> @combine_vpperm_zero(<16 x i8> %a0, <16 x i8> %a1) {
 ; CHECK-LABEL: combine_vpperm_zero:
 ; CHECK:       # BB#0:
-; CHECK-NEXT:    movl $128, %eax
-; CHECK-NEXT:    vmovd %eax, %xmm2
-; CHECK-NEXT:    vpperm %xmm2, %xmm1, %xmm0, %xmm0
-; CHECK-NEXT:    vpperm {{.*#+}} xmm0 = xmm0[0],zero,xmm0[0,0,0,0,0,0,0,0,0,0,0,0,0,0]
-; CHECK-NEXT:    vpperm {{.*#+}} xmm0 = xmm0[0,1],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
+; CHECK-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
 ; CHECK-NEXT:    retq
   %res0 = call <16 x i8> @llvm.x86.xop.vpperm(<16 x i8> %a0, <16 x i8> %a1, <16 x i8> <i8 128, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>)
   %res1 = call <16 x i8> @llvm.x86.xop.vpperm(<16 x i8> %res0, <16 x i8> undef, <16 x i8> <i8 0, i8 128, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>)