[DAGCombine] `combineShuffleToZeroExtendVectorInReg()`: widen shuffle elements before...
authorRoman Lebedev <lebedev.ri@gmail.com>
Mon, 26 Dec 2022 20:45:37 +0000 (23:45 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Mon, 26 Dec 2022 21:47:45 +0000 (00:47 +0300)
We might have sunk a bitcast into shuffle, and now it might be operating
on more fine-grained elements than what we'd match, so we must not be
dependent on whatever the granularity the shuffle happened to be in,
but transform it into the one canonical for us - with widest elements.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/X86/vector-shuffle-512-v64.ll

index fb053a7..fd27220 100644 (file)
@@ -22670,6 +22670,7 @@ static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
   EVT VT = SVN->getValueType(0);
   assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
   unsigned NumElts = VT.getVectorNumElements();
+  unsigned EltSizeInBits = VT.getScalarSizeInBits();
 
   // TODO: add support for big-endian when we have a test case.
   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
@@ -22722,15 +22723,31 @@ static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
   if (!HadZeroableElts)
     return SDValue();
 
-  // FIXME: the shuffle may be more fine-grained than we want.
+  // The shuffle may be more fine-grained than we want. Widen elements first.
+  // FIXME: should we do this before manifesting zeroable shuffle mask indices?
+  SmallVector<int, 16> ScaledMask;
+  getShuffleMaskWithWidestElts(Mask, ScaledMask);
+  assert(Mask.size() >= ScaledMask.size() &&
+         Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
+  int Prescale = Mask.size() / ScaledMask.size();
+
+  NumElts = ScaledMask.size();
+  EltSizeInBits *= Prescale;
+
+  EVT PrescaledVT = EVT::getVectorVT(
+      *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
+      NumElts);
+
+  if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
+    return SDValue();
 
   // For example,
   // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
   // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
-  auto isZeroExtend = [NumElts, &SrcMask = Mask](unsigned Scale) {
+  auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
     assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
            "Unexpected mask scaling factor.");
-    ArrayRef<int> Mask = SrcMask;
+    ArrayRef<int> Mask = ScaledMask;
     for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
          SrcElt != NumSrcElts; ++SrcElt) {
       // Analyze the shuffle mask in Scale-sized chunks.
@@ -22755,11 +22772,13 @@ static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
   for (bool Commuted : {false, true}) {
     SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
     if (Commuted)
-      ShuffleVectorSDNode::commuteMask(Mask);
+      ShuffleVectorSDNode::commuteMask(ScaledMask);
     std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
-        Opcode, VT, isZeroExtend, DAG, TLI, LegalTypes, LegalOperations);
+        Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
+        LegalOperations);
     if (OutVT)
-      return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, Op));
+      return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
+                                            DAG.getBitcast(PrescaledVT, Op)));
   }
   return SDValue();
 }
index 6394fd1..c2b3633 100644 (file)
@@ -1540,31 +1540,10 @@ define void @PR54562_mem(ptr %src, ptr %dst) {
 }
 
 define <64 x i8> @shuffle_v32i16_zextinreg_to_v16i32(<64 x i8> %a)  {
-; AVX512F-LABEL: shuffle_v32i16_zextinreg_to_v16i32:
-; AVX512F:       # %bb.0:
-; AVX512F-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
-; AVX512F-NEXT:    vextracti128 $1, %ymm0, %xmm0
-; AVX512F-NEXT:    vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
-; AVX512F-NEXT:    vinserti64x4 $1, %ymm0, %zmm1, %zmm0
-; AVX512F-NEXT:    retq
-;
-; AVX512BW-LABEL: shuffle_v32i16_zextinreg_to_v16i32:
-; AVX512BW:       # %bb.0:
-; AVX512BW-NEXT:    vpmovzxwd {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
-; AVX512BW-NEXT:    retq
-;
-; AVX512DQ-LABEL: shuffle_v32i16_zextinreg_to_v16i32:
-; AVX512DQ:       # %bb.0:
-; AVX512DQ-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
-; AVX512DQ-NEXT:    vextracti128 $1, %ymm0, %xmm0
-; AVX512DQ-NEXT:    vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
-; AVX512DQ-NEXT:    vinserti64x4 $1, %ymm0, %zmm1, %zmm0
-; AVX512DQ-NEXT:    retq
-;
-; AVX512VBMI-LABEL: shuffle_v32i16_zextinreg_to_v16i32:
-; AVX512VBMI:       # %bb.0:
-; AVX512VBMI-NEXT:    vpmovzxwd {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
-; AVX512VBMI-NEXT:    retq
+; ALL-LABEL: shuffle_v32i16_zextinreg_to_v16i32:
+; ALL:       # %bb.0:
+; ALL-NEXT:    vpmovzxwd {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
+; ALL-NEXT:    retq
   %b = shufflevector <64 x i8> %a, <64 x i8> <i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 0, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef
 >, <64 x i32> <i32 0, i32 1, i32 84, i32 84, i32 2, i32 3, i32 84, i32 84, i32 4, i32 5, i32 84, i32 84, i32 6, i32 7, i32 84, i32 84, i32 8, i32 9, i32 84, i32 84, i32 10, i32 11, i32 84, i32 84, i32 12, i32 13, i32 84, i32 84, i32 14, i32 15, i32 84, i32 84, i32 16, i32 17, i32 84, i32 84, i32 18, i32 19, i32 84, i32 84, i32 20, i32 21, i32 84, i32 84, i32 22, i32 23, i32 84, i32 84, i32 24, i32 25, i32 84, i32 84, i32 26, i32 27, i32 84, i32 84, i32 28, i32 29, i32 84, i32 84, i32 30, i32 31, i32 84, i32 84>
   ret <64 x i8> %b