[DAGCombiner] `convertBuildVecZextToBuildVecWithZeros()`: rework split factor calculation
authorRoman Lebedev <lebedev.ri@gmail.com>
Mon, 2 Jan 2023 15:22:54 +0000 (18:22 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Mon, 2 Jan 2023 15:34:35 +0000 (18:34 +0300)
The original computation was both making assumptions that do not hold
in practice, and being overly pessimistic. We should just check
every possible split factor, and pick the best one.

Fixes https://github.com/llvm/llvm-project/issues/59781

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/build-vector-extract.ll
llvm/test/CodeGen/X86/buildvec-insertvec.ll

index 9235042..7673bf0 100644 (file)
@@ -21472,6 +21472,12 @@ SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
   EVT OpVT = N->getOperand(0).getValueType();
   assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
 
+  EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
+
+  if (!TLI.isTypeLegal(OpIntVT) ||
+      (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
+    return SDValue();
+
   unsigned EltBitwidth = VT.getScalarSizeInBits();
   // NOTE: the actual width of operands may be wider than that!
 
@@ -21515,27 +21521,31 @@ SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
 
   // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
   // into how many chunks can we split our element width?
-  unsigned Factor = divideCeil(EltBitwidth, ActiveBits);
-  assert(Factor > 1 && "Did not split the element after all?");
-  assert(EltBitwidth % Factor == 0 && "Can not split into this many chunks?");
-  unsigned ChunkBitwidth = EltBitwidth / Factor;
-  assert(ChunkBitwidth >= ActiveBits && "Underestimated chunk size?");
-  assert(ChunkBitwidth < EltBitwidth && "Failed to reduce element width?");
-
-  EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
-  EVT NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
-  EVT NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
-                                  Factor * N->getNumOperands());
-
-  // Never create illegal types.
-  if (!TLI.isTypeLegal(OpIntVT) || !TLI.isTypeLegal(NewScalarIntVT) ||
-      !TLI.isTypeLegal(NewIntVT))
-    return SDValue();
-
-  if (LegalOperations &&
-      !(TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT) &&
-        TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
-        TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT)))
+  EVT NewScalarIntVT, NewIntVT;
+  std::optional<unsigned> Factor;
+  // We can split the element into at least two chunks, but not into more
+  // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
+  // for which the element width is a multiple of it,
+  // and the resulting types/operations on that chunk width are legal.
+  assert(2 * ActiveBits <= EltBitwidth &&
+         "We know that half or less bits of the element are active.");
+  for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
+    if (EltBitwidth % Scale != 0)
+      continue;
+    unsigned ChunkBitwidth = EltBitwidth / Scale;
+    assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
+    NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
+    NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
+                                Scale * N->getNumOperands());
+    if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
+        (LegalOperations &&
+         !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
+           TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
+      continue;
+    Factor = Scale;
+    break;
+  }
+  if (!Factor)
     return SDValue();
 
   SDLoc DL(N);
@@ -21546,16 +21556,16 @@ SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
   NewOps.reserve(NewIntVT.getVectorNumElements());
   for (auto I : enumerate(N->ops())) {
     SDValue Op = I.value();
-    // FIXME: after allowing UNDEF's, do handle them here.
+    assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
     unsigned SrcOpIdx = I.index();
     if (KnownZeroOps[SrcOpIdx]) {
-      NewOps.append(Factor, ZeroOp);
+      NewOps.append(*Factor, ZeroOp);
       continue;
     }
     Op = DAG.getBitcast(OpIntVT, Op);
     Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
     NewOps.emplace_back(Op);
-    NewOps.append(Factor - 1, ZeroOp);
+    NewOps.append(*Factor - 1, ZeroOp);
   }
   assert(NewOps.size() == NewIntVT.getVectorNumElements());
   SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
index 53e8b56..36b1b2c 100644 (file)
@@ -220,7 +220,7 @@ define <2 x i64> @extract0_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[0]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 0
@@ -246,7 +246,7 @@ define <2 x i64> @extract1_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[1]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 1
@@ -272,7 +272,7 @@ define <2 x i64> @extract2_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[2]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 2
@@ -298,7 +298,7 @@ define <2 x i64> @extract3_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[3]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 3
@@ -430,7 +430,7 @@ define <2 x i64> @extract0_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[0]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 0
@@ -456,7 +456,7 @@ define <2 x i64> @extract1_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[1]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 1
@@ -482,7 +482,7 @@ define <2 x i64> @extract2_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[2]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 2
@@ -508,7 +508,7 @@ define <2 x i64> @extract3_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[3]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 3
index 9c8f034..e1f139c 100644 (file)
@@ -829,3 +829,19 @@ define i32 @PR46586(ptr %p, <4 x i32> %v) {
   %t35 = extractelement <4 x i32> %t34, i32 3
   ret i32 %t35
 }
+
+define void @pr59781(ptr %in, ptr %out) {
+; CHECK-LABEL: pr59781:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movzwl (%rdi), %eax
+; CHECK-NEXT:    movzbl 2(%rdi), %ecx
+; CHECK-NEXT:    shlq $16, %rcx
+; CHECK-NEXT:    orq %rax, %rcx
+; CHECK-NEXT:    movq %rcx, (%rsi)
+; CHECK-NEXT:    retq
+  %bf.load = load i24, ptr %in, align 8
+  %conv = zext i24 %bf.load to i64
+  %splat.splatinsert = insertelement <1 x i64> zeroinitializer, i64 %conv, i64 0
+  store <1 x i64> %splat.splatinsert, ptr %out, align 8
+  ret void
+}