[SDAG] Allow scalable vectors in ComputeNumSignBits
authorPhilip Reames <preames@rivosinc.com>
Fri, 18 Nov 2022 17:57:42 +0000 (09:57 -0800)
committerPhilip Reames <listmail@philipreames.com>
Fri, 18 Nov 2022 18:50:06 +0000 (10:50 -0800)
This is a continuation of the series of patches adding lane wise support for scalable vectors in various knownbit-esq routines.

The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations.

Differential Revision: https://reviews.llvm.org/D137141

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll
llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll
llvm/test/CodeGen/RISCV/rvv/vdiv-vp.ll
llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll
llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll
llvm/test/CodeGen/RISCV/rvv/vrem-vp.ll

index ce86bf4..b562949 100644 (file)
@@ -3962,11 +3962,10 @@ bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val) const {
 unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
-  // TODO: Assume we don't know anything for now.
-  if (VT.isScalableVector())
-    return 1;
-
-  APInt DemandedElts = VT.isVector()
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
+  APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
   return ComputeNumSignBits(Op, DemandedElts, Depth);
@@ -3989,7 +3988,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   if (Depth >= MaxRecursionDepth)
     return 1;  // Limit search depth.
 
-  if (!DemandedElts || VT.isScalableVector())
+  if (!DemandedElts)
     return 1;  // No demanded elts, better to assume we don't know anything.
 
   unsigned Opcode = Op.getOpcode();
@@ -4004,7 +4003,16 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::MERGE_VALUES:
     return ComputeNumSignBits(Op.getOperand(Op.getResNo()), DemandedElts,
                               Depth + 1);
+  case ISD::SPLAT_VECTOR: {
+    // Check if the sign bits of source go down as far as the truncated value.
+    unsigned NumSrcBits = Op.getOperand(0).getValueSizeInBits();
+    unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1);
+    if (NumSrcSignBits > (NumSrcBits - VTBits))
+      return NumSrcSignBits - (NumSrcBits - VTBits);
+    break;
+  }
   case ISD::BUILD_VECTOR:
+    assert(!VT.isScalableVector());
     Tmp = VTBits;
     for (unsigned i = 0, e = Op.getNumOperands(); (i < e) && (Tmp > 1); ++i) {
       if (!DemandedElts[i])
@@ -4049,6 +4057,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   }
 
   case ISD::BITCAST: {
+    if (VT.isScalableVector())
+      return 1;
     SDValue N0 = Op.getOperand(0);
     EVT SrcVT = N0.getValueType();
     unsigned SrcBits = SrcVT.getScalarSizeInBits();
@@ -4106,6 +4116,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     Tmp2 = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
     return std::max(Tmp, Tmp2);
   case ISD::SIGN_EXTEND_VECTOR_INREG: {
+    if (VT.isScalableVector())
+      return 1;
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements());
@@ -4323,6 +4335,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     break;
   }
   case ISD::EXTRACT_ELEMENT: {
+    if (VT.isScalableVector())
+      return 1;
     const int KnownSign = ComputeNumSignBits(Op.getOperand(0), Depth+1);
     const int BitWidth = Op.getValueSizeInBits();
     const int Items = Op.getOperand(0).getValueSizeInBits() / BitWidth;
@@ -4336,6 +4350,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     return std::clamp(KnownSign - rIndex * BitWidth, 0, BitWidth);
   }
   case ISD::INSERT_VECTOR_ELT: {
+    if (VT.isScalableVector())
+      return 1;
     // If we know the element index, split the demand between the
     // source vector and the inserted element, otherwise assume we need
     // the original demanded vector elements and the value.
@@ -4366,6 +4382,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     return Tmp;
   }
   case ISD::EXTRACT_VECTOR_ELT: {
+    if (VT.isScalableVector())
+      return 1;
     SDValue InVec = Op.getOperand(0);
     SDValue EltNo = Op.getOperand(1);
     EVT VecVT = InVec.getValueType();
@@ -4404,6 +4422,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     return ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1);
   }
   case ISD::CONCAT_VECTORS: {
+    if (VT.isScalableVector())
+      return 1;
     // Determine the minimum number of sign bits across all demanded
     // elts of the input vectors. Early out if the result is already 1.
     Tmp = std::numeric_limits<unsigned>::max();
@@ -4422,6 +4442,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
     return Tmp;
   }
   case ISD::INSERT_SUBVECTOR: {
+    if (VT.isScalableVector())
+      return 1;
     // Demand any elements from the subvector and the remainder from the src its
     // inserted into.
     SDValue Src = Op.getOperand(0);
@@ -4492,7 +4514,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
           // We only need to handle vectors - computeKnownBits should handle
           // scalar cases.
           Type *CstTy = Cst->getType();
-          if (CstTy->isVectorTy() &&
+          if (CstTy->isVectorTy() && !VT.isScalableVector() &&
               (NumElts * VTBits) == CstTy->getPrimitiveSizeInBits() &&
               VTBits == CstTy->getScalarSizeInBits()) {
             Tmp = VTBits;
@@ -4527,6 +4549,10 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
       Opcode == ISD::INTRINSIC_WO_CHAIN ||
       Opcode == ISD::INTRINSIC_W_CHAIN ||
       Opcode == ISD::INTRINSIC_VOID) {
+    // TODO: This can probably be removed once target code is audited.  This
+    // is here purely to reduce patch size and review complexity.
+    if (VT.isScalableVector())
+      return 1;
     unsigned NumBits =
         TLI->ComputeNumSignBitsForTargetNode(Op, DemandedElts, *this, Depth);
     if (NumBits > 1)
index 8244b5f..ed3f784 100644 (file)
@@ -95,7 +95,7 @@ define <vscale x 2 x float> @masked_gather_nxv2f32(float* %base, <vscale x 2 x i
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p1.d
 ; CHECK-NEXT:    sxth z0.d, p1/m, z0.d
-; CHECK-NEXT:    ld1w { z0.d }, p0/z, [x0, z0.d, sxtw #2]
+; CHECK-NEXT:    ld1w { z0.d }, p0/z, [x0, z0.d, lsl #2]
 ; CHECK-NEXT:    ret
   %ptrs = getelementptr float, float* %base, <vscale x 2 x i16> %indices
   %data = call <vscale x 2 x float> @llvm.masked.gather.nxv2f32(<vscale x 2 x float*> %ptrs, i32 1, <vscale x 2 x i1> %mask, <vscale x 2 x float> undef)
index f89ec1d..eebd4a2 100644 (file)
@@ -9,15 +9,10 @@ define <vscale x 2 x i8> @smulo_nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    sxtb z1.d, p0/m, z1.d
 ; CHECK-NEXT:    sxtb z0.d, p0/m, z0.d
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    smulh z2.d, p0/m, z2.d, z1.d
 ; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT:    asr z1.d, z0.d, #63
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    sxtb z3.d, p0/m, z0.d
-; CHECK-NEXT:    cmpne p1.d, p0/z, z2.d, z1.d
-; CHECK-NEXT:    cmpne p0.d, p0/z, z3.d, z0.d
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    sxtb z1.d, p0/m, z0.d
+; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, z0.d
 ; CHECK-NEXT:    mov z0.d, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 2 x i8>, <vscale x 2 x i1> } @llvm.smul.with.overflow.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y)
@@ -35,15 +30,10 @@ define <vscale x 4 x i8> @smulo_nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    sxtb z1.s, p0/m, z1.s
 ; CHECK-NEXT:    sxtb z0.s, p0/m, z0.s
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    smulh z2.s, p0/m, z2.s, z1.s
 ; CHECK-NEXT:    mul z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT:    asr z1.s, z0.s, #31
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    sxtb z3.s, p0/m, z0.s
-; CHECK-NEXT:    cmpne p1.s, p0/z, z2.s, z1.s
-; CHECK-NEXT:    cmpne p0.s, p0/z, z3.s, z0.s
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    sxtb z1.s, p0/m, z0.s
+; CHECK-NEXT:    cmpne p0.s, p0/z, z1.s, z0.s
 ; CHECK-NEXT:    mov z0.s, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 4 x i8>, <vscale x 4 x i1> } @llvm.smul.with.overflow.nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %y)
@@ -61,15 +51,10 @@ define <vscale x 8 x i8> @smulo_nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %
 ; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEXT:    sxtb z0.h, p0/m, z0.h
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    smulh z2.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    mul z0.h, p0/m, z0.h, z1.h
-; CHECK-NEXT:    asr z1.h, z0.h, #15
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    sxtb z3.h, p0/m, z0.h
-; CHECK-NEXT:    cmpne p1.h, p0/z, z2.h, z1.h
-; CHECK-NEXT:    cmpne p0.h, p0/z, z3.h, z0.h
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    sxtb z1.h, p0/m, z0.h
+; CHECK-NEXT:    cmpne p0.h, p0/z, z1.h, z0.h
 ; CHECK-NEXT:    mov z0.h, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 8 x i8>, <vscale x 8 x i1> } @llvm.smul.with.overflow.nxv8i8(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y)
@@ -175,15 +160,10 @@ define <vscale x 2 x i16> @smulo_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    sxth z1.d, p0/m, z1.d
 ; CHECK-NEXT:    sxth z0.d, p0/m, z0.d
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    smulh z2.d, p0/m, z2.d, z1.d
 ; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT:    asr z1.d, z0.d, #63
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    sxth z3.d, p0/m, z0.d
-; CHECK-NEXT:    cmpne p1.d, p0/z, z2.d, z1.d
-; CHECK-NEXT:    cmpne p0.d, p0/z, z3.d, z0.d
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    sxth z1.d, p0/m, z0.d
+; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, z0.d
 ; CHECK-NEXT:    mov z0.d, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 2 x i16>, <vscale x 2 x i1> } @llvm.smul.with.overflow.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y)
@@ -201,15 +181,10 @@ define <vscale x 4 x i16> @smulo_nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i1
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    sxth z1.s, p0/m, z1.s
 ; CHECK-NEXT:    sxth z0.s, p0/m, z0.s
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    smulh z2.s, p0/m, z2.s, z1.s
 ; CHECK-NEXT:    mul z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT:    asr z1.s, z0.s, #31
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    sxth z3.s, p0/m, z0.s
-; CHECK-NEXT:    cmpne p1.s, p0/z, z2.s, z1.s
-; CHECK-NEXT:    cmpne p0.s, p0/z, z3.s, z0.s
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    sxth z1.s, p0/m, z0.s
+; CHECK-NEXT:    cmpne p0.s, p0/z, z1.s, z0.s
 ; CHECK-NEXT:    mov z0.s, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 4 x i16>, <vscale x 4 x i1> } @llvm.smul.with.overflow.nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y)
@@ -315,15 +290,10 @@ define <vscale x 2 x i32> @smulo_nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i3
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    sxtw z1.d, p0/m, z1.d
 ; CHECK-NEXT:    sxtw z0.d, p0/m, z0.d
-; CHECK-NEXT:    movprfx z2, z0
-; CHECK-NEXT:    smulh z2.d, p0/m, z2.d, z1.d
 ; CHECK-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-NEXT:    asr z1.d, z0.d, #63
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    sxtw z3.d, p0/m, z0.d
-; CHECK-NEXT:    cmpne p1.d, p0/z, z2.d, z1.d
-; CHECK-NEXT:    cmpne p0.d, p0/z, z3.d, z0.d
-; CHECK-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NEXT:    movprfx z1, z0
+; CHECK-NEXT:    sxtw z1.d, p0/m, z0.d
+; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, z0.d
 ; CHECK-NEXT:    mov z0.d, p0/m, #0 // =0x0
 ; CHECK-NEXT:    ret
   %a = call { <vscale x 2 x i32>, <vscale x 2 x i1> } @llvm.smul.with.overflow.nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y)
index dbeefd0..fd951d5 100644 (file)
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vdiv_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
 ; CHECK-NEXT:    vsetvli a2, zero, e8, m1, ta, ma
 ; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vsra.vi v8, v8, 1
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vadd.vv v9, v9, v9
-; CHECK-NEXT:    vsra.vi v9, v9, 1
 ; CHECK-NEXT:    vsetvli zero, a1, e8, m1, ta, ma
-; CHECK-NEXT:    vdiv.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    vdiv.vx v8, v8, a0, v0.t
 ; CHECK-NEXT:    ret
   %elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
   %vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer
index 24414d5..c69f5fd 100644 (file)
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vmax_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
 ; CHECK-NEXT:    vsetvli a2, zero, e8, m1, ta, ma
 ; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vsra.vi v8, v8, 1
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vadd.vv v9, v9, v9
-; CHECK-NEXT:    vsra.vi v9, v9, 1
 ; CHECK-NEXT:    vsetvli zero, a1, e8, m1, ta, ma
-; CHECK-NEXT:    vmax.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    vmax.vx v8, v8, a0, v0.t
 ; CHECK-NEXT:    ret
   %elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
   %vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer
index ae749e5..95c5cda 100644 (file)
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vmin_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
 ; CHECK-NEXT:    vsetvli a2, zero, e8, m1, ta, ma
 ; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vsra.vi v8, v8, 1
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vadd.vv v9, v9, v9
-; CHECK-NEXT:    vsra.vi v9, v9, 1
 ; CHECK-NEXT:    vsetvli zero, a1, e8, m1, ta, ma
-; CHECK-NEXT:    vmin.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    vmin.vx v8, v8, a0, v0.t
 ; CHECK-NEXT:    ret
   %elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
   %vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer
index 5f2fca9..74a8fce 100644 (file)
@@ -12,11 +12,8 @@ define <vscale x 8 x i7> @vrem_vx_nxv8i7(<vscale x 8 x i7> %a, i7 signext %b, <v
 ; CHECK-NEXT:    vsetvli a2, zero, e8, m1, ta, ma
 ; CHECK-NEXT:    vadd.vv v8, v8, v8
 ; CHECK-NEXT:    vsra.vi v8, v8, 1
-; CHECK-NEXT:    vmv.v.x v9, a0
-; CHECK-NEXT:    vadd.vv v9, v9, v9
-; CHECK-NEXT:    vsra.vi v9, v9, 1
 ; CHECK-NEXT:    vsetvli zero, a1, e8, m1, ta, ma
-; CHECK-NEXT:    vrem.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    vrem.vx v8, v8, a0, v0.t
 ; CHECK-NEXT:    ret
   %elt.head = insertelement <vscale x 8 x i7> poison, i7 %b, i32 0
   %vb = shufflevector <vscale x 8 x i7> %elt.head, <vscale x 8 x i7> poison, <vscale x 8 x i32> zeroinitializer