[ARM] Fix crash in chained BFI combine due to incorrectly RAUW'ing a node.
authorAmara Emerson <amara@apple.com>
Thu, 24 Jun 2021 18:10:42 +0000 (11:10 -0700)
committerAmara Emerson <amara@apple.com>
Fri, 25 Jun 2021 06:35:47 +0000 (23:35 -0700)
For a bfi chain like:
a = bfi input, x, y
b = bfi a, x', y'

The previous code was RAUW'ing a with x, mutating the second 'b' bfi, and when
SelectionDAG's CSE code ended up deleting it unexpectedly, bad things happend.
There's no need to RAUW in this case because we can just return our newly
created replacement BFI node. It also looked incorrect because it didn't account
for other users of the 'a' bfi.

Since it seems that chains of more than 2 BFI nodes are hard/impossible to
produce without this combine kicking in at some point, I've removed that
functionality since it had no test coverage.

rdar://79095399

Differential Revision: https://reviews.llvm.org/D104868

llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/ARM/bfi-chain-cse-crash.ll [new file with mode: 0644]

index 23451af..f7c8066 100644 (file)
@@ -13954,45 +13954,32 @@ static bool BitsProperlyConcatenate(const APInt &A, const APInt &B) {
 }
 
 static SDValue FindBFIToCombineWith(SDNode *N) {
-  // We have a BFI in N. Follow a possible chain of BFIs and find a BFI it can combine with,
-  // if one exists.
+  // We have a BFI in N. Find a BFI it can combine with, if one exists.
   APInt ToMask, FromMask;
   SDValue From = ParseBFI(N, ToMask, FromMask);
   SDValue To = N->getOperand(0);
 
-  // Now check for a compatible BFI to merge with. We can pass through BFIs that
-  // aren't compatible, but not if they set the same bit in their destination as
-  // we do (or that of any BFI we're going to combine with).
   SDValue V = To;
-  APInt CombinedToMask = ToMask;
-  while (V.getOpcode() == ARMISD::BFI) {
-    APInt NewToMask, NewFromMask;
-    SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask);
-    if (NewFrom != From) {
-      // This BFI has a different base. Keep going.
-      CombinedToMask |= NewToMask;
-      V = V.getOperand(0);
-      continue;
-    }
+  if (V.getOpcode() != ARMISD::BFI)
+    return SDValue();
 
-    // Do the written bits conflict with any we've seen so far?
-    if ((NewToMask & CombinedToMask).getBoolValue())
-      // Conflicting bits - bail out because going further is unsafe.
-      return SDValue();
+  APInt NewToMask, NewFromMask;
+  SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask);
+  if (NewFrom != From)
+    return SDValue();
 
-    // Are the new bits contiguous when combined with the old bits?
-    if (BitsProperlyConcatenate(ToMask, NewToMask) &&
-        BitsProperlyConcatenate(FromMask, NewFromMask))
-      return V;
-    if (BitsProperlyConcatenate(NewToMask, ToMask) &&
-        BitsProperlyConcatenate(NewFromMask, FromMask))
-      return V;
+  // Do the written bits conflict with any we've seen so far?
+  if ((NewToMask & ToMask).getBoolValue())
+    // Conflicting bits.
+    return SDValue();
 
-    // We've seen a write to some bits, so track it.
-    CombinedToMask |= NewToMask;
-    // Keep going...
-    V = V.getOperand(0);
-  }
+  // Are the new bits contiguous when combined with the old bits?
+  if (BitsProperlyConcatenate(ToMask, NewToMask) &&
+      BitsProperlyConcatenate(FromMask, NewFromMask))
+    return V;
+  if (BitsProperlyConcatenate(NewToMask, ToMask) &&
+      BitsProperlyConcatenate(NewFromMask, FromMask))
+    return V;
 
   return SDValue();
 }
@@ -14018,40 +14005,35 @@ static SDValue PerformBFICombine(SDNode *N,
       return DCI.DAG.getNode(ARMISD::BFI, SDLoc(N), N->getValueType(0),
                              N->getOperand(0), N1.getOperand(0),
                              N->getOperand(2));
-  } else if (N->getOperand(0).getOpcode() == ARMISD::BFI) {
-    // We have a BFI of a BFI. Walk up the BFI chain to see how long it goes.
-    // Keep track of any consecutive bits set that all come from the same base
-    // value. We can combine these together into a single BFI.
-    SDValue CombineBFI = FindBFIToCombineWith(N);
-    if (CombineBFI == SDValue())
-      return SDValue();
+    return SDValue();
+  }
+  // Look for another BFI to combine with.
+  SDValue CombineBFI = FindBFIToCombineWith(N);
+  if (CombineBFI == SDValue())
+    return SDValue();
 
-    // We've found a BFI.
-    APInt ToMask1, FromMask1;
-    SDValue From1 = ParseBFI(N, ToMask1, FromMask1);
+  // We've found a BFI.
+  APInt ToMask1, FromMask1;
+  SDValue From1 = ParseBFI(N, ToMask1, FromMask1);
 
-    APInt ToMask2, FromMask2;
-    SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2);
-    assert(From1 == From2);
-    (void)From2;
+  APInt ToMask2, FromMask2;
+  SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2);
+  assert(From1 == From2);
+  (void)From2;
 
-    // First, unlink CombineBFI.
-    DCI.DAG.ReplaceAllUsesWith(CombineBFI, CombineBFI.getOperand(0));
-    // Then create a new BFI, combining the two together.
-    APInt NewFromMask = FromMask1 | FromMask2;
-    APInt NewToMask = ToMask1 | ToMask2;
+  // Create a new BFI, combining the two together.
+  APInt NewFromMask = FromMask1 | FromMask2;
+  APInt NewToMask = ToMask1 | ToMask2;
 
-    EVT VT = N->getValueType(0);
-    SDLoc dl(N);
+  EVT VT = N->getValueType(0);
+  SDLoc dl(N);
 
-    if (NewFromMask[0] == 0)
-      From1 = DCI.DAG.getNode(
+  if (NewFromMask[0] == 0)
+    From1 = DCI.DAG.getNode(
         ISD::SRL, dl, VT, From1,
         DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT));
-    return DCI.DAG.getNode(ARMISD::BFI, dl, VT, N->getOperand(0), From1,
-                           DCI.DAG.getConstant(~NewToMask, dl, VT));
-  }
-  return SDValue();
+  return DCI.DAG.getNode(ARMISD::BFI, dl, VT, CombineBFI.getOperand(0), From1,
+                         DCI.DAG.getConstant(~NewToMask, dl, VT));
 }
 
 /// PerformVMOVRRDCombine - Target-specific dag combine xforms for
diff --git a/llvm/test/CodeGen/ARM/bfi-chain-cse-crash.ll b/llvm/test/CodeGen/ARM/bfi-chain-cse-crash.ll
new file mode 100644 (file)
index 0000000..e58be8f
--- /dev/null
@@ -0,0 +1,41 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=thumbv7s | FileCheck %s
+target datalayout = "e-m:o-p:32:32-f64:32:64-v64:32:64-v128:32:128-a:0:32-n32-S32"
+target triple = "thumbv7s-apple-ios3.1.3"
+
+define void @bfi_chain_cse_crash(i8* %0, i8 *%ptr) {
+; CHECK-LABEL: bfi_chain_cse_crash:
+; CHECK:       @ %bb.0: @ %entry
+; CHECK-NEXT:    ldrb r2, [r0]
+; CHECK-NEXT:    and r3, r2, #1
+; CHECK-NEXT:    lsr.w r12, r2, #3
+; CHECK-NEXT:    bfi r3, r12, #3, #1
+; CHECK-NEXT:    strb r3, [r0]
+; CHECK-NEXT:    and r0, r2, #4
+; CHECK-NEXT:    bfi r0, r12, #3, #1
+; CHECK-NEXT:    strb r0, [r1]
+; CHECK-NEXT:    bx lr
+entry:
+  %1 = load i8, i8* %0, align 1
+  %2 = and i8 %1, 1
+  %3 = select i1 false, i8 %2, i8 0
+  %4 = and i8 %1, 4
+  %5 = icmp eq i8 %4, 0
+  %6 = zext i8 %3 to i32
+  %7 = or i32 %6, 4
+  %8 = trunc i32 %7 to i8
+  %9 = select i1 %5, i8 %3, i8 %8
+  %10 = and i8 %1, 8
+  %11 = icmp eq i8 %10, 0
+  %12 = zext i8 %2 to i32
+  %13 = or i32 %12, 8
+  %14 = trunc i32 %13 to i8
+  %15 = zext i8 %9 to i32
+  %16 = or i32 %15, 8
+  %17 = trunc i32 %16 to i8
+  %18 = select i1 %11, i8 %2, i8 %14
+  %19 = select i1 %11, i8 %9, i8 %17
+  store i8 %18, i8* %0, align 1
+  store i8 %19, i8* %ptr, align 1
+  ret void
+}