[KnownBits] Make shl/lshr/ashr implementations optimal
authorNikita Popov <npopov@redhat.com>
Mon, 15 May 2023 15:38:49 +0000 (17:38 +0200)
committerNikita Popov <npopov@redhat.com>
Tue, 16 May 2023 07:44:26 +0000 (09:44 +0200)
The implementations for shifts were suboptimal in the case where
the max shift amount was >= bitwidth. In that case we should still
use the usual code clamped to BitWidth-1 rather than just giving up
entirely.

Additionally, there was an implementation bug where the known zero
bits for the individual shift amounts were not set in the shl/lshr
implementations. I think after these changes, we'll be able to drop
some of the code in ValueTracking which *also* evaluates all possible
shift amounts and has been papering over this issue.

For the "all poison" case I've opted to return an unknown value for
now. It would be better to return zero, but this has fairly
substantial test fallout, so I figured it's best to not mix it into
this change. (The "correct" return value would be a conflict, but
given that a lot of our APIs assert conflict-freedom, that's probably
not the best idea to actually return.)

Differential Revision: https://reviews.llvm.org/D150587

llvm/lib/Support/KnownBits.cpp
llvm/test/CodeGen/AMDGPU/amdgpu.private-memory.ll
llvm/test/Transforms/InstCombine/not-add.ll
llvm/unittests/Support/KnownBitsTest.cpp

index ad6e1c8..3377dd3 100644 (file)
@@ -182,24 +182,26 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
   // No matter the shift amount, the trailing zeros will stay zero.
   unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
 
-  // Minimum shift amount low bits are known zero.
   APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.ult(BitWidth)) {
-    MinTrailingZeros += MinShiftAmount.getZExtValue();
-    MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
-  }
+  if (MinShiftAmount.uge(BitWidth))
+    // Always poison. Return unknown because we don't like returning conflict.
+    return Known;
+
+  // Minimum shift amount low bits are known zero.
+  MinTrailingZeros += MinShiftAmount.getZExtValue();
+  MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
 
   // If the maximum shift is in range, then find the common bits from all
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
+  if (!LHS.isUnknown()) {
     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
+                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
       // Skip if the shift amount is impossible.
       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
@@ -207,6 +209,7 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
         continue;
       KnownBits SpecificShift;
       SpecificShift.Zero = LHS.Zero << ShiftAmt;
+      SpecificShift.Zero.setLowBits(ShiftAmt);
       SpecificShift.One = LHS.One << ShiftAmt;
       Known = KnownBits::commonBits(Known, SpecificShift);
       if (Known.isUnknown())
@@ -237,22 +240,24 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
 
   // Minimum shift amount high bits are known zero.
   APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.ult(BitWidth)) {
-    MinLeadingZeros += MinShiftAmount.getZExtValue();
-    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-  }
+  if (MinShiftAmount.uge(BitWidth))
+    // Always poison. Return unknown because we don't like returning conflict.
+    return Known;
+
+  MinLeadingZeros += MinShiftAmount.getZExtValue();
+  MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
 
   // If the maximum shift is in range, then find the common bits from all
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
+  if (!LHS.isUnknown()) {
     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
+                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
       // Skip if the shift amount is impossible.
       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
@@ -260,6 +265,7 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
         continue;
       KnownBits SpecificShift = LHS;
       SpecificShift.Zero.lshrInPlace(ShiftAmt);
+      SpecificShift.Zero.setHighBits(ShiftAmt);
       SpecificShift.One.lshrInPlace(ShiftAmt);
       Known = KnownBits::commonBits(Known, SpecificShift);
       if (Known.isUnknown())
@@ -289,28 +295,30 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
 
   // Minimum shift amount high bits are known sign bits.
   APInt MinShiftAmount = RHS.getMinValue();
-  if (MinShiftAmount.ult(BitWidth)) {
-    if (MinLeadingZeros) {
-      MinLeadingZeros += MinShiftAmount.getZExtValue();
-      MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
-    }
-    if (MinLeadingOnes) {
-      MinLeadingOnes += MinShiftAmount.getZExtValue();
-      MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
-    }
+  if (MinShiftAmount.uge(BitWidth))
+    // Always poison. Return unknown because we don't like returning conflict.
+    return Known;
+
+  if (MinLeadingZeros) {
+    MinLeadingZeros += MinShiftAmount.getZExtValue();
+    MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
+  }
+  if (MinLeadingOnes) {
+    MinLeadingOnes += MinShiftAmount.getZExtValue();
+    MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
   }
 
   // If the maximum shift is in range, then find the common bits from all
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
-  if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
+  if (!LHS.isUnknown()) {
     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
-                  MaxShiftAmt = MaxShiftAmount.getZExtValue();
+                  MaxShiftAmt = MaxShiftAmount.getLimitedValue(BitWidth - 1);
          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
       // Skip if the shift amount is impossible.
       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
index c734639..452459a 100644 (file)
@@ -221,7 +221,7 @@ for.end:
 ; SI-PROMOTE-VECT: s_load_dword [[IDX:s[0-9]+]]
 ; SI-PROMOTE-VECT: s_lshl_b32 [[SCALED_IDX:s[0-9]+]], [[IDX]], 4
 ; SI-PROMOTE-VECT: s_lshr_b32 [[SREG:s[0-9]+]], 0x10000, [[SCALED_IDX]]
-; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 0xffff
+; SI-PROMOTE-VECT: s_and_b32 s{{[0-9]+}}, [[SREG]], 1
 define amdgpu_kernel void @short_array(ptr addrspace(1) %out, i32 %index) #0 {
 entry:
   %0 = alloca [2 x i16], addrspace(5)
index 48cd4f5..03f4f44 100644 (file)
@@ -172,7 +172,7 @@ define void @pr50370(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X:%.*]], 1
 ; CHECK-NEXT:    [[B15:%.*]] = srem i32 ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537)), [[XOR]]
-; CHECK-NEXT:    [[B12:%.*]] = add nuw nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537))
+; CHECK-NEXT:    [[B12:%.*]] = add nsw i32 [[B15]], ashr (i32 65536, i32 or (i32 zext (i1 icmp eq (ptr @g, ptr null) to i32), i32 65537))
 ; CHECK-NEXT:    [[B:%.*]] = xor i32 [[B12]], -1
 ; CHECK-NEXT:    store i32 [[B]], ptr undef, align 4
 ; CHECK-NEXT:    ret void
index e8daae3..28f904e 100644 (file)
@@ -270,7 +270,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       },
       checkCorrectnessOnlyBinary);
 
-  // TODO: Make optimal for non-constant cases.
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::shl(Known1, Known2);
@@ -279,9 +278,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         if (N2.uge(N2.getBitWidth()))
           return std::nullopt;
         return N1.shl(N2);
-      },
-      [](const KnownBits &, const KnownBits &Known) {
-        return Known.isConstant();
       });
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -291,9 +287,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         if (N2.uge(N2.getBitWidth()))
           return std::nullopt;
         return N1.lshr(N2);
-      },
-      [](const KnownBits &, const KnownBits &Known) {
-        return Known.isConstant();
       });
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -303,9 +296,6 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         if (N2.uge(N2.getBitWidth()))
           return std::nullopt;
         return N1.ashr(N2);
-      },
-      [](const KnownBits &, const KnownBits &Known) {
-        return Known.isConstant();
       });
 
   testBinaryOpExhaustive(