[SDAG] Treat DemandedElts argument to isSplatVector as splat for scalable vectors...
authorPhilip Reames <preames@rivosinc.com>
Tue, 11 Oct 2022 16:32:49 +0000 (09:32 -0700)
committerPhilip Reames <listmail@philipreames.com>
Tue, 11 Oct 2022 16:49:28 +0000 (09:49 -0700)
The previous code used a APInt(1, 0) to represent the demanded elts of a scalable vector, and then ignored that argument if type was scalable.  This was inconsistent with the UndefElts parameter which is set to either APInt(1, 0) or APInt(1,1) - that is, implicitly broadcast across all lanes.  Particularly since the undef code relied on the DemandedElts parameter having bitwidth 1 to achieve that result!

This change switches the demanded parameter to APInt(1,1), documents the broadcast semantics, and takes advantage of it to remove one special case for scalable vectors which is no longer required.

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

index 219d58a..cd98f04 100644 (file)
@@ -2542,10 +2542,10 @@ bool SelectionDAG::MaskedValueIsAllOnes(SDValue V, const APInt &Mask,
 }
 
 /// isSplatValue - Return true if the vector V has the same value
-/// across all DemandedElts. For scalable vectors it does not make
-/// sense to specify which elements are demanded, therefore they are
-/// simply ignored.  For undef elts this means either all lanes are
-/// undef, or no lanes are known undef.
+/// across all DemandedElts. For scalable vectors, we don't know the
+/// number of lanes at compile time.  Instead, we use a 1 bit APInt
+/// to represent a conservative value for all lanes; that is, that
+/// one bit value is implicitly splatted across all lanes.
 bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
                                 APInt &UndefElts, unsigned Depth) const {
   unsigned Opcode = V.getOpcode();
@@ -2554,7 +2554,7 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
   assert((!VT.isScalableVector() || DemandedElts.getBitWidth() == 1) &&
          "scalable demanded bits are ignored");
 
-  if (!VT.isScalableVector() && !DemandedElts)
+  if (!DemandedElts)
     return false; // No demanded elts, better to assume we don't know anything.
 
   if (Depth >= MaxRecursionDepth)
@@ -2736,11 +2736,11 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) const {
   assert(VT.isVector() && "Vector type expected");
 
   APInt UndefElts;
-  APInt DemandedElts;
-
-  // For now we don't support this with scalable vectors.
-  if (!VT.isScalableVector())
-    DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts
+    = APInt::getAllOnes(VT.isScalableVector() ? 1 : VT.getVectorNumElements());
   return isSplatValue(V, DemandedElts, UndefElts) &&
          (AllowUndefs || !UndefElts);
 }
@@ -2753,10 +2753,11 @@ SDValue SelectionDAG::getSplatSourceVector(SDValue V, int &SplatIdx) {
   switch (Opcode) {
   default: {
     APInt UndefElts;
-    APInt DemandedElts;
-
-    if (!VT.isScalableVector())
-      DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
+    // Since the number of lanes in a scalable vector is unknown at compile time,
+    // we track one bit which is implicitly broadcast to all lanes.  This means
+    // that all lanes in a scalable vector are considered demanded.
+    APInt DemandedElts
+      = APInt::getAllOnes(VT.isScalableVector() ? 1 : VT.getVectorNumElements());
 
     if (isSplatValue(V, DemandedElts, UndefElts)) {
       if (VT.isScalableVector()) {
index 428309b..4b41957 100644 (file)
@@ -325,7 +325,7 @@ TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_SPLAT_VECTOR) {
   EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false));
 
   APInt UndefElts;
-  APInt DemandedElts;
+  APInt DemandedElts(1,1);
   EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
 }
 
@@ -345,7 +345,7 @@ TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_ADD_of_SPLAT_VECTOR) {
   EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false));
 
   APInt UndefElts;
-  APInt DemandedElts;
+  APInt DemandedElts(1, 1);
   EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
 }