[APInt] Enable APInt to support zero bit integers.
authorChris Lattner <clattner@nondot.org>
Thu, 9 Sep 2021 23:20:25 +0000 (16:20 -0700)
committerChris Lattner <clattner@nondot.org>
Fri, 10 Sep 2021 05:43:54 +0000 (22:43 -0700)
Motivation: APInt not supporting zero bit values leads to
a lot of special cases in various bits of code, particularly
when using APInt as a bit vector (where you want to start with
zero bits and then concat on more.  This is particularly
challenging in the CIRCT project, where the absence of zero-bit
ConstantOp forces duplication of ops and makes instcombine-like
logic far more complicated.

Approach: zero bit integers are weird.  There are two reasonable
approaches: either make it illegal to do general arithmetic on
them (e.g. sign extends), or treat them as as implicitly having
a zero value.  This patch takes the conservative approach, which
enables their use in bitvector applications.

Differential Revision: https://reviews.llvm.org/D109555

llvm/include/llvm/ADT/APInt.h
llvm/lib/Support/APInt.cpp
llvm/unittests/ADT/APIntTest.cpp

index 4629250..4aad2c9 100644 (file)
@@ -66,6 +66,11 @@ inline APInt operator-(APInt);
 ///     not.
 ///   * In general, the class tries to follow the style of computation that LLVM
 ///     uses in its IR. This simplifies its use for LLVM.
+///   * APInt supports zero-bit-width values, but operations that require bits
+///     are not defined on it (e.g. you cannot ask for the sign of a zero-bit
+///     integer).  This means that operations like zero extension and logical
+///     shifts are defined, but sign extension and ashr is not.  Zero bit values
+///     compare and hash equal to themselves, and countLeadingZeros returns 0.
 ///
 class LLVM_NODISCARD APInt {
 public:
@@ -102,7 +107,6 @@ public:
   /// \param isSigned how to treat signedness of val
   APInt(unsigned numBits, uint64_t val, bool isSigned = false)
       : BitWidth(numBits) {
-    assert(BitWidth && "bitwidth too small");
     if (isSingleWord()) {
       U.VAL = val;
       clearUnusedBits();
@@ -142,11 +146,7 @@ public:
   /// \param radix the radix to use for the conversion
   APInt(unsigned numBits, StringRef str, uint8_t radix);
 
-  /// Default constructor that creates an uninteresting APInt
-  /// representing a 1-bit zero value.
-  ///
-  /// This is useful for object deserialization (pair this with the static
-  ///  method Read).
+  /// Default constructor that creates an APInt with a 1-bit zero value.
   explicit APInt() : BitWidth(1) { U.VAL = 0; }
 
   /// Copy Constructor.
@@ -179,6 +179,9 @@ public:
   /// NOTE: This is soft-deprecated.  Please use `getZero()` instead.
   static APInt getNullValue(unsigned numBits) { return getZero(numBits); }
 
+  /// Return an APInt zero bits wide.
+  static APInt getZeroWidth() { return getZero(0); }
+
   /// Gets maximum unsigned value of APInt for specific bit width.
   static APInt getMaxValue(unsigned numBits) {
     return getAllOnesValue(numBits);
@@ -238,7 +241,6 @@ public:
   ///
   /// \returns An APInt value with the requested bits set.
   static APInt getBitsSet(unsigned numBits, unsigned loBit, unsigned hiBit) {
-    assert(loBit <= hiBit && "loBit greater than hiBit");
     APInt Res(numBits, 0);
     Res.setBits(loBit, hiBit);
     return Res;
@@ -257,8 +259,6 @@ public:
     return Res;
   }
 
-  /// Get a value with upper bits starting at loBit set.
-  ///
   /// Constructs an APInt value that has a contiguous range of bits set. The
   /// bits from loBit (inclusive) to numBits (exclusive) will be set. All other
   /// bits will be zero. For example, with parameters(32, 12) you would get
@@ -274,8 +274,6 @@ public:
     return Res;
   }
 
-  /// Get a value with high bits set
-  ///
   /// Constructs an APInt value that has the top hiBitsSet bits set.
   ///
   /// \param numBits the bitwidth of the result
@@ -286,8 +284,6 @@ public:
     return Res;
   }
 
-  /// Get a value with low bits set
-  ///
   /// Constructs an APInt value that has the bottom loBitsSet bits set.
   ///
   /// \param numBits the bitwidth of the result
@@ -351,8 +347,11 @@ public:
 
   /// Determine if all bits are set.
   bool isAllOnes() const {
-    if (isSingleWord())
+    if (isSingleWord()) {
+      if (BitWidth == 0)
+        return false;
       return U.VAL == WORDTYPE_MAX >> (APINT_BITS_PER_WORD - BitWidth);
+    }
     return countTrailingOnesSlowCase() == BitWidth;
   }
 
@@ -360,7 +359,11 @@ public:
   bool isAllOnesValue() const { return isAllOnes(); }
 
   /// Determine if this value is zero, i.e. all bits are clear.
-  bool isZero() const { return !*this; }
+  bool isZero() const {
+    if (isSingleWord())
+      return U.VAL == 0;
+    return countLeadingZerosSlowCase() == BitWidth;
+  }
 
   /// NOTE: This is soft-deprecated.  Please use `isZero()` instead.
   bool isNullValue() const { return isZero(); }
@@ -388,8 +391,10 @@ public:
   /// This checks to see if the value of this APInt is the maximum signed
   /// value for the APInt's bit width.
   bool isMaxSignedValue() const {
-    if (isSingleWord())
+    if (isSingleWord()) {
+      assert(BitWidth && "zero width values not allowed");
       return U.VAL == ((WordType(1) << (BitWidth - 1)) - 1);
+    }
     return !isNegative() && countTrailingOnesSlowCase() == BitWidth - 1;
   }
 
@@ -404,29 +409,27 @@ public:
   /// This checks to see if the value of this APInt is the minimum signed
   /// value for the APInt's bit width.
   bool isMinSignedValue() const {
-    if (isSingleWord())
+    if (isSingleWord()) {
+      assert(BitWidth && "zero width values not allowed");
       return U.VAL == (WordType(1) << (BitWidth - 1));
+    }
     return isNegative() && countTrailingZerosSlowCase() == BitWidth - 1;
   }
 
   /// Check if this APInt has an N-bits unsigned integer value.
-  bool isIntN(unsigned N) const {
-    assert(N && "0 bit APInt not supported");
-    return getActiveBits() <= N;
-  }
+  bool isIntN(unsigned N) const { return getActiveBits() <= N; }
 
   /// Check if this APInt has an N-bits signed integer value.
-  bool isSignedIntN(unsigned N) const {
-    assert(N && "0 bit APInt not supported");
-    return getMinSignedBits() <= N;
-  }
+  bool isSignedIntN(unsigned N) const { return getMinSignedBits() <= N; }
 
   /// Check if this APInt's value is a power of two greater than zero.
   ///
   /// \returns true if the argument APInt value is a power of two > 0.
   bool isPowerOf2() const {
-    if (isSingleWord())
+    if (isSingleWord()) {
+      assert(BitWidth && "zero width values not allowed");
       return isPowerOf2_64(U.VAL);
+    }
     return countPopulationSlowCase() == 1;
   }
 
@@ -438,7 +441,7 @@ public:
   /// Convert APInt to a boolean value.
   ///
   /// This converts the APInt to a boolean value as a test against zero.
-  bool getBoolValue() const { return !!*this; }
+  bool getBoolValue() const { return !isZero(); }
 
   /// If this value is smaller than the specified limit, return it, otherwise
   /// return the limit value.  This causes the value to saturate to the limit.
@@ -487,16 +490,16 @@ public:
 
   /// Compute an APInt containing numBits highbits from this APInt.
   ///
-  /// Get an APInt with the same BitWidth as this APInt, just zero mask
-  /// the low bits and right shift to the least significant bit.
+  /// Get an APInt with the same BitWidth as this APInt, just zero mask the low
+  /// bits and right shift to the least significant bit.
   ///
   /// \returns the high "numBits" bits of this APInt.
   APInt getHiBits(unsigned numBits) const;
 
   /// Compute an APInt containing numBits lowbits from this APInt.
   ///
-  /// Get an APInt with the same BitWidth as this APInt, just zero mask
-  /// the high bits.
+  /// Get an APInt with the same BitWidth as this APInt, just zero mask the high
+  /// bits.
   ///
   /// \returns the low "numBits" bits of this APInt.
   APInt getLoBits(unsigned numBits) const;
@@ -529,9 +532,7 @@ public:
   /// \name Unary Operators
   /// @{
 
-  /// Postfix increment operator.
-  ///
-  /// Increments *this by 1.
+  /// Postfix increment operator.  Increment *this by 1.
   ///
   /// \returns a new APInt value representing the original value of *this.
   APInt operator++(int) {
@@ -545,9 +546,7 @@ public:
   /// \returns *this incremented by one
   APInt &operator++();
 
-  /// Postfix decrement operator.
-  ///
-  /// Decrements *this by 1.
+  /// Postfix decrement operator. Decrement *this by 1.
   ///
   /// \returns a new APInt value representing the original value of *this.
   APInt operator--(int) {
@@ -561,16 +560,9 @@ public:
   /// \returns *this decremented by one.
   APInt &operator--();
 
-  /// Logical negation operator.
-  ///
-  /// Performs logical negation operation on this APInt.
-  ///
-  /// \returns true if *this is zero, false otherwise.
-  bool operator!() const {
-    if (isSingleWord())
-      return U.VAL == 0;
-    return countLeadingZerosSlowCase() == BitWidth;
-  }
+  /// Logical negation operation on this APInt returns true if zero, like normal
+  /// integers.
+  bool operator!() const { return isZero(); }
 
   /// @}
   /// \name Assignment Operators
@@ -580,11 +572,12 @@ public:
   ///
   /// \returns *this after assignment of RHS.
   APInt &operator=(const APInt &RHS) {
-    // If the bitwidths are the same, we can avoid mucking with memory
+    // The common case (both source or dest being inline) doesn't require
+    // allocation or deallocation.
     if (isSingleWord() && RHS.isSingleWord()) {
       U.VAL = RHS.U.VAL;
       BitWidth = RHS.BitWidth;
-      return clearUnusedBits();
+      return *this;
     }
 
     AssignSlowCase(RHS);
@@ -608,7 +601,6 @@ public:
 
     BitWidth = that.BitWidth;
     that.BitWidth = 0;
-
     return *this;
   }
 
@@ -1264,8 +1256,6 @@ public:
     clearUnusedBits();
   }
 
-  /// Set a given bit to 1.
-  ///
   /// Set the given bit to 1 whose position is given as "bitPosition".
   void setBit(unsigned BitPosition) {
     assert(BitPosition < BitWidth && "BitPosition out of range");
@@ -1449,8 +1439,10 @@ public:
   /// uint64_t. The bitwidth must be <= 64 or the value must fit within a
   /// uint64_t. Otherwise an assertion will result.
   uint64_t getZExtValue() const {
-    if (isSingleWord())
+    if (isSingleWord()) {
+      assert(BitWidth && "zero width values not allowed");
       return U.VAL;
+    }
     assert(getActiveBits() <= 64 && "Too many bits for uint64_t");
     return U.pVal[0];
   }
@@ -1498,8 +1490,11 @@ public:
   /// \returns 0 if the high order bit is not set, otherwise returns the number
   /// of 1 bits from the most significant to the least
   unsigned countLeadingOnes() const {
-    if (isSingleWord())
+    if (isSingleWord()) {
+      if (BitWidth == 0)
+        return 0;
       return llvm::countLeadingOnes(U.VAL << (APINT_BITS_PER_WORD - BitWidth));
+    }
     return countLeadingOnesSlowCase();
   }
 
@@ -1807,10 +1802,9 @@ private:
 
   friend class APSInt;
 
-  /// Fast internal constructor
-  ///
   /// This constructor is used only internally for speed of construction of
-  /// temporaries. It is unsafe for general use so it is not public.
+  /// temporaries. It is unsafe since it takes ownership of the pointer, so it
+  /// is not public.
   APInt(uint64_t *val, unsigned bits) : BitWidth(bits) { U.pVal = val; }
 
   /// Determine which word a bit is in.
@@ -1820,10 +1814,7 @@ private:
     return bitPosition / APINT_BITS_PER_WORD;
   }
 
-  /// Determine which bit in a word a bit is in.
-  ///
-  /// \returns the bit position in a word for the specified bit position
-  /// in the APInt.
+  /// Determine which bit in a word the specified bit position is in.
   static unsigned whichBit(unsigned bitPosition) {
     return bitPosition % APINT_BITS_PER_WORD;
   }
@@ -1845,11 +1836,14 @@ private:
   /// significant word is assigned a value to ensure that those bits are
   /// zero'd out.
   APInt &clearUnusedBits() {
-    // Compute how many bits are used in the final word
+    // Compute how many bits are used in the final word.
     unsigned WordBits = ((BitWidth - 1) % APINT_BITS_PER_WORD) + 1;
 
     // Mask out the high bits.
     uint64_t mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - WordBits);
+    if (BitWidth == 0)
+      mask = 0;
+
     if (isSingleWord())
       U.VAL &= mask;
     else
index 69787d7..3982490 100644 (file)
@@ -89,7 +89,6 @@ void APInt::initSlowCase(const APInt& that) {
 }
 
 void APInt::initFromArray(ArrayRef<uint64_t> bigVal) {
-  assert(BitWidth && "Bitwidth too small");
   assert(bigVal.data() && "Null pointer detected!");
   if (isSingleWord())
     U.VAL = bigVal[0];
@@ -105,19 +104,17 @@ void APInt::initFromArray(ArrayRef<uint64_t> bigVal) {
   clearUnusedBits();
 }
 
-APInt::APInt(unsigned numBits, ArrayRef<uint64_t> bigVal)
-  : BitWidth(numBits) {
+APInt::APInt(unsigned numBits, ArrayRef<uint64_t> bigVal) : BitWidth(numBits) {
   initFromArray(bigVal);
 }
 
 APInt::APInt(unsigned numBits, unsigned numWords, const uint64_t bigVal[])
-  : BitWidth(numBits) {
+    : BitWidth(numBits) {
   initFromArray(makeArrayRef(bigVal, numWords));
 }
 
 APInt::APInt(unsigned numbits, StringRef Str, uint8_t radix)
-  : BitWidth(numbits) {
-  assert(BitWidth && "Bitwidth too small");
+    : BitWidth(numbits) {
   fromString(numbits, Str, radix);
 }
 
@@ -233,9 +230,7 @@ APInt APInt::operator*(const APInt& RHS) const {
     return APInt(BitWidth, U.VAL * RHS.U.VAL);
 
   APInt Result(getMemory(getNumWords()), getBitWidth());
-
   tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords());
-
   Result.clearUnusedBits();
   return Result;
 }
@@ -258,8 +253,7 @@ void APInt::XorAssignSlowCase(const APInt &RHS) {
     dst[i] ^= rhs[i];
 }
 
-APInt& APInt::operator*=(const APInt& RHS) {
-  assert(BitWidth == RHS.BitWidth && "Bit widths must be the same");
+APInt &APInt::operator*=(const APInt &RHS) {
   *this = *this * RHS;
   return *this;
 }
@@ -714,6 +708,8 @@ APInt APInt::reverseBits() const {
     return APInt(BitWidth, llvm::reverseBits<uint16_t>(U.VAL));
   case 8:
     return APInt(BitWidth, llvm::reverseBits<uint8_t>(U.VAL));
+  case 0:
+    return *this;
   default:
     break;
   }
@@ -873,7 +869,6 @@ double APInt::roundToDouble(bool isSigned) const {
 // Truncate to new width.
 APInt APInt::trunc(unsigned width) const {
   assert(width < BitWidth && "Invalid APInt Truncate request");
-  assert(width && "Can't truncate to 0 bits");
 
   if (width <= APINT_BITS_PER_WORD)
     return APInt(width, getRawData()[0]);
@@ -896,7 +891,6 @@ APInt APInt::trunc(unsigned width) const {
 // Truncate to new width with unsigned saturation.
 APInt APInt::truncUSat(unsigned width) const {
   assert(width < BitWidth && "Invalid APInt Truncate request");
-  assert(width && "Can't truncate to 0 bits");
 
   // Can we just losslessly truncate it?
   if (isIntN(width))
@@ -908,7 +902,6 @@ APInt APInt::truncUSat(unsigned width) const {
 // Truncate to new width with signed saturation.
 APInt APInt::truncSSat(unsigned width) const {
   assert(width < BitWidth && "Invalid APInt Truncate request");
-  assert(width && "Can't truncate to 0 bits");
 
   // Can we just losslessly truncate it?
   if (isSignedIntN(width))
@@ -1071,6 +1064,8 @@ void APInt::shlSlowCase(unsigned ShiftAmt) {
 
 // Calculate the rotate amount modulo the bit width.
 static unsigned rotateModulo(unsigned BitWidth, const APInt &rotateAmt) {
+  if (BitWidth == 0)
+    return 0;
   unsigned rotBitWidth = rotateAmt.getBitWidth();
   APInt rot = rotateAmt;
   if (rotBitWidth < BitWidth) {
@@ -1087,6 +1082,8 @@ APInt APInt::rotl(const APInt &rotateAmt) const {
 }
 
 APInt APInt::rotl(unsigned rotateAmt) const {
+  if (BitWidth == 0)
+    return *this;
   rotateAmt %= BitWidth;
   if (rotateAmt == 0)
     return *this;
@@ -1098,6 +1095,8 @@ APInt APInt::rotr(const APInt &rotateAmt) const {
 }
 
 APInt APInt::rotr(unsigned rotateAmt) const {
+  if (BitWidth == 0)
+    return *this;
   rotateAmt %= BitWidth;
   if (rotateAmt == 0)
     return *this;
@@ -2145,7 +2144,7 @@ void APInt::toString(SmallVectorImpl<char> &Str, unsigned Radix,
   }
 
   // First, check for a zero value and just short circuit the logic below.
-  if (*this == 0) {
+  if (isZero()) {
     while (*Prefix) {
       Str.push_back(*Prefix);
       ++Prefix;
@@ -2713,7 +2712,7 @@ APInt llvm::APIntOps::RoundingUDiv(const APInt &A, const APInt &B,
   case APInt::Rounding::UP: {
     APInt Quo, Rem;
     APInt::udivrem(A, B, Quo, Rem);
-    if (Rem == 0)
+    if (Rem.isZero())
       return Quo;
     return Quo + 1;
   }
@@ -2728,7 +2727,7 @@ APInt llvm::APIntOps::RoundingSDiv(const APInt &A, const APInt &B,
   case APInt::Rounding::UP: {
     APInt Quo, Rem;
     APInt::sdivrem(A, B, Quo, Rem);
-    if (Rem == 0)
+    if (Rem.isZero())
       return Quo;
     // This algorithm deals with arbitrary rounding mode used by sdivrem.
     // We want to check whether the non-integer part of the mathematical value
index 74e0c46..0f8f462 100644 (file)
@@ -1422,7 +1422,6 @@ TEST(APIntTest, Log2) {
 #ifdef GTEST_HAS_DEATH_TEST
 #ifndef NDEBUG
 TEST(APIntTest, StringDeath) {
-  EXPECT_DEATH((void)APInt(0, "", 0), "Bitwidth too small");
   EXPECT_DEATH((void)APInt(32, "", 0), "Invalid string length");
   EXPECT_DEATH((void)APInt(32, "0", 0), "Radix should be 2, 8, 10, 16, or 36!");
   EXPECT_DEATH((void)APInt(32, "", 10), "Invalid string length");
@@ -2908,4 +2907,79 @@ TEST(APIntTest, SignbitZeroChecks) {
   EXPECT_FALSE(APInt(8, 1).isNonPositive());
 }
 
+TEST(APIntTest, ZeroWidth) {
+  // Zero width Constructors.
+  auto ZW = APInt::getZeroWidth();
+  EXPECT_EQ(0U, ZW.getBitWidth());
+  EXPECT_EQ(0U, APInt(0, ArrayRef<uint64_t>({0, 1, 2})).getBitWidth());
+  EXPECT_EQ(0U, APInt(0, "0", 10).getBitWidth());
+
+  // Default constructor is single bit wide.
+  EXPECT_EQ(1U, APInt().getBitWidth());
+
+  // Copy ctor (move is down below).
+  APInt ZW2(ZW);
+  EXPECT_EQ(0U, ZW2.getBitWidth());
+  // Assignment
+  ZW = ZW2;
+  EXPECT_EQ(0U, ZW.getBitWidth());
+
+  // Methods like getLowBitsSet work with zero bits.
+  EXPECT_EQ(0U, APInt::getLowBitsSet(0, 0).getBitWidth());
+  EXPECT_EQ(0U, APInt::getSplat(0, ZW).getBitWidth());
+
+  // Logical operators.
+  ZW |= ZW2;
+  ZW &= ZW2;
+  ZW ^= ZW2;
+  ZW |= 42; // These ignore high bits of the literal.
+  ZW &= 42;
+  ZW ^= 42;
+  EXPECT_EQ(1, ZW.isIntN(0));
+
+  // Modulo Arithmetic.  Divide/Rem aren't defined on division by zero, so they
+  // aren't supported.
+  ZW += ZW2;
+  ZW -= ZW2;
+  ZW *= ZW2;
+
+  // Logical Shifts and rotates, the amount must be <= bitwidth.
+  ZW <<= 0;
+  ZW.lshrInPlace(0);
+  (void)ZW.rotl(0);
+  (void)ZW.rotr(0);
+
+  // Comparisons.
+  EXPECT_EQ(1, ZW == ZW);
+  EXPECT_EQ(0, ZW != ZW);
+  EXPECT_EQ(0, ZW.ult(ZW));
+
+  // Mutations.
+  ZW.setBitsWithWrap(0, 0);
+  ZW.setBits(0, 0);
+  ZW.clearAllBits();
+  ZW.flipAllBits();
+
+  // Leading, trailing, ctpop, etc
+  EXPECT_EQ(0U, ZW.countLeadingZeros());
+  EXPECT_EQ(0U, ZW.countLeadingOnes());
+  EXPECT_EQ(0U, ZW.countPopulation());
+  EXPECT_EQ(0U, ZW.reverseBits().getBitWidth());
+  EXPECT_EQ(0U, ZW.getHiBits(0).getBitWidth());
+  EXPECT_EQ(0U, ZW.getLoBits(0).getBitWidth());
+  EXPECT_EQ(0, ZW.zext(4));
+  EXPECT_EQ(0U, APInt(4, 3).trunc(0).getBitWidth());
+
+  SmallString<42> STR;
+  ZW.toStringUnsigned(STR);
+  EXPECT_EQ("0", STR);
+
+  // Move ctor (keep at the end of the method since moves are destructive).
+  APInt MZW1(std::move(ZW));
+  EXPECT_EQ(0U, MZW1.getBitWidth());
+  // Move Assignment
+  MZW1 = std::move(ZW2);
+  EXPECT_EQ(0U, MZW1.getBitWidth());
+}
+
 } // end anonymous namespace