[SLP]Fix incorrect shuffle results because of missing shuffle mask
authorAlexey Bataev <a.bataev@outlook.com>
Wed, 4 Jan 2023 16:32:25 +0000 (08:32 -0800)
committerAlexey Bataev <a.bataev@outlook.com>
Wed, 4 Jan 2023 21:10:40 +0000 (13:10 -0800)
analysis.

Missed the analysis of the shuffle mask when trying to analyze the
operands of the shuffle instruction during peeking through shuffle
instructions.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/peek-through-shuffle.ll

index 947c456..fe2e014 100644 (file)
@@ -6405,20 +6405,26 @@ protected:
       if (auto *SVOpTy =
               dyn_cast<FixedVectorType>(SV->getOperand(0)->getType()))
         LocalVF = SVOpTy->getNumElements();
+      SmallVector<int> ExtMask(Mask.size(), UndefMaskElem);
+      for (auto [Idx, I] : enumerate(Mask)) {
+         if (I == UndefMaskElem)
+           continue;
+         ExtMask[Idx] = SV->getMaskValue(I);
+      }
       bool IsOp1Undef =
           isUndefVector(SV->getOperand(0),
-                        buildUseMask(LocalVF, Mask, UseMask::FirstArg))
+                        buildUseMask(LocalVF, ExtMask, UseMask::FirstArg))
               .all();
       bool IsOp2Undef =
           isUndefVector(SV->getOperand(1),
-                        buildUseMask(LocalVF, Mask, UseMask::SecondArg))
+                        buildUseMask(LocalVF, ExtMask, UseMask::SecondArg))
               .all();
       if (!IsOp1Undef && !IsOp2Undef) {
         // Update mask and mark undef elems.
         for (auto [Idx, I] : enumerate(Mask)) {
           if (I == UndefMaskElem)
             continue;
-          if (SV->getShuffleMask()[I % SV->getShuffleMask().size()] ==
+          if (SV->getMaskValue(I % SV->getShuffleMask().size()) ==
               UndefMaskElem)
             I = UndefMaskElem;
         }
@@ -6495,14 +6501,26 @@ protected:
         // again.
         if (auto *SV1 = dyn_cast<ShuffleVectorInst>(Op1))
           if (auto *SV2 = dyn_cast<ShuffleVectorInst>(Op2)) {
+            SmallVector<int> ExtMask1(Mask.size(), UndefMaskElem);
+            for (auto [Idx, I] : enumerate(CombinedMask1)) {
+                if (I == UndefMaskElem)
+                continue;
+                ExtMask1[Idx] = SV1->getMaskValue(I);
+            }
             SmallBitVector UseMask1 = buildUseMask(
                 cast<FixedVectorType>(SV1->getOperand(1)->getType())
                     ->getNumElements(),
-                CombinedMask1, UseMask::FirstArg);
+                ExtMask1, UseMask::SecondArg);
+            SmallVector<int> ExtMask2(CombinedMask2.size(), UndefMaskElem);
+            for (auto [Idx, I] : enumerate(CombinedMask2)) {
+                if (I == UndefMaskElem)
+                continue;
+                ExtMask2[Idx] = SV2->getMaskValue(I);
+            }
             SmallBitVector UseMask2 = buildUseMask(
                 cast<FixedVectorType>(SV2->getOperand(1)->getType())
                     ->getNumElements(),
-                CombinedMask2, UseMask::FirstArg);
+                ExtMask2, UseMask::SecondArg);
             if (SV1->getOperand(0)->getType() ==
                     SV2->getOperand(0)->getType() &&
                 SV1->getOperand(0)->getType() != SV1->getType() &&
index f9e0e4f..047a0d4 100644 (file)
@@ -9,7 +9,7 @@ define void @foo(ptr %0, <4 x float> %1) {
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x float> [[TMP3]], <2 x float> poison, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>
 ; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x float> zeroinitializer, <4 x float> [[TMP4]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
 ; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <4 x float> [[TMP1:%.*]], <4 x float> zeroinitializer, <4 x i32> <i32 0, i32 5, i32 6, i32 undef>
-; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> [[TMP4]], <4 x i32> <i32 0, i32 1, i32 2, i32 4>
+; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x float> [[TMP6]], <4 x float> [[TMP4]], <4 x i32> <i32 0, i32 1, i32 2, i32 4>
 ; CHECK-NEXT:    [[TMP8:%.*]] = fpext <4 x float> [[TMP7]] to <4 x double>
 ; CHECK-NEXT:    store <4 x double> [[TMP8]], ptr [[TMP0:%.*]], align 32
 ; CHECK-NEXT:    ret void