From e6872ce7b7ada90af37677f9a31691e00b03da7d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 6 Jun 2019 15:55:17 -0700 Subject: [PATCH] Simplify DenseElementsAttr by rounding up the storage of odd bit widths to 8-bits. This removes the requirement that the underlying buffer be aligned to 64 bits which opens the door for several optimizations in the future, e.g. detecting splat. PiperOrigin-RevId: 251944922 --- mlir/include/mlir/IR/Attributes.h | 6 +-- mlir/lib/IR/AttributeDetail.h | 18 ++----- mlir/lib/IR/Attributes.cpp | 99 ++++++++++++++++++++------------------- 3 files changed, 59 insertions(+), 64 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index b7b300b..cedd181 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -537,12 +537,12 @@ public: ArrayRef getRawData() const; - /// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is - /// expected to be a 64-bit aligned storage address. + /// Writes value to the bit position `bitPos` in array `rawData`. If the + /// bitwidth of `value` is not 1, then `bitPos` must be 8-bit aligned. static void writeBits(char *rawData, size_t bitPos, APInt value); /// Reads the next `bitWidth` bits from the bit position `bitPos` in array - /// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address. + /// `rawData`. If `bitWidth` is not 1, then `bitPos` must be 8-bit aligned. static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth); /// Method for support type inquiry through isa, cast and dyn_cast. diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 09a7bd7..664251a 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -361,8 +361,9 @@ struct SplatElementsAttributeStorage : public AttributeStorage { struct DenseElementsAttributeStorage : public AttributeStorage { using KeyTy = std::pair>; - DenseElementsAttributeStorage(Type ty, ArrayRef data) - : AttributeStorage(ty), data(data) {} + DenseElementsAttributeStorage(Type ty, ArrayRef data, + bool isSplat = false) + : AttributeStorage(ty), data(data), isSplat(isSplat) {} /// Key equality and hash functions. bool operator==(const KeyTy &key) const { @@ -373,22 +374,13 @@ struct DenseElementsAttributeStorage : public AttributeStorage { static DenseElementsAttributeStorage * construct(AttributeStorageAllocator &allocator, KeyTy key) { // If the data buffer is non-empty, we copy it into the allocator. - ArrayRef data = key.second; - if (!data.empty()) { - // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so - // the `readBits` will not fail when it accesses multiples of - // APINT_WORD_SIZE each time. - size_t sizeToAllocate = - llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE); - auto *rawCopy = (char *)allocator.allocate(sizeToAllocate, 64); - std::uninitialized_copy(data.begin(), data.end(), rawCopy); - data = {rawCopy, data.size()}; - } + ArrayRef data = allocator.copyInto(key.second); return new (allocator.allocate()) DenseElementsAttributeStorage(key.first, data); } ArrayRef data; + bool isSplat; }; /// An attribute representing a reference to a tensor constant with opaque diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 1e902fb..0e2b5a1c 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -500,6 +500,25 @@ static size_t getDenseElementBitwidth(Type eltType) { return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); } +/// Get the bitwidth of a dense element type within the buffer. +/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. +static size_t getDenseElementStorageWidth(size_t origWidth) { + return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); +} + +/// Set a bit to a specific value. +static void setBit(char *rawData, size_t bitPos, bool value) { + if (value) + rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); + else + rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); +} + +/// Return the value of the specified bit. +static bool getBit(const char *rawData, size_t bitPos) { + return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; +} + /// Constructs a new iterator. DenseElementsAttr::RawElementIterator::RawElementIterator( DenseElementsAttr attr, size_t index) @@ -508,7 +527,8 @@ DenseElementsAttr::RawElementIterator::RawElementIterator( /// Accesses the raw APInt value at this iterator position. APInt DenseElementsAttr::RawElementIterator::operator*() const { - return readBits(rawData, index * bitWidth, bitWidth); + return readBits(rawData, index * getDenseElementStorageWidth(bitWidth), + bitWidth); } //===----------------------------------------------------------------------===// @@ -544,14 +564,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, assert(static_cast(values.size()) == type.getNumElements() && "expected 'values' to contain the same number of elements as 'type'"); - // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored - // with double semantics. auto eltType = type.getElementType(); - size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); + size_t bitWidth = getDenseElementBitwidth(eltType); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); // Compress the attribute values into a character buffer. - SmallVector data(APInt::getNumWords(bitWidth * values.size()) * - APInt::APINT_WORD_SIZE); + SmallVector data(storageBitWidth * type.getNumElements()); APInt intVal; for (unsigned i = 0, e = values.size(); i < e; ++i) { switch (eltType.getKind()) { @@ -573,7 +591,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, } assert(intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type"); - writeBits(data.data(), i * bitWidth, intVal); + writeBits(data.data(), i * storageBitWidth, intVal); } return get(type, data); } @@ -608,8 +626,9 @@ Attribute DenseElementsAttr::getValue(ArrayRef index) const { // Return the element stored at the 1D index. auto elementType = getType().getElementType(); size_t bitWidth = getDenseElementBitwidth(elementType); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); APInt rawValueData = - readBits(getRawData().data(), valueIndex * bitWidth, bitWidth); + readBits(getRawData().data(), valueIndex * storageBitWidth, bitWidth); // Convert the raw value data to an attribute value. switch (getKind()) { @@ -677,60 +696,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, "expected 'values' to contain the same number of elements as 'type'"); size_t bitWidth = getDenseElementBitwidth(type.getElementType()); - std::vector elementData(APInt::getNumWords(bitWidth * values.size()) * - APInt::APINT_WORD_SIZE); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); + std::vector elementData(bitWidth * values.size()); for (unsigned i = 0, e = values.size(); i != e; ++i) { assert(values[i].getBitWidth() == bitWidth); - writeBits(elementData.data(), i * bitWidth, values[i]); + writeBits(elementData.data(), i * storageBitWidth, values[i]); } return get(type, elementData); } -/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is -/// expected to be a 64-bit aligned storage address. +/// Writes value to the bit position `bitPos` in array `rawData`. void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) { size_t bitWidth = value.getBitWidth(); // If the bitwidth is 1 we just toggle the specific bit. - if (bitWidth == 1) { - auto *rawIntData = reinterpret_cast(rawData); - if (value.isOneValue()) - APInt::tcSetBit(rawIntData, bitPos); - else - APInt::tcClearBit(rawIntData, bitPos); - return; - } - - // If the bit position and width are byte aligned, write the storage directly - // to the data. - if ((bitWidth % 8) == 0 && (bitPos % 8) == 0) { - std::copy_n(reinterpret_cast(value.getRawData()), - bitWidth / 8, rawData + (bitPos / 8)); - return; - } + if (bitWidth == 1) + return setBit(rawData, bitPos, value.isOneValue()); - // Otherwise, convert the raw data into an APInt and insert the value at the - // specified bit position. - size_t totalWords = APInt::getNumWords((bitPos % 64) + bitWidth); - llvm::MutableArrayRef rawIntData( - reinterpret_cast(rawData) + (bitPos / 64), totalWords); - APInt tempStorage(totalWords * 64, rawIntData); - tempStorage.insertBits(value, bitPos % 64); - - // Copy the value back to the raw data. - std::copy_n(tempStorage.getRawData(), rawIntData.size(), rawIntData.data()); + // Otherwise, the bit position is guaranteed to be byte aligned. + assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); + std::copy_n(reinterpret_cast(value.getRawData()), + llvm::divideCeil(bitWidth, CHAR_BIT), + rawData + (bitPos / CHAR_BIT)); } /// Reads the next `bitWidth` bits from the bit position `bitPos` in array -/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address. +/// `rawData`. APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos, size_t bitWidth) { - // Reinterpret the raw data as a uint64_t word array and extract the value - // starting at 'bitPos'. + // Handle a boolean bit position. + if (bitWidth == 1) + return APInt(1, getBit(rawData, bitPos) ? 1 : 0); + + // Otherwise, the bit position must be 8-bit aligned. + assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); APInt result(bitWidth, 0); - const uint64_t *intData = reinterpret_cast(rawData); - APInt::tcExtract(const_cast(result.getRawData()), - result.getNumWords(), intData, bitWidth, bitPos); + std::copy_n(rawData + (bitPos / CHAR_BIT), + llvm::divideCeil(bitWidth, CHAR_BIT), + (char *)(result.getRawData())); return result; } @@ -772,6 +775,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, Type newElementType, llvm::SmallVectorImpl &data) { size_t bitWidth = getDenseElementBitwidth(newElementType); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); ShapedType newArrayType; if (inType.isa()) @@ -783,14 +787,13 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, else assert(newArrayType && "Unhandled tensor type"); - data.resize(APInt::getNumWords(bitWidth * inType.getNumElements()) * - APInt::APINT_WORD_SIZE); + data.resize(storageBitWidth * inType.getNumElements()); uint64_t elementIdx = 0; for (auto value : attr) { auto newInt = mapping(value); assert(newInt.getBitWidth() == bitWidth); - attr.writeBits(data.data(), elementIdx * bitWidth, newInt); + attr.writeBits(data.data(), elementIdx * storageBitWidth, newInt); ++elementIdx; } -- 2.7.4