Simplify DenseElementsAttr by rounding up the storage of odd bit widths to 8-bits...
authorRiver Riddle <riverriddle@google.com>
Thu, 6 Jun 2019 22:55:17 +0000 (15:55 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:21:43 +0000 (16:21 -0700)
PiperOrigin-RevId: 251944922

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp

index b7b300b..cedd181 100644 (file)
@@ -537,12 +537,12 @@ public:
 
   ArrayRef<char> getRawData() const;
 
-  /// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
-  /// expected to be a 64-bit aligned storage address.
+  /// Writes value to the bit position `bitPos` in array `rawData`. If the
+  /// bitwidth of `value` is not 1, then `bitPos` must be 8-bit aligned.
   static void writeBits(char *rawData, size_t bitPos, APInt value);
 
   /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
-  /// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
+  /// `rawData`. If `bitWidth` is not 1, then `bitPos` must be 8-bit aligned.
   static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth);
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
index 09a7bd7..664251a 100644 (file)
@@ -361,8 +361,9 @@ struct SplatElementsAttributeStorage : public AttributeStorage {
 struct DenseElementsAttributeStorage : public AttributeStorage {
   using KeyTy = std::pair<Type, ArrayRef<char>>;
 
-  DenseElementsAttributeStorage(Type ty, ArrayRef<char> data)
-      : AttributeStorage(ty), data(data) {}
+  DenseElementsAttributeStorage(Type ty, ArrayRef<char> data,
+                                bool isSplat = false)
+      : AttributeStorage(ty), data(data), isSplat(isSplat) {}
 
   /// Key equality and hash functions.
   bool operator==(const KeyTy &key) const {
@@ -373,22 +374,13 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
   static DenseElementsAttributeStorage *
   construct(AttributeStorageAllocator &allocator, KeyTy key) {
     // If the data buffer is non-empty, we copy it into the allocator.
-    ArrayRef<char> data = key.second;
-    if (!data.empty()) {
-      // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so
-      // the `readBits` will not fail when it accesses multiples of
-      // APINT_WORD_SIZE each time.
-      size_t sizeToAllocate =
-          llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE);
-      auto *rawCopy = (char *)allocator.allocate(sizeToAllocate, 64);
-      std::uninitialized_copy(data.begin(), data.end(), rawCopy);
-      data = {rawCopy, data.size()};
-    }
+    ArrayRef<char> data = allocator.copyInto(key.second);
     return new (allocator.allocate<DenseElementsAttributeStorage>())
         DenseElementsAttributeStorage(key.first, data);
   }
 
   ArrayRef<char> data;
+  bool isSplat;
 };
 
 /// An attribute representing a reference to a tensor constant with opaque
index 1e902fb..0e2b5a1 100644 (file)
@@ -500,6 +500,25 @@ static size_t getDenseElementBitwidth(Type eltType) {
   return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
 }
 
+/// Get the bitwidth of a dense element type within the buffer.
+/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
+static size_t getDenseElementStorageWidth(size_t origWidth) {
+  return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
+}
+
+/// Set a bit to a specific value.
+static void setBit(char *rawData, size_t bitPos, bool value) {
+  if (value)
+    rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
+  else
+    rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
+}
+
+/// Return the value of the specified bit.
+static bool getBit(const char *rawData, size_t bitPos) {
+  return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
+}
+
 /// Constructs a new iterator.
 DenseElementsAttr::RawElementIterator::RawElementIterator(
     DenseElementsAttr attr, size_t index)
@@ -508,7 +527,8 @@ DenseElementsAttr::RawElementIterator::RawElementIterator(
 
 /// Accesses the raw APInt value at this iterator position.
 APInt DenseElementsAttr::RawElementIterator::operator*() const {
-  return readBits(rawData, index * bitWidth, bitWidth);
+  return readBits(rawData, index * getDenseElementStorageWidth(bitWidth),
+                  bitWidth);
 }
 
 //===----------------------------------------------------------------------===//
@@ -544,14 +564,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
          "expected 'values' to contain the same number of elements as 'type'");
 
-  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
-  // with double semantics.
   auto eltType = type.getElementType();
-  size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
+  size_t bitWidth = getDenseElementBitwidth(eltType);
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
 
   // Compress the attribute values into a character buffer.
-  SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
-                            APInt::APINT_WORD_SIZE);
+  SmallVector<char, 8> data(storageBitWidth * type.getNumElements());
   APInt intVal;
   for (unsigned i = 0, e = values.size(); i < e; ++i) {
     switch (eltType.getKind()) {
@@ -573,7 +591,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
     }
     assert(intVal.getBitWidth() == bitWidth &&
            "expected value to have same bitwidth as element type");
-    writeBits(data.data(), i * bitWidth, intVal);
+    writeBits(data.data(), i * storageBitWidth, intVal);
   }
   return get(type, data);
 }
@@ -608,8 +626,9 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   // Return the element stored at the 1D index.
   auto elementType = getType().getElementType();
   size_t bitWidth = getDenseElementBitwidth(elementType);
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
   APInt rawValueData =
-      readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
+      readBits(getRawData().data(), valueIndex * storageBitWidth, bitWidth);
 
   // Convert the raw value data to an attribute value.
   switch (getKind()) {
@@ -677,60 +696,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
          "expected 'values' to contain the same number of elements as 'type'");
 
   size_t bitWidth = getDenseElementBitwidth(type.getElementType());
-  std::vector<char> elementData(APInt::getNumWords(bitWidth * values.size()) *
-                                APInt::APINT_WORD_SIZE);
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+  std::vector<char> elementData(bitWidth * values.size());
   for (unsigned i = 0, e = values.size(); i != e; ++i) {
     assert(values[i].getBitWidth() == bitWidth);
-    writeBits(elementData.data(), i * bitWidth, values[i]);
+    writeBits(elementData.data(), i * storageBitWidth, values[i]);
   }
   return get(type, elementData);
 }
 
-/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
-/// expected to be a 64-bit aligned storage address.
+/// Writes value to the bit position `bitPos` in array `rawData`.
 void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) {
   size_t bitWidth = value.getBitWidth();
 
   // If the bitwidth is 1 we just toggle the specific bit.
-  if (bitWidth == 1) {
-    auto *rawIntData = reinterpret_cast<uint64_t *>(rawData);
-    if (value.isOneValue())
-      APInt::tcSetBit(rawIntData, bitPos);
-    else
-      APInt::tcClearBit(rawIntData, bitPos);
-    return;
-  }
-
-  // If the bit position and width are byte aligned, write the storage directly
-  // to the data.
-  if ((bitWidth % 8) == 0 && (bitPos % 8) == 0) {
-    std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
-                bitWidth / 8, rawData + (bitPos / 8));
-    return;
-  }
+  if (bitWidth == 1)
+    return setBit(rawData, bitPos, value.isOneValue());
 
-  // Otherwise, convert the raw data into an APInt and insert the value at the
-  // specified bit position.
-  size_t totalWords = APInt::getNumWords((bitPos % 64) + bitWidth);
-  llvm::MutableArrayRef<uint64_t> rawIntData(
-      reinterpret_cast<uint64_t *>(rawData) + (bitPos / 64), totalWords);
-  APInt tempStorage(totalWords * 64, rawIntData);
-  tempStorage.insertBits(value, bitPos % 64);
-
-  // Copy the value back to the raw data.
-  std::copy_n(tempStorage.getRawData(), rawIntData.size(), rawIntData.data());
+  // Otherwise, the bit position is guaranteed to be byte aligned.
+  assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
+  std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
+              llvm::divideCeil(bitWidth, CHAR_BIT),
+              rawData + (bitPos / CHAR_BIT));
 }
 
 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
-/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
+/// `rawData`.
 APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
                                   size_t bitWidth) {
-  // Reinterpret the raw data as a uint64_t word array and extract the value
-  // starting at 'bitPos'.
+  // Handle a boolean bit position.
+  if (bitWidth == 1)
+    return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
+
+  // Otherwise, the bit position must be 8-bit aligned.
+  assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
   APInt result(bitWidth, 0);
-  const uint64_t *intData = reinterpret_cast<const uint64_t *>(rawData);
-  APInt::tcExtract(const_cast<uint64_t *>(result.getRawData()),
-                   result.getNumWords(), intData, bitWidth, bitPos);
+  std::copy_n(rawData + (bitPos / CHAR_BIT),
+              llvm::divideCeil(bitWidth, CHAR_BIT),
+              (char *)(result.getRawData()));
   return result;
 }
 
@@ -772,6 +775,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
                                 Type newElementType,
                                 llvm::SmallVectorImpl<char> &data) {
   size_t bitWidth = getDenseElementBitwidth(newElementType);
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
 
   ShapedType newArrayType;
   if (inType.isa<RankedTensorType>())
@@ -783,14 +787,13 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
   else
     assert(newArrayType && "Unhandled tensor type");
 
-  data.resize(APInt::getNumWords(bitWidth * inType.getNumElements()) *
-              APInt::APINT_WORD_SIZE);
+  data.resize(storageBitWidth * inType.getNumElements());
 
   uint64_t elementIdx = 0;
   for (auto value : attr) {
     auto newInt = mapping(value);
     assert(newInt.getBitWidth() == bitWidth);
-    attr.writeBits(data.data(), elementIdx * bitWidth, newInt);
+    attr.writeBits(data.data(), elementIdx * storageBitWidth, newInt);
     ++elementIdx;
   }