[ARM] Reassociate BFI
authorDavid Green <david.green@arm.com>
Thu, 1 Jul 2021 20:08:13 +0000 (21:08 +0100)
committerDavid Green <david.green@arm.com>
Thu, 1 Jul 2021 20:08:13 +0000 (21:08 +0100)
D104868 removed an (incorrect) fold for distributing BFI instructions in
a chain, combining them into a single instruction. BFIs like that are
hard to test, as the patterns are often destroyed before they become
BFIs. But it can come up in places, with chains of BFIs that can be
combined.

This patch adds a replacement, which reassociates BFI instructions with
non-overlapping insertion masks so that low bits are inserted first.
This can end up sorting the nodes so that adjacent inserts are next to
one another, allowing the existing folds to combine into a single BFI.

Differential Revision: https://reviews.llvm.org/D105096

llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/ARM/bfi.ll

index 43b8cec..653dbdf 100644 (file)
@@ -14076,7 +14076,9 @@ static SDValue FindBFIToCombineWith(SDNode *N) {
 
 static SDValue PerformBFICombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
+
   if (N1.getOpcode() == ISD::AND) {
     // (bfi A, (and B, Mask1), Mask2) -> (bfi A, B, Mask2) iff
     // the bits being cleared by the AND are not demanded by the BFI.
@@ -14097,33 +14099,54 @@ static SDValue PerformBFICombine(SDNode *N,
                              N->getOperand(2));
     return SDValue();
   }
+
   // Look for another BFI to combine with.
-  SDValue CombineBFI = FindBFIToCombineWith(N);
-  if (CombineBFI == SDValue())
-    return SDValue();
+  if (SDValue CombineBFI = FindBFIToCombineWith(N)) {
+    // We've found a BFI.
+    APInt ToMask1, FromMask1;
+    SDValue From1 = ParseBFI(N, ToMask1, FromMask1);
 
-  // We've found a BFI.
-  APInt ToMask1, FromMask1;
-  SDValue From1 = ParseBFI(N, ToMask1, FromMask1);
+    APInt ToMask2, FromMask2;
+    SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2);
+    assert(From1 == From2);
+    (void)From2;
 
-  APInt ToMask2, FromMask2;
-  SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2);
-  assert(From1 == From2);
-  (void)From2;
+    // Create a new BFI, combining the two together.
+    APInt NewFromMask = FromMask1 | FromMask2;
+    APInt NewToMask = ToMask1 | ToMask2;
 
-  // Create a new BFI, combining the two together.
-  APInt NewFromMask = FromMask1 | FromMask2;
-  APInt NewToMask = ToMask1 | ToMask2;
+    EVT VT = N->getValueType(0);
+    SDLoc dl(N);
 
-  EVT VT = N->getValueType(0);
-  SDLoc dl(N);
+    if (NewFromMask[0] == 0)
+      From1 = DCI.DAG.getNode(
+          ISD::SRL, dl, VT, From1,
+          DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT));
+    return DCI.DAG.getNode(ARMISD::BFI, dl, VT, CombineBFI.getOperand(0), From1,
+                          DCI.DAG.getConstant(~NewToMask, dl, VT));
+  }
 
-  if (NewFromMask[0] == 0)
-    From1 = DCI.DAG.getNode(
-        ISD::SRL, dl, VT, From1,
-        DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT));
-  return DCI.DAG.getNode(ARMISD::BFI, dl, VT, CombineBFI.getOperand(0), From1,
-                         DCI.DAG.getConstant(~NewToMask, dl, VT));
+  // Reassociate BFI(BFI (A, B, M1), C, M2) to BFI(BFI (A, C, M2), B, M1) so
+  // that lower bit insertions are performed first, providing that M1 and M2
+  // do no overlap. This can allow multiple BFI instructions to be combined
+  // together by the other folds above.
+  if (N->getOperand(0).getOpcode() == ARMISD::BFI) {
+    APInt ToMask1 = ~N->getConstantOperandAPInt(2);
+    APInt ToMask2 = ~N0.getConstantOperandAPInt(2);
+
+    if (!N0.hasOneUse() || (ToMask1 & ToMask2) != 0 ||
+        ToMask1.countLeadingZeros() < ToMask2.countLeadingZeros())
+      return SDValue();
+
+    EVT VT = N->getValueType(0);
+    SDLoc dl(N);
+    SDValue BFI1 = DCI.DAG.getNode(ARMISD::BFI, dl, VT, N0.getOperand(0),
+                                   N->getOperand(1), N->getOperand(2));
+    return DCI.DAG.getNode(ARMISD::BFI, dl, VT, BFI1, N0.getOperand(1),
+                           N0.getOperand(2));
+  }
+
+  return SDValue();
 }
 
 /// PerformVMOVRRDCombine - Target-specific dag combine xforms for
index b6126ab..786bf1c 100644 (file)
@@ -397,23 +397,21 @@ define void @bfi3_uses(i32 %a, i32 %b) {
 define i32 @bfi4(i32 %A, i2 zeroext %BB, i32* %d) {
 ; CHECK-LABEL: bfi4:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    lsr r3, r0, #1
-; CHECK-NEXT:    mov r12, #96
-; CHECK-NEXT:    bfi r1, r3, #2, #1
-; CHECK-NEXT:    tst r0, #32
-; CHECK-NEXT:    movweq r12, #32
-; CHECK-NEXT:    bfi r1, r3, #9, #1
-; CHECK-NEXT:    lsr r3, r0, #2
-; CHECK-NEXT:    bfi r1, r3, #3, #1
-; CHECK-NEXT:    bfi r1, r3, #10, #1
+; CHECK-NEXT:    push {r11, lr}
+; CHECK-NEXT:    lsr r12, r0, #1
 ; CHECK-NEXT:    and r3, r0, #8
+; CHECK-NEXT:    bfi r1, r12, #2, #2
+; CHECK-NEXT:    mov lr, #96
+; CHECK-NEXT:    tst r0, #32
+; CHECK-NEXT:    bfi r1, r12, #9, #2
+; CHECK-NEXT:    movweq lr, #32
 ; CHECK-NEXT:    orr r1, r1, r3, lsl #8
 ; CHECK-NEXT:    and r3, r0, #64
 ; CHECK-NEXT:    and r0, r0, #128
-; CHECK-NEXT:    orr r1, r1, r12
+; CHECK-NEXT:    orr r1, r1, lr
 ; CHECK-NEXT:    orr r1, r1, r3, lsl #1
 ; CHECK-NEXT:    str r1, [r2]
-; CHECK-NEXT:    bx lr
+; CHECK-NEXT:    pop {r11, pc}
 entry:
   %B = zext i2 %BB to i32
   %and = and i32 %A, 2