[KnownBits] Partially synchronize shift implementations (NFC)
authorNikita Popov <npopov@redhat.com>
Fri, 26 May 2023 12:03:12 +0000 (14:03 +0200)
committerNikita Popov <npopov@redhat.com>
Fri, 26 May 2023 12:16:14 +0000 (14:16 +0200)
And remove some bits of effectively dead code.

llvm/lib/Support/KnownBits.cpp

index c8e4a89..a7ca7c0 100644 (file)
@@ -199,11 +199,6 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
   KnownBits Known(BitWidth);
   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
   if (LHS.isUnknown()) {
-    if (MinShiftAmount == BitWidth) {
-      // Always poison. Return zero because we don't like returning conflict.
-      Known.setAllZero();
-      return Known;
-    }
     Known.Zero.setLowBits(MinShiftAmount);
     if (NUW && NSW && MinShiftAmount != 0)
       Known.makeNonNegative();
@@ -261,120 +256,89 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
 
 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
   unsigned BitWidth = LHS.getBitWidth();
-  KnownBits Known(BitWidth);
-
-  if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
-    unsigned Shift = RHS.getConstant().getZExtValue();
-    Known = LHS;
-    Known.Zero.lshrInPlace(Shift);
-    Known.One.lshrInPlace(Shift);
+  auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
+    KnownBits Known = LHS;
+    Known.Zero.lshrInPlace(ShiftAmt);
+    Known.One.lshrInPlace(ShiftAmt);
     // High bits are known zero.
-    Known.Zero.setHighBits(Shift);
+    Known.Zero.setHighBits(ShiftAmt);
     return Known;
-  }
-
-  // Minimum shift amount high bits are known zero.
-  APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.uge(BitWidth)) {
-    // Always poison. Return zero because we don't like returning conflict.
-    Known.setAllZero();
-    return Known;
-  }
+  };
 
+  // Fast path for a common case when LHS is completely unknown.
+  KnownBits Known(BitWidth);
+  unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
   if (LHS.isUnknown()) {
-    // No matter the shift amount, the leading zeros will stay zero.
-    unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
-    MinLeadingZeros += MinShiftAmount.getZExtValue();
-    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-    Known.Zero.setHighBits(MinLeadingZeros);
+    Known.Zero.setHighBits(MinShiftAmount);
     return Known;
   }
 
   // Find the common bits from all possible shifts.
-  APInt MaxShiftAmount = RHS.getMaxValue();
-  uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
-  uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
-  assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+  APInt MaxValue = RHS.getMaxValue();
+  unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+  unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
+  unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
   Known.One.setAllBits();
-  for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
-       ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+  for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
+       ++ShiftAmt) {
     // Skip if the shift amount is impossible.
-    if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+    if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
       continue;
-    KnownBits SpecificShift = LHS;
-    SpecificShift.Zero.lshrInPlace(ShiftAmt);
-    SpecificShift.Zero.setHighBits(ShiftAmt);
-    SpecificShift.One.lshrInPlace(ShiftAmt);
-    Known = Known.intersectWith(SpecificShift);
+    Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
     if (Known.isUnknown())
       break;
   }
 
+  // All shift amounts may result in poison.
+  if (Known.hasConflict())
+    Known.setAllZero();
   return Known;
 }
 
 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
   unsigned BitWidth = LHS.getBitWidth();
-  KnownBits Known(BitWidth);
-
-  if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
-    unsigned Shift = RHS.getConstant().getZExtValue();
-    Known = LHS;
-    Known.Zero.ashrInPlace(Shift);
-    Known.One.ashrInPlace(Shift);
-    return Known;
-  }
-
-  // Minimum shift amount high bits are known sign bits.
-  APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.uge(BitWidth)) {
-    // Always poison. Return zero because we don't like returning conflict.
-    Known.setAllZero();
+  auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
+    KnownBits Known = LHS;
+    Known.Zero.ashrInPlace(ShiftAmt);
+    Known.One.ashrInPlace(ShiftAmt);
     return Known;
-  }
+  };
 
+  // Fast path for a common case when LHS is completely unknown.
+  KnownBits Known(BitWidth);
+  unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
   if (LHS.isUnknown()) {
-    // No matter the shift amount, the leading sign bits will stay.
-    unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
-    unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
-    if (MinLeadingZeros) {
-      MinLeadingZeros += MinShiftAmount.getZExtValue();
-      MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-    }
-    if (MinLeadingOnes) {
-      MinLeadingOnes += MinShiftAmount.getZExtValue();
-      MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
+    if (MinShiftAmount == BitWidth) {
+      // Always poison. Return zero because we don't like returning conflict.
+      Known.setAllZero();
+      return Known;
     }
-    Known.Zero.setHighBits(MinLeadingZeros);
-    Known.One.setHighBits(MinLeadingOnes);
     return Known;
   }
 
   // Find the common bits from all possible shifts.
-  APInt MaxShiftAmount = RHS.getMaxValue();
-  uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
-  uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
-  assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
+  APInt MaxValue = RHS.getMaxValue();
+  unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
+  unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
+  unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
   Known.Zero.setAllBits();
   Known.One.setAllBits();
-  for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
-       ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
+  for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
+      ++ShiftAmt) {
     // Skip if the shift amount is impossible.
-    if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
+    if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
       continue;
-    KnownBits SpecificShift = LHS;
-    SpecificShift.Zero.ashrInPlace(ShiftAmt);
-    SpecificShift.One.ashrInPlace(ShiftAmt);
-    Known = Known.intersectWith(SpecificShift);
+    Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
     if (Known.isUnknown())
       break;
   }
 
+  // All shift amounts may result in poison.
+  if (Known.hasConflict())
+    Known.setAllZero();
   return Known;
 }