From 61c3b5df3887ffbe59642af687deb1c2fc34003b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 7 Jun 2019 12:08:36 -0700 Subject: [PATCH] NFC: Cleanup the grouping of DenseElementsAttr 'get' methods, and move the bit write/read functions to static functions in Attributes.cpp. PiperOrigin-RevId: 252094145 --- mlir/include/mlir/IR/Attributes.h | 41 ++++----- mlir/lib/IR/Attributes.cpp | 176 +++++++++++++++++++------------------- 2 files changed, 109 insertions(+), 108 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 08a5f33..77caaba 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -502,6 +502,11 @@ class DenseElementsAttr public: using Base::Base; + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr) { + return attr.getKind() == StandardAttributes::DenseElements; + } + /// Constructs a dense elements attribute from an array of element values. /// Each element attribute value is expected to be an element of 'type'. /// 'type' must be a vector or tensor with static shape. @@ -528,6 +533,13 @@ public: /// Overload of the above 'get' method that is specialized for boolean values. static DenseElementsAttr get(ShapedType type, ArrayRef values); + //===--------------------------------------------------------------------===// + // Value Querying + //===--------------------------------------------------------------------===// + + /// Return the raw storage data held by this attribute. + ArrayRef getRawData() const; + /// Returns the number of elements held by this attribute. size_t size() const; @@ -538,6 +550,10 @@ public: /// Return the held element values as Attributes in 'values'. void getValues(SmallVectorImpl &values) const; + //===--------------------------------------------------------------------===// + // Mutation Utilities + //===--------------------------------------------------------------------===// + /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. @@ -557,21 +573,6 @@ public: mapValues(Type newElementType, llvm::function_ref mapping) const; - ArrayRef getRawData() const; - - /// Writes value to the bit position `bitPos` in array `rawData`. If the - /// bitwidth of `value` is not 1, then `bitPos` must be 8-bit aligned. - static void writeBits(char *rawData, size_t bitPos, APInt value); - - /// Reads the next `bitWidth` bits from the bit position `bitPos` in array - /// `rawData`. If `bitWidth` is not 1, then `bitPos` must be 8-bit aligned. - static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth); - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr) { - return attr.getKind() == StandardAttributes::DenseElements; - } - protected: /// A utility iterator that allows walking over the internal raw APInt values. class RawElementIterator @@ -622,6 +623,11 @@ protected: return RawElementIterator(*this, size()); } + /// Constructs a dense elements attribute from an array of raw APInt values. + /// Each APInt value is expected to have the same bitwidth as the element type + /// of 'type'. 'type' must be a vector or tensor with static shape. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Get or create a new dense elements attribute instance with the given raw /// data buffer. 'type' must be a vector or tensor with static shape. static DenseElementsAttr getRaw(ShapedType type, ArrayRef data); @@ -630,11 +636,6 @@ protected: /// integer or floating-point type. static DenseElementsAttr getRawIntOrFloat(ShapedType type, ArrayRef data, bool isInt); - - /// Constructs a dense elements attribute from an array of raw APInt values. - /// Each APInt value is expected to have the same bitwidth as the element type - /// of 'type'. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); }; /// An attribute that represents a reference to a dense integer vector or tensor diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 2e91ac6..a5da323 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -516,6 +516,37 @@ static bool getBit(const char *rawData, size_t bitPos) { return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; } +/// Writes value to the bit position `bitPos` in array `rawData`. +static void writeBits(char *rawData, size_t bitPos, APInt value) { + size_t bitWidth = value.getBitWidth(); + + // If the bitwidth is 1 we just toggle the specific bit. + if (bitWidth == 1) + return setBit(rawData, bitPos, value.isOneValue()); + + // Otherwise, the bit position is guaranteed to be byte aligned. + assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); + std::copy_n(reinterpret_cast(value.getRawData()), + llvm::divideCeil(bitWidth, CHAR_BIT), + rawData + (bitPos / CHAR_BIT)); +} + +/// Reads the next `bitWidth` bits from the bit position `bitPos` in array +/// `rawData`. +static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { + // Handle a boolean bit position. + if (bitWidth == 1) + return APInt(1, getBit(rawData, bitPos) ? 1 : 0); + + // Otherwise, the bit position must be 8-bit aligned. + assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); + APInt result(bitWidth, 0); + std::copy_n(rawData + (bitPos / CHAR_BIT), + llvm::divideCeil(bitWidth, CHAR_BIT), + (char *)(result.getRawData())); + return result; +} + /// Constructs a new iterator. DenseElementsAttr::RawElementIterator::RawElementIterator( DenseElementsAttr attr, size_t index) @@ -532,39 +563,6 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const { // DenseElementsAttr //===----------------------------------------------------------------------===// -DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, - ArrayRef data) { - assert((static_cast(type.getSizeInBits()) <= - data.size() * APInt::APINT_WORD_SIZE) && - "Input data bit size should be larger than that type requires"); - assert((type.isa() || type.isa()) && - "type must be ranked tensor or vector"); - assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), StandardAttributes::DenseElements, type, - data); -} - -/// Overload of the raw 'get' method that asserts that the given type is of -/// integer type. -DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, - ArrayRef data, - bool isInt) { - assert(isInt ? type.getElementType().isa() - : type.getElementType().isa()); - return getRaw(type, data); -} - -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(type.getNumElements() == static_cast(values.size())); - assert(type.getElementType().isInteger(1)); - - std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); - for (int i = 0, e = values.size(); i != e; ++i) - writeBits(buff.data(), i, llvm::APInt(1, values[i])); - return getRaw(type, buff); -} - DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrFloat() && @@ -604,6 +602,62 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, return getRaw(type, data); } +DenseElementsAttr DenseElementsAttr::get(ShapedType type, + ArrayRef values) { + assert(type.getNumElements() == static_cast(values.size())); + assert(type.getElementType().isInteger(1)); + + std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); + for (int i = 0, e = values.size(); i != e; ++i) + writeBits(buff.data(), i, llvm::APInt(1, values[i])); + return getRaw(type, buff); +} + +// Constructs a dense elements attribute from an array of raw APInt values. +// Each APInt value is expected to have the same bitwidth as the element type +// of 'type'. +DenseElementsAttr DenseElementsAttr::get(ShapedType type, + ArrayRef values) { + assert(static_cast(values.size()) == type.getNumElements() && + "expected 'values' to contain the same number of elements as 'type'"); + + size_t bitWidth = getDenseElementBitwidth(type.getElementType()); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); + std::vector elementData(bitWidth * values.size()); + for (unsigned i = 0, e = values.size(); i != e; ++i) { + assert(values[i].getBitWidth() == bitWidth); + writeBits(elementData.data(), i * storageBitWidth, values[i]); + } + return getRaw(type, elementData); +} + +DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, + ArrayRef data) { + assert((static_cast(type.getSizeInBits()) <= + data.size() * APInt::APINT_WORD_SIZE) && + "Input data bit size should be larger than that type requires"); + assert((type.isa() || type.isa()) && + "type must be ranked tensor or vector"); + assert(type.hasStaticShape() && "type must have static shape"); + return Base::get(type.getContext(), StandardAttributes::DenseElements, type, + data); +} + +/// Overload of the 'getRaw' method that asserts that the given type is of +/// integer type. +DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, + ArrayRef data, + bool isInt) { + assert(isInt ? type.getElementType().isa() + : type.getElementType().isa()); + return getRaw(type, data); +} + +/// Return the raw storage data held by this attribute. +ArrayRef DenseElementsAttr::getRawData() const { + return static_cast(impl)->data; +} + /// Returns the number of elements held by this attribute. size_t DenseElementsAttr::size() const { return getType().getNumElements(); } @@ -700,60 +754,6 @@ DenseElementsAttr DenseElementsAttr::mapValues( return cast().mapValues(newElementType, mapping); } -ArrayRef DenseElementsAttr::getRawData() const { - return static_cast(impl)->data; -} - -// Constructs a dense elements attribute from an array of raw APInt values. -// Each APInt value is expected to have the same bitwidth as the element type -// of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(static_cast(values.size()) == type.getNumElements() && - "expected 'values' to contain the same number of elements as 'type'"); - - size_t bitWidth = getDenseElementBitwidth(type.getElementType()); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - std::vector elementData(bitWidth * values.size()); - for (unsigned i = 0, e = values.size(); i != e; ++i) { - assert(values[i].getBitWidth() == bitWidth); - writeBits(elementData.data(), i * storageBitWidth, values[i]); - } - return getRaw(type, elementData); -} - -/// Writes value to the bit position `bitPos` in array `rawData`. -void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) { - size_t bitWidth = value.getBitWidth(); - - // If the bitwidth is 1 we just toggle the specific bit. - if (bitWidth == 1) - return setBit(rawData, bitPos, value.isOneValue()); - - // Otherwise, the bit position is guaranteed to be byte aligned. - assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); - std::copy_n(reinterpret_cast(value.getRawData()), - llvm::divideCeil(bitWidth, CHAR_BIT), - rawData + (bitPos / CHAR_BIT)); -} - -/// Reads the next `bitWidth` bits from the bit position `bitPos` in array -/// `rawData`. -APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos, - size_t bitWidth) { - // Handle a boolean bit position. - if (bitWidth == 1) - return APInt(1, getBit(rawData, bitPos) ? 1 : 0); - - // Otherwise, the bit position must be 8-bit aligned. - assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); - APInt result(bitWidth, 0); - std::copy_n(rawData + (bitPos / CHAR_BIT), - llvm::divideCeil(bitWidth, CHAR_BIT), - (char *)(result.getRawData())); - return result; -} - //===----------------------------------------------------------------------===// // DenseIntElementsAttr //===----------------------------------------------------------------------===// @@ -794,7 +794,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, for (auto value : attr) { auto newInt = mapping(value); assert(newInt.getBitWidth() == bitWidth); - attr.writeBits(data.data(), elementIdx * storageBitWidth, newInt); + writeBits(data.data(), elementIdx * storageBitWidth, newInt); ++elementIdx; } -- 2.7.4