[CodeGen] Fix warnings due to SelectionDAG::getSplatSourceVector
authorDavid Sherwood <david.sherwood@arm.com>
Tue, 28 Apr 2020 14:55:34 +0000 (15:55 +0100)
committerDavid Sherwood <david.sherwood@arm.com>
Tue, 5 May 2020 07:45:41 +0000 (08:45 +0100)
Summary:
I have fixed several places in getSplatSourceVector and isSplatValue
to work correctly with scalable vectors. I added new support for
the ISD::SPLAT_VECTOR DAG node as one of the obvious cases we can
support with scalable vectors. In other places I have tried to do
the sensible thing, such as bail out for vector types we don't yet
support or don't intend to support.

It's not possible to add IR test cases to cover these changes, since
they are currently only ever exercised on certain targets, e.g.
only X86 targets use the result of getSplatSourceVector. I've
assumed that X86 tests already exist to test these code paths for
fixed vectors. However, I have added some AArch64 unit tests that
test the specific functions I have changed.

Differential revision: https://reviews.llvm.org/D79083

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

index 7929682..e5f522a 100644 (file)
@@ -2276,15 +2276,42 @@ bool SelectionDAG::MaskedValueIsAllOnes(SDValue V, const APInt &Mask,
 }
 
 /// isSplatValue - Return true if the vector V has the same value
-/// across all DemandedElts.
+/// across all DemandedElts. For scalable vectors it does not make
+/// sense to specify which elements are demanded or undefined, therefore
+/// they are simply ignored.
 bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
                                 APInt &UndefElts) {
-  if (!DemandedElts)
-    return false; // No demanded elts, better to assume we don't know anything.
-
   EVT VT = V.getValueType();
   assert(VT.isVector() && "Vector type expected");
 
+  if (!VT.isScalableVector() && !DemandedElts)
+    return false; // No demanded elts, better to assume we don't know anything.
+
+  // Deal with some common cases here that work for both fixed and scalable
+  // vector types.
+  switch (V.getOpcode()) {
+  case ISD::SPLAT_VECTOR:
+    return true;
+  case ISD::ADD:
+  case ISD::SUB:
+  case ISD::AND: {
+    APInt UndefLHS, UndefRHS;
+    SDValue LHS = V.getOperand(0);
+    SDValue RHS = V.getOperand(1);
+    if (isSplatValue(LHS, DemandedElts, UndefLHS) &&
+        isSplatValue(RHS, DemandedElts, UndefRHS)) {
+      UndefElts = UndefLHS | UndefRHS;
+      return true;
+    }
+    break;
+  }
+  }
+
+  // We don't support other cases than those above for scalable vectors at
+  // the moment.
+  if (VT.isScalableVector())
+    return false;
+
   unsigned NumElts = VT.getVectorNumElements();
   assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch");
   UndefElts = APInt::getNullValue(NumElts);
@@ -2341,19 +2368,6 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
     }
     break;
   }
-  case ISD::ADD:
-  case ISD::SUB:
-  case ISD::AND: {
-    APInt UndefLHS, UndefRHS;
-    SDValue LHS = V.getOperand(0);
-    SDValue RHS = V.getOperand(1);
-    if (isSplatValue(LHS, DemandedElts, UndefLHS) &&
-        isSplatValue(RHS, DemandedElts, UndefRHS)) {
-      UndefElts = UndefLHS | UndefRHS;
-      return true;
-    }
-    break;
-  }
   }
 
   return false;
@@ -2363,10 +2377,13 @@ bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
 bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) {
   EVT VT = V.getValueType();
   assert(VT.isVector() && "Vector type expected");
-  unsigned NumElts = VT.getVectorNumElements();
 
   APInt UndefElts;
-  APInt DemandedElts = APInt::getAllOnesValue(NumElts);
+  APInt DemandedElts;
+
+  // For now we don't support this with scalable vectors.
+  if (!VT.isScalableVector())
+    DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements());
   return isSplatValue(V, DemandedElts, UndefElts) &&
          (AllowUndefs || !UndefElts);
 }
@@ -2379,19 +2396,35 @@ SDValue SelectionDAG::getSplatSourceVector(SDValue V, int &SplatIdx) {
   switch (Opcode) {
   default: {
     APInt UndefElts;
-    APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements());
+    APInt DemandedElts;
+
+    if (!VT.isScalableVector())
+      DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements());
+
     if (isSplatValue(V, DemandedElts, UndefElts)) {
-      // Handle case where all demanded elements are UNDEF.
-      if (DemandedElts.isSubsetOf(UndefElts)) {
+      if (VT.isScalableVector()) {
+        // DemandedElts and UndefElts are ignored for scalable vectors, since
+        // the only supported cases are SPLAT_VECTOR nodes.
         SplatIdx = 0;
-        return getUNDEF(VT);
+      } else {
+        // Handle case where all demanded elements are UNDEF.
+        if (DemandedElts.isSubsetOf(UndefElts)) {
+          SplatIdx = 0;
+          return getUNDEF(VT);
+        }
+        SplatIdx = (UndefElts & DemandedElts).countTrailingOnes();
       }
-      SplatIdx = (UndefElts & DemandedElts).countTrailingOnes();
       return V;
     }
     break;
   }
+  case ISD::SPLAT_VECTOR:
+    SplatIdx = 0;
+    return V;
   case ISD::VECTOR_SHUFFLE: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     // Check if this is a shuffle node doing a splat.
     // TODO - remove this and rely purely on SelectionDAG::isSplatValue,
     // getTargetVShiftNode currently struggles without the splat source.
index defd27e..848cbc0 100644 (file)
@@ -199,4 +199,182 @@ TEST_F(AArch64SelectionDAGTest, ComputeKnownBits_SUB) {
   EXPECT_EQ(Known.One, APInt(8, 0x1));
 }
 
+TEST_F(AArch64SelectionDAGTest, isSplatValue_Fixed_BUILD_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false);
+  // Create a BUILD_VECTOR
+  SDValue Op = DAG->getConstant(1, Loc, VecVT);
+  EXPECT_EQ(Op->getOpcode(), ISD::BUILD_VECTOR);
+  EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false));
+
+  APInt UndefElts;
+  APInt DemandedElts;
+  EXPECT_FALSE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+
+  // Width=16, Mask=3
+  DemandedElts = APInt(16, 3);
+  EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+}
+
+TEST_F(AArch64SelectionDAGTest, isSplatValue_Fixed_ADD_of_BUILD_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false);
+
+  // Should create BUILD_VECTORs
+  SDValue Val1 = DAG->getConstant(1, Loc, VecVT);
+  SDValue Val2 = DAG->getConstant(3, Loc, VecVT);
+  EXPECT_EQ(Val1->getOpcode(), ISD::BUILD_VECTOR);
+  SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2);
+
+  EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false));
+
+  APInt UndefElts;
+  APInt DemandedElts;
+  EXPECT_FALSE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+
+  // Width=16, Mask=3
+  DemandedElts = APInt(16, 3);
+  EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+}
+
+TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_SPLAT_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true);
+  // Create a SPLAT_VECTOR
+  SDValue Op = DAG->getConstant(1, Loc, VecVT);
+  EXPECT_EQ(Op->getOpcode(), ISD::SPLAT_VECTOR);
+  EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false));
+
+  APInt UndefElts;
+  APInt DemandedElts;
+  EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+
+  // Width=16, Mask=3. These bits should be ignored.
+  DemandedElts = APInt(16, 3);
+  EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+}
+
+TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_ADD_of_SPLAT_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true);
+
+  // Should create SPLAT_VECTORS
+  SDValue Val1 = DAG->getConstant(1, Loc, VecVT);
+  SDValue Val2 = DAG->getConstant(3, Loc, VecVT);
+  EXPECT_EQ(Val1->getOpcode(), ISD::SPLAT_VECTOR);
+  SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2);
+
+  EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false));
+
+  APInt UndefElts;
+  APInt DemandedElts;
+  EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+
+  // Width=16, Mask=3. These bits should be ignored.
+  DemandedElts = APInt(16, 3);
+  EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts));
+}
+
+TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Fixed_BUILD_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false);
+  // Create a BUILD_VECTOR
+  SDValue Op = DAG->getConstant(1, Loc, VecVT);
+  EXPECT_EQ(Op->getOpcode(), ISD::BUILD_VECTOR);
+
+  int SplatIdx = -1;
+  EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op);
+  EXPECT_EQ(SplatIdx, 0);
+}
+
+TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Fixed_ADD_of_BUILD_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false);
+
+  // Should create BUILD_VECTORs
+  SDValue Val1 = DAG->getConstant(1, Loc, VecVT);
+  SDValue Val2 = DAG->getConstant(3, Loc, VecVT);
+  EXPECT_EQ(Val1->getOpcode(), ISD::BUILD_VECTOR);
+  SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2);
+
+  int SplatIdx = -1;
+  EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op);
+  EXPECT_EQ(SplatIdx, 0);
+}
+
+TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_SPLAT_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true);
+  // Create a SPLAT_VECTOR
+  SDValue Op = DAG->getConstant(1, Loc, VecVT);
+  EXPECT_EQ(Op->getOpcode(), ISD::SPLAT_VECTOR);
+
+  int SplatIdx = -1;
+  EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op);
+  EXPECT_EQ(SplatIdx, 0);
+}
+
+TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTOR) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  auto IntVT = EVT::getIntegerVT(Context, 8);
+  auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true);
+
+  // Should create SPLAT_VECTORS
+  SDValue Val1 = DAG->getConstant(1, Loc, VecVT);
+  SDValue Val2 = DAG->getConstant(3, Loc, VecVT);
+  EXPECT_EQ(Val1->getOpcode(), ISD::SPLAT_VECTOR);
+  SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2);
+
+  int SplatIdx = -1;
+  EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op);
+  EXPECT_EQ(SplatIdx, 0);
+}
+
 } // end anonymous namespace