[NFC][InstCombine] Extract canTryToConstantAddTwoShiftAmounts() as helper
authorRoman Lebedev <lebedev.ri@gmail.com>
Sun, 4 Apr 2021 20:23:10 +0000 (23:23 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Sun, 4 Apr 2021 20:26:41 +0000 (23:26 +0300)
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

index 2007cf0..8413ae9 100644 (file)
@@ -21,6 +21,30 @@ using namespace PatternMatch;
 
 #define DEBUG_TYPE "instcombine"
 
+bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1,
+                                        Value *ShAmt1) {
+  // We have two shift amounts from two different shifts. The types of those
+  // shift amounts may not match. If that's the case let's bailout now..
+  if (ShAmt0->getType() != ShAmt1->getType())
+    return false;
+
+  // As input, we have the following pattern:
+  //   Sh0 (Sh1 X, Q), K
+  // We want to rewrite that as:
+  //   Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
+  // While we know that originally (Q+K) would not overflow
+  // (because  2 * (N-1) u<= iN -1), we have looked past extensions of
+  // shift amounts. so it may now overflow in smaller bitwidth.
+  // To ensure that does not happen, we need to ensure that the total maximal
+  // shift amount is still representable in that smaller bit width.
+  unsigned MaximalPossibleTotalShiftAmount =
+      (Sh0->getType()->getScalarSizeInBits() - 1) +
+      (Sh1->getType()->getScalarSizeInBits() - 1);
+  APInt MaximalRepresentableShiftAmount =
+      APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
+  return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount);
+}
+
 // Given pattern:
 //   (x shiftopcode Q) shiftopcode K
 // we should rewrite it as
@@ -57,26 +81,8 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
   if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
     return nullptr;
 
-  // We have two shift amounts from two different shifts. The types of those
-  // shift amounts may not match. If that's the case let's bailout now..
-  if (ShAmt0->getType() != ShAmt1->getType())
-    return nullptr;
-
-  // As input, we have the following pattern:
-  //   Sh0 (Sh1 X, Q), K
-  // We want to rewrite that as:
-  //   Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
-  // While we know that originally (Q+K) would not overflow
-  // (because  2 * (N-1) u<= iN -1), we have looked past extensions of
-  // shift amounts. so it may now overflow in smaller bitwidth.
-  // To ensure that does not happen, we need to ensure that the total maximal
-  // shift amount is still representable in that smaller bit width.
-  unsigned MaximalPossibleTotalShiftAmount =
-      (Sh0->getType()->getScalarSizeInBits() - 1) +
-      (Sh1->getType()->getScalarSizeInBits() - 1);
-  APInt MaximalRepresentableShiftAmount =
-      APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
-  if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
+  // Verify that it would be safe to try to add those two shift amounts.
+  if (!canTryToConstantAddTwoShiftAmounts(Sh0, ShAmt0, Sh1, ShAmt1))
     return nullptr;
 
   // We are only looking for signbit extraction if we have two right shifts.