[SDAG] Allow scalable vectors in SimplifyDemanded routines
authorPhilip Reames <preames@rivosinc.com>
Mon, 5 Dec 2022 19:59:06 +0000 (11:59 -0800)
committerPhilip Reames <listmail@philipreames.com>
Mon, 5 Dec 2022 20:42:16 +0000 (12:42 -0800)
This is a continuation of the series of patches adding lane wise support for scalable vectors in various knownbit-esq routines.

The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations.

Differential Revision: https://reviews.llvm.org/D137190

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/test/CodeGen/AArch64/active_lane_mask.ll
llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

index 466a2ed..553facc 100644 (file)
@@ -634,16 +634,10 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
 
-  // TODO: We can probably do more work on calculating the known bits and
-  // simplifying the operations for scalable vectors, but for now we just
-  // bail out.
-  if (VT.isScalableVector()) {
-    // Pretend we don't know anything for now.
-    Known = KnownBits(DemandedBits.getBitWidth());
-    return false;
-  }
-
-  APInt DemandedElts = VT.isVector()
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
   return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
@@ -656,12 +650,6 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
-  // Pretend we don't know anything about scalable vectors for now.
-  // TODO: We can probably do more work on simplifying the operations for
-  // scalable vectors, but for now we just bail out.
-  if (VT.isScalableVector())
-    return SDValue();
-
   // Limit search depth.
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return SDValue();
@@ -680,6 +668,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   KnownBits LHSKnown, RHSKnown;
   switch (Op.getOpcode()) {
   case ISD::BITCAST: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     SDValue Src = peekThroughBitcasts(Op.getOperand(0));
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
@@ -825,6 +816,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::ANY_EXTEND_VECTOR_INREG:
   case ISD::SIGN_EXTEND_VECTOR_INREG:
   case ISD::ZERO_EXTEND_VECTOR_INREG: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     // If we only want the lowest element and none of extended bits, then we can
     // return the bitcasted source vector.
     SDValue Src = Op.getOperand(0);
@@ -838,6 +832,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   case ISD::INSERT_VECTOR_ELT: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     // If we don't demand the inserted element, return the base vector.
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
@@ -848,6 +845,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   case ISD::INSERT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return SDValue();
+
     SDValue Vec = Op.getOperand(0);
     SDValue Sub = Op.getOperand(1);
     uint64_t Idx = Op.getConstantOperandVal(2);
@@ -868,6 +868,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   case ISD::VECTOR_SHUFFLE: {
+    assert(!VT.isScalableVector());
     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
 
     // If all the demanded elts are from one operand and are inline,
@@ -891,6 +892,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     break;
   }
   default:
+    // TODO: Probably okay to remove after audit; here to reduce change size
+    // in initial enablement patch for scalable vectors
+    if (VT.isScalableVector())
+      return SDValue();
+
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
               Op, DemandedBits, DemandedElts, DAG, Depth))
@@ -904,14 +910,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
   EVT VT = Op.getValueType();
-
-  // Pretend we don't know anything about scalable vectors for now.
-  // TODO: We can probably do more work on simplifying the operations for
-  // scalable vectors, but for now we just bail out.
-  if (VT.isScalableVector())
-    return SDValue();
-
-  APInt DemandedElts = VT.isVector()
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
@@ -1070,16 +1072,10 @@ bool TargetLowering::SimplifyDemandedBits(
   // Don't know anything.
   Known = KnownBits(BitWidth);
 
-  // TODO: We can probably do more work on calculating the known bits and
-  // simplifying the operations for scalable vectors, but for now we just
-  // bail out.
   EVT VT = Op.getValueType();
-  if (VT.isScalableVector())
-    return false;
-
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
   unsigned NumElts = OriginalDemandedElts.getBitWidth();
-  assert((!VT.isVector() || NumElts == VT.getVectorNumElements()) &&
+  assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
          "Unexpected vector size");
 
   APInt DemandedBits = OriginalDemandedBits;
@@ -1130,6 +1126,8 @@ bool TargetLowering::SimplifyDemandedBits(
   KnownBits Known2;
   switch (Op.getOpcode()) {
   case ISD::SCALAR_TO_VECTOR: {
+    if (VT.isScalableVector())
+      return false;
     if (!DemandedElts[0])
       return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
 
@@ -1167,6 +1165,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::INSERT_VECTOR_ELT: {
+    if (VT.isScalableVector())
+      return false;
     SDValue Vec = Op.getOperand(0);
     SDValue Scl = Op.getOperand(1);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
@@ -1203,6 +1203,8 @@ bool TargetLowering::SimplifyDemandedBits(
     return false;
   }
   case ISD::INSERT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return false;
     // Demand any elements from the subvector and the remainder from the src its
     // inserted into.
     SDValue Src = Op.getOperand(0);
@@ -1246,6 +1248,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::EXTRACT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return false;
     // Offset the demanded elts by the subvector index.
     SDValue Src = Op.getOperand(0);
     if (Src.getValueType().isScalableVector())
@@ -1271,6 +1275,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::CONCAT_VECTORS: {
+    if (VT.isScalableVector())
+      return false;
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     EVT SubVT = Op.getOperand(0).getValueType();
@@ -1289,6 +1295,7 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::VECTOR_SHUFFLE: {
+    assert(!VT.isScalableVector());
     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
 
     // Collect demanded elements from shuffle operands..
@@ -1366,7 +1373,7 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
     // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
-    if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR &&
+    if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
         (Op0.getOperand(0).isUndef() ||
          ISD::isBuildVectorOfConstantSDNodes(Op0.getOperand(0).getNode())) &&
         Op0->hasOneUse()) {
@@ -2226,12 +2233,15 @@ bool TargetLowering::SimplifyDemandedBits(
     Known = KnownHi.concat(KnownLo);
     break;
   }
-  case ISD::ZERO_EXTEND:
-  case ISD::ZERO_EXTEND_VECTOR_INREG: {
+  case ISD::ZERO_EXTEND_VECTOR_INREG:
+    if (VT.isScalableVector())
+      return false;
+    [[fallthrough]];
+  case ISD::ZERO_EXTEND: {
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned InBits = SrcVT.getScalarSizeInBits();
-    unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
+    unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
     bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
 
     // If none of the top bits are demanded, convert this into an any_extend.
@@ -2263,12 +2273,15 @@ bool TargetLowering::SimplifyDemandedBits(
       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
     break;
   }
-  case ISD::SIGN_EXTEND:
-  case ISD::SIGN_EXTEND_VECTOR_INREG: {
+  case ISD::SIGN_EXTEND_VECTOR_INREG:
+    if (VT.isScalableVector())
+      return false;
+    [[fallthrough]];
+  case ISD::SIGN_EXTEND: {
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned InBits = SrcVT.getScalarSizeInBits();
-    unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
+    unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
     bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
 
     // If none of the top bits are demanded, convert this into an any_extend.
@@ -2315,12 +2328,15 @@ bool TargetLowering::SimplifyDemandedBits(
       return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
     break;
   }
-  case ISD::ANY_EXTEND:
-  case ISD::ANY_EXTEND_VECTOR_INREG: {
+  case ISD::ANY_EXTEND_VECTOR_INREG:
+    if (VT.isScalableVector())
+      return false;
+    [[fallthrough]];
+  case ISD::ANY_EXTEND: {
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned InBits = SrcVT.getScalarSizeInBits();
-    unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
+    unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
     bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
 
     // If we only need the bottom element then we can just bitcast.
@@ -2459,6 +2475,8 @@ bool TargetLowering::SimplifyDemandedBits(
     break;
   }
   case ISD::BITCAST: {
+    if (VT.isScalableVector())
+      return false;
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
@@ -2680,6 +2698,10 @@ bool TargetLowering::SimplifyDemandedBits(
     // We also ask the target about intrinsics (which could be specific to it).
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
         Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
+      // TODO: Probably okay to remove after audit; here to reduce change size
+      // in initial enablement patch for scalable vectors
+      if (Op.getValueType().isScalableVector())
+        break;
       if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
                                             Known, TLO, Depth))
         return true;
@@ -2749,7 +2771,7 @@ static APInt getKnownUndefForVectorBinop(SDValue BO, SelectionDAG &DAG,
          "Vector binop only");
 
   EVT EltVT = VT.getVectorElementType();
-  unsigned NumElts = VT.getVectorNumElements();
+  unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
   assert(UndefOp0.getBitWidth() == NumElts &&
          UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
 
index 211361d..cb2b498 100644 (file)
@@ -113,14 +113,13 @@ define <vscale x 4 x i1> @lane_mask_nxv4i1_i8(i8 %index, i8 %TC) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    and w8, w0, #0xff
 ; CHECK-NEXT:    index z0.s, #0, #1
+; CHECK-NEXT:    and w9, w1, #0xff
 ; CHECK-NEXT:    and z0.s, z0.s, #0xff
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    mov z1.s, w8
-; CHECK-NEXT:    and w8, w1, #0xff
 ; CHECK-NEXT:    add z0.s, z0.s, z1.s
+; CHECK-NEXT:    mov z1.s, w9
 ; CHECK-NEXT:    umin z0.s, z0.s, #255
-; CHECK-NEXT:    and z0.s, z0.s, #0xff
-; CHECK-NEXT:    mov z1.s, w8
 ; CHECK-NEXT:    cmphi p0.s, p0/z, z1.s, z0.s
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i8(i8 %index, i8 %TC)
@@ -132,17 +131,16 @@ define <vscale x 2 x i1> @lane_mask_nxv2i1_i8(i8 %index, i8 %TC) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
 ; CHECK-NEXT:    and x8, x0, #0xff
-; CHECK-NEXT:    index z0.d, #0, #1
 ; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
 ; CHECK-NEXT:    and x9, x1, #0xff
-; CHECK-NEXT:    and z0.d, z0.d, #0xff
+; CHECK-NEXT:    index z0.d, #0, #1
 ; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    and z0.d, z0.d, #0xff
 ; CHECK-NEXT:    mov z1.d, x8
+; CHECK-NEXT:    mov z2.d, x9
 ; CHECK-NEXT:    add z0.d, z0.d, z1.d
-; CHECK-NEXT:    mov z1.d, x9
 ; CHECK-NEXT:    umin z0.d, z0.d, #255
-; CHECK-NEXT:    and z0.d, z0.d, #0xff
-; CHECK-NEXT:    cmphi p0.d, p0/z, z1.d, z0.d
+; CHECK-NEXT:    cmphi p0.d, p0/z, z2.d, z0.d
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i8(i8 %index, i8 %TC)
   ret <vscale x 2 x i1> %active.lane.mask
index 67ea38b..518e714 100644 (file)
@@ -224,11 +224,15 @@ TEST_F(AArch64SelectionDAGTest, SimplifyDemandedBitsSVE) {
 
   SDValue Op = DAG->getNode(ISD::AND, Loc, InVecVT, N0, Mask2V);
 
+  // N0 = ?000?0?0
+  // Mask2V = 01010101
+  //  =>
+  // Known.Zero = 00100000 (0xAA)
   KnownBits Known;
   APInt DemandedBits = APInt(8, 0xFF);
   TargetLowering::TargetLoweringOpt TLO(*DAG, false, false);
-  EXPECT_FALSE(TL.SimplifyDemandedBits(Op, DemandedBits, Known, TLO));
-  EXPECT_EQ(Known.Zero, APInt(8, 0));
+  EXPECT_TRUE(TL.SimplifyDemandedBits(Op, DemandedBits, Known, TLO));
+  EXPECT_EQ(Known.Zero, APInt(8, 0xAA));
 }
 
 // Piggy-backing on the AArch64 tests to verify SelectionDAG::computeKnownBits.