[KnownBits] Handle shifts over wide types
authorNikita Popov <npopov@redhat.com>
Tue, 16 May 2023 09:23:40 +0000 (11:23 +0200)
committerNikita Popov <npopov@redhat.com>
Tue, 16 May 2023 09:26:39 +0000 (11:26 +0200)
Do not assert if the bit width is larger than 64 bits. This case
is currently hidden from the IR layer by other checks, but gets
exposed with future changes.

llvm/lib/Support/KnownBits.cpp
llvm/unittests/Support/KnownBitsTest.cpp

index 0fd3e5a..ddeb6a4 100644 (file)
@@ -195,8 +195,8 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
   if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
+    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
@@ -251,8 +251,8 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
   if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
+    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
@@ -312,8 +312,8 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
   // possible shifts.
   APInt MaxShiftAmount = RHS.getMaxValue();
   if (!LHS.isUnknown()) {
-    uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
-    uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
+    uint64_t ShiftAmtZeroMask = (~RHS.Zero).zextOrTrunc(64).getZExtValue();
+    uint64_t ShiftAmtOneMask = RHS.One.zextOrTrunc(64).getZExtValue();
     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
     Known.Zero.setAllBits();
     Known.One.setAllBits();
index 28f904e..ece7e80 100644 (file)
@@ -352,6 +352,20 @@ TEST(KnownBitsTest, UnaryExhaustive) {
       [](const APInt &N) { return N * N; }, checkCorrectnessOnlyUnary);
 }
 
+TEST(KnownBitsTest, WideShifts) {
+  unsigned BitWidth = 128;
+  KnownBits Unknown(BitWidth);
+  KnownBits AllOnes = KnownBits::makeConstant(APInt::getAllOnes(BitWidth));
+
+  KnownBits ShlResult(BitWidth);
+  ShlResult.makeNegative();
+  EXPECT_EQ(KnownBits::shl(AllOnes, Unknown), ShlResult);
+  KnownBits LShrResult(BitWidth);
+  LShrResult.One.setBit(0);
+  EXPECT_EQ(KnownBits::lshr(AllOnes, Unknown), LShrResult);
+  EXPECT_EQ(KnownBits::ashr(AllOnes, Unknown), AllOnes);
+}
+
 TEST(KnownBitsTest, ICmpExhaustive) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {