[DAGCombiner] `convertBuildVecZextToBuildVecWithZeros()`: rework split factor calculation
authorRoman Lebedev <lebedev.ri@gmail.com>
Mon, 2 Jan 2023 15:22:54 +0000 (18:22 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Mon, 2 Jan 2023 15:34:35 +0000 (18:34 +0300)
The original computation was both making assumptions that do not hold
in practice, and being overly pessimistic. We should just check
every possible split factor, and pick the best one.

Fixes https://github.com/llvm/llvm-project/issues/59781

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/build-vector-extract.ll
llvm/test/CodeGen/X86/buildvec-insertvec.ll

index 9235042c54a4a4768b56342dcdc2c2eb860c19cc..7673bf0ed5cb9e5f5366538b5cde26d2fd279279 100644 (file)
@@ -21472,6 +21472,12 @@ SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
   EVT OpVT = N->getOperand(0).getValueType();
   assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
 
+  EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
+
+  if (!TLI.isTypeLegal(OpIntVT) ||
+      (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
+    return SDValue();
+
   unsigned EltBitwidth = VT.getScalarSizeInBits();
   // NOTE: the actual width of operands may be wider than that!
 
@@ -21515,27 +21521,31 @@ SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
 
   // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
   // into how many chunks can we split our element width?
-  unsigned Factor = divideCeil(EltBitwidth, ActiveBits);
-  assert(Factor > 1 && "Did not split the element after all?");
-  assert(EltBitwidth % Factor == 0 && "Can not split into this many chunks?");
-  unsigned ChunkBitwidth = EltBitwidth / Factor;
-  assert(ChunkBitwidth >= ActiveBits && "Underestimated chunk size?");
-  assert(ChunkBitwidth < EltBitwidth && "Failed to reduce element width?");
-
-  EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
-  EVT NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
-  EVT NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
-                                  Factor * N->getNumOperands());
-
-  // Never create illegal types.
-  if (!TLI.isTypeLegal(OpIntVT) || !TLI.isTypeLegal(NewScalarIntVT) ||
-      !TLI.isTypeLegal(NewIntVT))
-    return SDValue();
-
-  if (LegalOperations &&
-      !(TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT) &&
-        TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
-        TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT)))
+  EVT NewScalarIntVT, NewIntVT;
+  std::optional<unsigned> Factor;
+  // We can split the element into at least two chunks, but not into more
+  // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
+  // for which the element width is a multiple of it,
+  // and the resulting types/operations on that chunk width are legal.
+  assert(2 * ActiveBits <= EltBitwidth &&
+         "We know that half or less bits of the element are active.");
+  for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
+    if (EltBitwidth % Scale != 0)
+      continue;
+    unsigned ChunkBitwidth = EltBitwidth / Scale;
+    assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
+    NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
+    NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
+                                Scale * N->getNumOperands());
+    if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
+        (LegalOperations &&
+         !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
+           TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
+      continue;
+    Factor = Scale;
+    break;
+  }
+  if (!Factor)
     return SDValue();
 
   SDLoc DL(N);
@@ -21546,16 +21556,16 @@ SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
   NewOps.reserve(NewIntVT.getVectorNumElements());
   for (auto I : enumerate(N->ops())) {
     SDValue Op = I.value();
-    // FIXME: after allowing UNDEF's, do handle them here.
+    assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
     unsigned SrcOpIdx = I.index();
     if (KnownZeroOps[SrcOpIdx]) {
-      NewOps.append(Factor, ZeroOp);
+      NewOps.append(*Factor, ZeroOp);
       continue;
     }
     Op = DAG.getBitcast(OpIntVT, Op);
     Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
     NewOps.emplace_back(Op);
-    NewOps.append(Factor - 1, ZeroOp);
+    NewOps.append(*Factor - 1, ZeroOp);
   }
   assert(NewOps.size() == NewIntVT.getVectorNumElements());
   SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
index 53e8b568f70963ca2f39afe86f209c7dacfc68c9..36b1b2cdcb43204e5f80fc74adc89a8569b1e300 100644 (file)
@@ -220,7 +220,7 @@ define <2 x i64> @extract0_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[0]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 0
@@ -246,7 +246,7 @@ define <2 x i64> @extract1_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[1]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 1
@@ -272,7 +272,7 @@ define <2 x i64> @extract2_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[2]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 2
@@ -298,7 +298,7 @@ define <2 x i64> @extract3_i16_zext_insert0_i64_zero(<8 x i16> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.h[3]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <8 x i16> %x, i32 3
@@ -430,7 +430,7 @@ define <2 x i64> @extract0_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[0]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 0
@@ -456,7 +456,7 @@ define <2 x i64> @extract1_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[1]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 1
@@ -482,7 +482,7 @@ define <2 x i64> @extract2_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[2]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 2
@@ -508,7 +508,7 @@ define <2 x i64> @extract3_i8_zext_insert0_i64_zero(<16 x i8> %x) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    umov w8, v0.b[3]
-; CHECK-NEXT:    mov v1.d[0], x8
+; CHECK-NEXT:    mov v1.s[0], w8
 ; CHECK-NEXT:    mov v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %e = extractelement <16 x i8> %x, i32 3
index 9c8f0344437514bd7af859860eeec1811865946c..e1f139c66db9539609969c44b7f12528091e716b 100644 (file)
@@ -829,3 +829,19 @@ define i32 @PR46586(ptr %p, <4 x i32> %v) {
   %t35 = extractelement <4 x i32> %t34, i32 3
   ret i32 %t35
 }
+
+define void @pr59781(ptr %in, ptr %out) {
+; CHECK-LABEL: pr59781:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movzwl (%rdi), %eax
+; CHECK-NEXT:    movzbl 2(%rdi), %ecx
+; CHECK-NEXT:    shlq $16, %rcx
+; CHECK-NEXT:    orq %rax, %rcx
+; CHECK-NEXT:    movq %rcx, (%rsi)
+; CHECK-NEXT:    retq
+  %bf.load = load i24, ptr %in, align 8
+  %conv = zext i24 %bf.load to i64
+  %splat.splatinsert = insertelement <1 x i64> zeroinitializer, i64 %conv, i64 0
+  store <1 x i64> %splat.splatinsert, ptr %out, align 8
+  ret void
+}