[IR] Improve member `ShuffleVectorInst::isReplicationMask()`
authorRoman Lebedev <lebedev.ri@gmail.com>
Fri, 5 Nov 2021 16:11:55 +0000 (19:11 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Fri, 5 Nov 2021 21:09:27 +0000 (00:09 +0300)
When we have an actual shuffle, we can impose the additional restriction
that the mask replicates the elements of the first operand, so we know
the replication factor as a ratio of output and op0 vector sizes.

llvm/include/llvm/IR/Instructions.h
llvm/lib/IR/Instructions.cpp
llvm/unittests/IR/InstructionsTest.cpp

index b380e34..0ef7888 100644 (file)
@@ -2373,14 +2373,7 @@ public:
   }
 
   /// Return true if this shuffle mask is a replication mask.
-  bool isReplicationMask(int &ReplicationFactor, int &VF) const {
-    // Not possible to express a shuffle mask for a scalable vector for this
-    // case.
-    if (isa<ScalableVectorType>(getType()))
-      return false;
-
-    return isReplicationMask(ShuffleMask, ReplicationFactor, VF);
-  }
+  bool isReplicationMask(int &ReplicationFactor, int &VF) const;
 
   /// Change values in a shuffle permute mask assuming the two vector operands
   /// of length InVecNumElts have swapped position.
index 63dd075..c42df49 100644 (file)
@@ -2502,6 +2502,21 @@ bool ShuffleVectorInst::isReplicationMask(ArrayRef<int> Mask,
   return false;
 }
 
+bool ShuffleVectorInst::isReplicationMask(int &ReplicationFactor,
+                                          int &VF) const {
+  // Not possible to express a shuffle mask for a scalable vector for this
+  // case.
+  if (isa<ScalableVectorType>(getType()))
+    return false;
+
+  VF = cast<FixedVectorType>(Op<0>()->getType())->getNumElements();
+  if (ShuffleMask.size() % VF != 0)
+    return false;
+  ReplicationFactor = ShuffleMask.size() / VF;
+
+  return isReplicationMaskWithParams(ShuffleMask, ReplicationFactor, VF);
+}
+
 //===----------------------------------------------------------------------===//
 //                             InsertValueInst Class
 //===----------------------------------------------------------------------===//
index 213435f..a4a9671 100644 (file)
@@ -1126,6 +1126,16 @@ TEST(InstructionsTest, ShuffleMaskIsReplicationMask) {
           ReplicatedMask, GuessedReplicationFactor, GuessedVF));
       EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor);
       EXPECT_EQ(GuessedVF, VF);
+
+      for (int OpVF : seq_inclusive(VF, 2 * VF + 1)) {
+        LLVMContext Ctx;
+        Type *OpVFTy = FixedVectorType::get(IntegerType::getInt1Ty(Ctx), OpVF);
+        Value *Op = ConstantVector::getNullValue(OpVFTy);
+        ShuffleVectorInst *SVI = new ShuffleVectorInst(Op, Op, ReplicatedMask);
+        EXPECT_EQ(SVI->isReplicationMask(GuessedReplicationFactor, GuessedVF),
+                  OpVF == VF);
+        delete SVI;
+      }
     }
   }
 }