From 0f89ef30b749408b705afd7ba55b9f9be8ac8d8f Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 30 Apr 2019 10:31:29 -0700 Subject: [PATCH] Refactor Attribute uniquing to use StorageUniquer instead of being hard coded in the MLIRContext. This allows for attributes to be uniqued similarly to types. This is the second step towards allowing dialects to define attributes. -- PiperOrigin-RevId: 245974705 --- mlir/include/mlir/IR/Attributes.h | 2 - mlir/include/mlir/IR/MLIRContext.h | 4 + mlir/include/mlir/Support/StorageUniquer.h | 38 ++- mlir/lib/IR/AttributeDetail.h | 358 ++++++++++++++++--- mlir/lib/IR/Attributes.cpp | 198 ++++++++++- mlir/lib/IR/MLIRContext.cpp | 528 +---------------------------- mlir/lib/Support/StorageUniquer.cpp | 26 ++ 7 files changed, 582 insertions(+), 572 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 25b7399..5deb2bb 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -38,7 +38,6 @@ class VectorOrTensorType; namespace detail { struct AttributeStorage; -struct UnitAttributeStorage; struct BoolAttributeStorage; struct IntegerAttributeStorage; struct FloatAttributeStorage; @@ -165,7 +164,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) { class UnitAttr : public Attribute { public: using Attribute::Attribute; - using ImplType = detail::UnitAttributeStorage; static UnitAttr get(MLIRContext *context); diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index ed2640d..4b2343a 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -98,6 +98,10 @@ public: /// This should not be used directly. StorageUniquer &getTypeUniquer(); + /// Returns the storage uniquer used for constructing attribute storage + /// instances. This should not be used directly. + StorageUniquer &getAttributeUniquer(); + private: const std::unique_ptr impl; diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h index 8a0c590..2a9bb4a 100644 --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -55,6 +55,12 @@ struct StorageUniquerImpl; /// that builds a unique instance of the derived storage. The arguments to /// this function are an allocator to store any uniqued data and the key /// type for this storage. +/// +/// - Provide a cleanup method: +/// 'void cleanup()' +/// that is called when erasing a storage instance. This should cleanup any +/// fields of the storage as necessary and not attempt to free the memory +/// of the storage itself. class StorageUniquer { public: StorageUniquer(); @@ -114,7 +120,7 @@ public: /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter /// that can be used to initialize a newly inserted storage instance. This - /// overload is used for derived types that have complex storage or uniquing + /// function is used for derived types that have complex storage or uniquing /// constraints. template Storage *getComplex(std::function initFn, unsigned kind, @@ -146,7 +152,7 @@ public: /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter /// that can be used to initialize a newly inserted storage instance. This - /// overload is used for derived types that use no additional storage or + /// function is used for derived types that use no additional storage or /// uniquing outside of the kind. template Storage *getSimple(std::function initFn, unsigned kind) { @@ -159,6 +165,28 @@ public: return static_cast(getImpl(kind, ctorFn)); } + /// Erases a uniqued instance of 'Storage'. This function is used for derived + /// types that have complex storage or uniquing constraints. + template + void eraseComplex(unsigned kind, Args &&... args) { + // Construct a value of the derived key type. + auto derivedKey = getKey(args...); + + // Create a hash of the kind and the derived key. + unsigned hashValue = getHash(kind, derivedKey); + + // Generate an equality function for the derived storage. + std::function isEqual = + [&derivedKey](const BaseStorage *existing) { + return static_cast(*existing) == derivedKey; + }; + + // Attempt to erase the storage instance. + eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) { + static_cast(storage)->cleanup(); + }); + } + private: /// Implementation for getting/creating an instance of a derived type with /// complex storage. @@ -171,6 +199,12 @@ private: BaseStorage *getImpl(unsigned kind, std::function ctorFn); + /// Implementation for erasing an instance of a derived type with complex + /// storage. + void eraseImpl(unsigned kind, unsigned hashValue, + llvm::function_ref isEqual, + std::function cleanupFn); + /// The internal implementation class. std::unique_ptr impl; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index e1da603..89ac240 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -28,33 +28,79 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/StorageUniquer.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { namespace detail { - /// Base storage class appearing in an attribute. -struct AttributeStorage { - AttributeStorage(Attribute::Kind kind, bool isOrContainsFunctionCache = false) - : kind(kind), isOrContainsFunctionCache(isOrContainsFunctionCache) {} - - Attribute::Kind kind : 8; +struct AttributeStorage : public StorageUniquer::BaseStorage { + AttributeStorage(bool isOrContainsFunctionCache = false) + : isOrContainsFunctionCache(isOrContainsFunctionCache) {} /// This field is true if this is, or contains, a function attribute. bool isOrContainsFunctionCache : 1; }; -/// An attribute representing a unit value. -struct UnitAttributeStorage : public AttributeStorage { - UnitAttributeStorage() : AttributeStorage(Attribute::Kind::Unit) {} +// A utility class to get, or create, unique instances of attributes within an +// MLIRContext. This class manages all creation and uniquing of attributes. +class AttributeUniquer { +public: + /// Get an uniqued instance of attribute T. This overload is used for + /// derived attributes that have complex storage or uniquing constraints. + template + static typename std::enable_if< + !std::is_same::value, T>::type + get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) { + return ctx->getAttributeUniquer().getComplex( + /*initFn=*/{}, static_cast(kind), + std::forward(args)...); + } + + /// Get an uniqued instance of attribute T. This overload is used for + /// derived attributes that use the AttributeStorage directly and thus need no + /// additional storage or uniquing. + template + static typename std::enable_if< + std::is_same::value, T>::type + get(MLIRContext *ctx, Attribute::Kind kind) { + return ctx->getAttributeUniquer().getSimple( + /*initFn=*/{}, static_cast(kind)); + } + + /// Erase a uniqued instance of attribute T. This overload is used for + /// derived attributes that have complex storage or uniquing constraints. + template + static typename std::enable_if< + !std::is_same::value>::type + erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) { + return ctx->getAttributeUniquer().eraseComplex( + static_cast(kind), std::forward(args)...); + } }; +using AttributeStorageAllocator = StorageUniquer::StorageAllocator; + /// An attribute representing a boolean value. struct BoolAttributeStorage : public AttributeStorage { - BoolAttributeStorage(Type type, bool value) - : AttributeStorage(Attribute::Kind::Bool), type(type), value(value) {} - const Type type; + using KeyTy = std::pair; + + BoolAttributeStorage(Type type, bool value) : type(type), value(value) {} + + /// We only check equality for and hash with the boolean key parameter. + bool operator==(const KeyTy &key) const { return key.second == value; } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_value(key.second); + } + + static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + BoolAttributeStorage(IntegerType::get(1, key.first), key.second); + } + + Type type; bool value; }; @@ -62,14 +108,37 @@ struct BoolAttributeStorage : public AttributeStorage { struct IntegerAttributeStorage final : public AttributeStorage, public llvm::TrailingObjects { + using KeyTy = std::pair; + IntegerAttributeStorage(Type type, size_t numObjects) - : AttributeStorage(Attribute::Kind::Integer), type(type), - numObjects(numObjects) { + : type(type), numObjects(numObjects) { assert((type.isIndex() || type.isa()) && "invalid type"); } - const Type type; - size_t numObjects; + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { + return key == KeyTy(type, getValue()); + } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, llvm::hash_value(key.second)); + } + + /// Construct a new storage instance. + static IntegerAttributeStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + Type type; + APInt value; + std::tie(type, value) = key; + + auto elements = ArrayRef(value.getRawData(), value.getNumWords()); + auto size = + IntegerAttributeStorage::totalSizeToAlloc(elements.size()); + auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage)); + auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), + result->getTrailingObjects()); + return result; + } /// Returns an APInt representing the stored value. APInt getValue() const { @@ -78,19 +147,47 @@ struct IntegerAttributeStorage final return APInt(type.getIntOrFloatBitWidth(), {getTrailingObjects(), numObjects}); } + + Type type; + size_t numObjects; }; /// An attribute representing a floating point value. struct FloatAttributeStorage final : public AttributeStorage, public llvm::TrailingObjects { + using KeyTy = std::pair; + FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type, size_t numObjects) - : AttributeStorage(Attribute::Kind::Float), semantics(semantics), - type(type.cast()), numObjects(numObjects) {} - const llvm::fltSemantics &semantics; - const FloatType type; - size_t numObjects; + : semantics(semantics), type(type.cast()), + numObjects(numObjects) {} + + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { + return key.first == type && key.second.bitwiseIsEqual(getValue()); + } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, llvm::hash_value(key.second)); + } + + /// Construct a new storage instance. + static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + const auto &apint = key.second.bitcastToAPInt(); + + // Here one word's bitwidth equals to that of uint64_t. + auto elements = ArrayRef(apint.getRawData(), apint.getNumWords()); + + auto byteSize = + FloatAttributeStorage::totalSizeToAlloc(elements.size()); + auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage)); + auto result = ::new (rawMem) FloatAttributeStorage( + key.second.getSemantics(), key.first, elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), + result->getTrailingObjects()); + return result; + } /// Returns an APFloat representing the stored value. APFloat getValue() const { @@ -98,95 +195,266 @@ struct FloatAttributeStorage final {getTrailingObjects(), numObjects}); return APFloat(semantics, val); } + + const llvm::fltSemantics &semantics; + FloatType type; + size_t numObjects; }; /// An attribute representing a string value. struct StringAttributeStorage : public AttributeStorage { - StringAttributeStorage(StringRef value) - : AttributeStorage(Attribute::Kind::String), value(value) {} + using KeyTy = StringRef; + + StringAttributeStorage(StringRef value) : value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static StringAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + StringAttributeStorage(allocator.copyInto(key)); + } + StringRef value; }; /// An attribute representing an array of other attributes. struct ArrayAttributeStorage : public AttributeStorage { + using KeyTy = ArrayRef; + ArrayAttributeStorage(bool hasFunctionAttr, ArrayRef value) - : AttributeStorage(Attribute::Kind::Array, hasFunctionAttr), - value(value) {} + : AttributeStorage(hasFunctionAttr), value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + // Check to see if any of the elements have a function attr. + bool hasFunctionAttr = llvm::any_of( + key, [](Attribute elt) { return elt.isOrContainsFunction(); }); + + // Initialize the memory using placement new. + return new (allocator.allocate()) + ArrayAttributeStorage(hasFunctionAttr, allocator.copyInto(key)); + } + ArrayRef value; }; // An attribute representing a reference to an affine map. struct AffineMapAttributeStorage : public AttributeStorage { - AffineMapAttributeStorage(AffineMap value) - : AttributeStorage(Attribute::Kind::AffineMap), value(value) {} + using KeyTy = AffineMap; + + AffineMapAttributeStorage(AffineMap value) : value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static AffineMapAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate()) + AffineMapAttributeStorage(key); + } + AffineMap value; }; // An attribute representing a reference to an integer set. struct IntegerSetAttributeStorage : public AttributeStorage { - IntegerSetAttributeStorage(IntegerSet value) - : AttributeStorage(Attribute::Kind::IntegerSet), value(value) {} + using KeyTy = IntegerSet; + + IntegerSetAttributeStorage(IntegerSet value) : value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static IntegerSetAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate()) + IntegerSetAttributeStorage(key); + } + IntegerSet value; }; /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { - TypeAttributeStorage(Type value) - : AttributeStorage(Attribute::Kind::Type), value(value) {} + using KeyTy = Type; + + TypeAttributeStorage(Type value) : value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator, + KeyTy key) { + return new (allocator.allocate()) + TypeAttributeStorage(key); + } + Type value; }; /// An attribute representing a reference to a function. struct FunctionAttributeStorage : public AttributeStorage { + using KeyTy = Function *; + FunctionAttributeStorage(Function *value) - : AttributeStorage(Attribute::Kind::Function, - /*isOrContainsFunctionCache=*/true), - value(value) {} + : AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { return key == value; } + + /// Construct a new storage instance. + static FunctionAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate()) + FunctionAttributeStorage(key); + } + + /// Storage cleanup function. + void cleanup() { + // Null out the function reference in the attribute to avoid dangling + // pointers. + value = nullptr; + } + Function *value; }; /// A base attribute representing a reference to a vector or tensor constant. struct ElementsAttributeStorage : public AttributeStorage { - ElementsAttributeStorage(Attribute::Kind kind, VectorOrTensorType type) - : AttributeStorage(kind), type(type) {} + ElementsAttributeStorage(VectorOrTensorType type) : type(type) {} VectorOrTensorType type; }; /// An attribute representing a reference to a vector or tensor constant, /// inwhich all elements have the same value. struct SplatElementsAttributeStorage : public ElementsAttributeStorage { + using KeyTy = std::pair; + SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt) - : ElementsAttributeStorage(Attribute::Kind::SplatElements, type), - elt(elt) {} + : ElementsAttributeStorage(type), elt(elt) {} + + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { + return key == std::make_pair(type, elt); + } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, key.second); + } + + /// Construct a new storage instance. + static SplatElementsAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate()) + SplatElementsAttributeStorage(key.first, key.second); + } + Attribute elt; }; /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public ElementsAttributeStorage { - DenseElementsAttributeStorage(Attribute::Kind kind, VectorOrTensorType type, - ArrayRef data) - : ElementsAttributeStorage(kind, type), data(data) {} + using KeyTy = std::pair>; + + DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef data) + : ElementsAttributeStorage(ty), data(data) {} + + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, key.second); + } + + /// Construct a new storage instance. + static DenseElementsAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + // If the data buffer is non-empty, we copy it into the allocator. + ArrayRef data = key.second; + if (!data.empty()) { + // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so + // the `readBits` will not fail when it accesses multiples of + // APINT_WORD_SIZE each time. + size_t sizeToAllocate = + llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE); + auto *rawCopy = (char *)allocator.allocate(sizeToAllocate, 64); + std::uninitialized_copy(data.begin(), data.end(), rawCopy); + data = {rawCopy, data.size()}; + } + return new (allocator.allocate()) + DenseElementsAttributeStorage(key.first, data); + } + ArrayRef data; }; /// An attribute representing a reference to a tensor constant with opaque /// content. struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage { + using KeyTy = std::tuple; + OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect, StringRef bytes) - : ElementsAttributeStorage(Attribute::Kind::OpaqueElements, type), - dialect(dialect), bytes(bytes) {} + : ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {} + + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { + return key == std::make_tuple(type, dialect, bytes); + } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get<0>(key), std::get<1>(key), + std::get<2>(key)); + } + + /// Construct a new storage instance. + static OpaqueElementsAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + // TODO(b/131468830): Provide a way to avoid copying content of large opaque + // tensors This will likely require a new reference attribute kind. + return new (allocator.allocate()) + OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key), + allocator.copyInto(std::get<2>(key))); + } + Dialect *dialect; StringRef bytes; }; /// An attribute representing a reference to a sparse vector or tensor object. struct SparseElementsAttributeStorage : public ElementsAttributeStorage { + using KeyTy = + std::tuple; + SparseElementsAttributeStorage(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values) - : ElementsAttributeStorage(Attribute::Kind::SparseElements, type), - indices(indices), values(values) {} + : ElementsAttributeStorage(type), indices(indices), values(values) {} + + /// Key equality and hash functions. + bool operator==(const KeyTy &key) const { + return key == std::make_tuple(type, indices, values); + } + static unsigned hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get<0>(key), std::get<1>(key), + std::get<2>(key)); + } + + /// Construct a new storage instance. + static SparseElementsAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + return new (allocator.allocate()) + SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key), + std::get<2>(key)); + } + DenseIntElementsAttr indices; DenseElementsAttr values; }; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 8e1214d..4df46df 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -26,7 +26,9 @@ using namespace mlir; using namespace mlir::detail; -Attribute::Kind Attribute::getKind() const { return attr->kind; } +Attribute::Kind Attribute::getKind() const { + return static_cast(attr->getKind()); +} bool Attribute::isOrContainsFunction() const { return attr->isOrContainsFunctionCache; @@ -66,6 +68,14 @@ Attribute Attribute::remapFunctionAttrs( } //===----------------------------------------------------------------------===// +// UnitAttr +//===----------------------------------------------------------------------===// + +UnitAttr UnitAttr::get(MLIRContext *context) { + return AttributeUniquer::get(context, Attribute::Kind::Unit); +} + +//===----------------------------------------------------------------------===// // NumericAttr //===----------------------------------------------------------------------===// @@ -91,6 +101,12 @@ bool NumericAttr::kindof(Kind kind) { // BoolAttr //===----------------------------------------------------------------------===// +BoolAttr BoolAttr::get(bool value, MLIRContext *context) { + // Note: The context is also used within the BoolAttrStorage. + return AttributeUniquer::get(context, Attribute::Kind::Bool, + context, value); +} + bool BoolAttr::getValue() const { return static_cast(attr)->value; } Type BoolAttr::getType() const { return static_cast(attr)->type; } @@ -99,6 +115,20 @@ Type BoolAttr::getType() const { return static_cast(attr)->type; } // IntegerAttr //===----------------------------------------------------------------------===// +IntegerAttr IntegerAttr::get(Type type, const APInt &value) { + return AttributeUniquer::get( + type.getContext(), Attribute::Kind::Integer, type, value); +} + +IntegerAttr IntegerAttr::get(Type type, int64_t value) { + // This uses 64 bit APInts by default for index type. + if (type.isIndex()) + return get(type, APInt(64, value)); + + auto intType = type.cast(); + return get(type, APInt(intType.getWidth(), value)); +} + APInt IntegerAttr::getValue() const { return static_cast(attr)->getValue(); } @@ -113,6 +143,44 @@ Type IntegerAttr::getType() const { // FloatAttr //===----------------------------------------------------------------------===// +FloatAttr FloatAttr::get(Type type, const APFloat &value) { + assert(&type.cast().getFloatSemantics() == &value.getSemantics() && + "FloatAttr type doesn't match the type implied by its value"); + return AttributeUniquer::get(type.getContext(), + Attribute::Kind::Float, type, value); +} + +static FloatAttr getFloatAttr(Type type, double value, + llvm::Optional loc) { + if (!type.isa()) { + if (loc) + type.getContext()->emitError(*loc, "expected floating point type"); + return nullptr; + } + + // Treat BF16 as double because it is not supported in LLVM's APFloat. + // TODO(b/121118307): add BF16 support to APFloat? + if (type.isBF16() || type.isF64()) + return FloatAttr::get(type, APFloat(value)); + + // This handles, e.g., F16 because there is no APFloat constructor for it. + bool unused; + APFloat val(value); + val.convert(type.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &unused); + return FloatAttr::get(type, val); +} + +FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { + return getFloatAttr(type, value, loc); +} + +FloatAttr FloatAttr::get(Type type, double value) { + auto res = getFloatAttr(type, value, /*loc=*/llvm::None); + assert(res && "failed to construct float attribute"); + return res; +} + APFloat FloatAttr::getValue() const { return static_cast(attr)->getValue(); } @@ -134,6 +202,11 @@ double FloatAttr::getValueAsDouble() const { // StringAttr //===----------------------------------------------------------------------===// +StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { + return AttributeUniquer::get(context, Attribute::Kind::String, + bytes); +} + StringRef StringAttr::getValue() const { return static_cast(attr)->value; } @@ -142,6 +215,11 @@ StringRef StringAttr::getValue() const { // ArrayAttr //===----------------------------------------------------------------------===// +ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { + return AttributeUniquer::get(context, Attribute::Kind::Array, + value); +} + ArrayRef ArrayAttr::getValue() const { return static_cast(attr)->value; } @@ -150,6 +228,11 @@ ArrayRef ArrayAttr::getValue() const { // AffineMapAttr //===----------------------------------------------------------------------===// +AffineMapAttr AffineMapAttr::get(AffineMap value) { + return AttributeUniquer::get( + value.getResult(0).getContext(), Attribute::Kind::AffineMap, value); +} + AffineMap AffineMapAttr::getValue() const { return static_cast(attr)->value; } @@ -158,6 +241,11 @@ AffineMap AffineMapAttr::getValue() const { // IntegerSetAttr //===----------------------------------------------------------------------===// +IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { + return AttributeUniquer::get( + value.getConstraint(0).getContext(), Attribute::Kind::IntegerSet, value); +} + IntegerSet IntegerSetAttr::getValue() const { return static_cast(attr)->value; } @@ -166,12 +254,29 @@ IntegerSet IntegerSetAttr::getValue() const { // TypeAttr //===----------------------------------------------------------------------===// +TypeAttr TypeAttr::get(Type value, MLIRContext *context) { + return AttributeUniquer::get(context, Attribute::Kind::Type, value); +} + Type TypeAttr::getValue() const { return static_cast(attr)->value; } //===----------------------------------------------------------------------===// // FunctionAttr //===----------------------------------------------------------------------===// +FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) { + assert(value && "Cannot get FunctionAttr for a null function"); + return AttributeUniquer::get(context, Attribute::Kind::Function, + value); +} + +/// This function is used by the internals of the Function class to null out +/// attributes referring to functions that are about to be deleted. +void FunctionAttr::dropFunctionReference(Function *value) { + AttributeUniquer::erase(value->getContext(), + Attribute::Kind::Function, value); +} + Function *FunctionAttr::getValue() const { return static_cast(attr)->value; } @@ -208,6 +313,14 @@ Attribute ElementsAttr::getValue(ArrayRef index) const { // SplatElementsAttr //===----------------------------------------------------------------------===// +SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, + Attribute elt) { + assert(elt.cast().getType() == type.getElementType() && + "value should be of the given type"); + return AttributeUniquer::get( + type.getContext(), Attribute::Kind::SplatElements, type, elt); +} + Attribute SplatElementsAttr::getValue() const { return static_cast(attr)->elt; } @@ -237,6 +350,70 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const { // DenseElementsAttr //===----------------------------------------------------------------------===// +DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, + ArrayRef data) { + assert((type.getSizeInBits() <= data.size() * APInt::APINT_WORD_SIZE) && + "Input data bit size should be larger than that type requires"); + + Attribute::Kind kind; + switch (type.getElementType().getKind()) { + case StandardTypes::BF16: + case StandardTypes::F16: + case StandardTypes::F32: + case StandardTypes::F64: + kind = Attribute::Kind::DenseFPElements; + break; + case StandardTypes::Integer: + kind = Attribute::Kind::DenseIntElements; + break; + default: + llvm_unreachable("unexpected element type"); + } + return AttributeUniquer::get(type.getContext(), kind, type, + data); +} + +DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, + ArrayRef values) { + assert(type.getElementType().isIntOrFloat() && + "expected int or float element type"); + assert(values.size() == type.getNumElements() && + "expected 'values' to contain the same number of elements as 'type'"); + + // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored + // with double semantics. + auto eltType = type.getElementType(); + size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); + + // Compress the attribute values into a character buffer. + SmallVector data(APInt::getNumWords(bitWidth * values.size()) * + APInt::APINT_WORD_SIZE); + APInt intVal; + for (unsigned i = 0, e = values.size(); i < e; ++i) { + switch (eltType.getKind()) { + case StandardTypes::BF16: + case StandardTypes::F16: + case StandardTypes::F32: + case StandardTypes::F64: + assert(eltType == values[i].cast().getType() && + "expected attribute value to have element type"); + intVal = values[i].cast().getValue().bitcastToAPInt(); + break; + case StandardTypes::Integer: + assert(eltType == values[i].cast().getType() && + "expected attribute value to have element type"); + intVal = values[i].cast().getValue(); + break; + default: + llvm_unreachable("unexpected element type"); + } + assert(intVal.getBitWidth() == bitWidth && + "expected value to have same bitwidth as element type"); + writeBits(data.data(), i * bitWidth, intVal); + } + return get(type, data); +} + /// Returns the number of elements held by this attribute. size_t DenseElementsAttr::size() const { return getType().getNumElements(); } @@ -457,6 +634,15 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const { // OpaqueElementsAttr //===----------------------------------------------------------------------===// +OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, + VectorOrTensorType type, + StringRef bytes) { + assert(TensorType::isValidElementType(type.getElementType()) && + "Input element type should be a valid tensor element type"); + return AttributeUniquer::get( + type.getContext(), Attribute::Kind::OpaqueElements, type, dialect, bytes); +} + StringRef OpaqueElementsAttr::getValue() const { return static_cast(attr)->bytes; } @@ -483,6 +669,16 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) { // SparseElementsAttr //===----------------------------------------------------------------------===// +SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, + DenseIntElementsAttr indices, + DenseElementsAttr values) { + assert(indices.getType().getElementType().isInteger(64) && + "expected sparse indices to be 64-bit integer values"); + return AttributeUniquer::get( + type.getContext(), Attribute::Kind::SparseElements, type, indices, + values); +} + DenseIntElementsAttr SparseElementsAttr::getIndices() const { return static_cast(attr)->indices; } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index cad3a28..9760599 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -195,70 +195,6 @@ struct IntegerSetKeyInfo : DenseMapInfo { } }; -struct FloatAttrKeyInfo : DenseMapInfo { - // Float attributes are uniqued based on wrapped APFloat. - using KeyTy = std::pair; - using DenseMapInfo::isEqual; - - static unsigned getHashValue(FloatAttributeStorage *key) { - return getHashValue(KeyTy(key->type, key->getValue())); - } - - static unsigned getHashValue(KeyTy key) { - return hash_combine(key.first, llvm::hash_value(key.second)); - } - - static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) { - if (rhs == getEmptyKey() || rhs == getTombstoneKey()) - return false; - return lhs.first == rhs->type && lhs.second.bitwiseIsEqual(rhs->getValue()); - } -}; - -struct IntegerAttrKeyInfo : DenseMapInfo { - // Integer attributes are uniqued based on wrapped APInt. - using KeyTy = std::pair; - using DenseMapInfo::isEqual; - - static unsigned getHashValue(IntegerAttributeStorage *key) { - return getHashValue(KeyTy(key->type, key->getValue())); - } - - static unsigned getHashValue(KeyTy key) { - return hash_combine(key.first, llvm::hash_value(key.second)); - } - - static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) { - if (rhs == getEmptyKey() || rhs == getTombstoneKey()) - return false; - assert((lhs.first.isIndex() || (lhs.first.isa() && - lhs.first.cast().getWidth() == - lhs.second.getBitWidth())) && - "mismatching integer type and value bitwidth"); - return lhs.first == rhs->type && lhs.second == rhs->getValue(); - } -}; - -struct ArrayAttrKeyInfo : DenseMapInfo { - // Array attributes are uniqued based on their elements. - using KeyTy = ArrayRef; - using DenseMapInfo::isEqual; - - static unsigned getHashValue(ArrayAttributeStorage *key) { - return getHashValue(KeyTy(key->value)); - } - - static unsigned getHashValue(KeyTy key) { - return hash_combine_range(key.begin(), key.end()); - } - - static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) { - if (rhs == getEmptyKey() || rhs == getTombstoneKey()) - return false; - return lhs == rhs->value; - } -}; - struct AttributeListKeyInfo : DenseMapInfo { // Array attributes are uniqued based on their elements. using KeyTy = ArrayRef; @@ -279,51 +215,6 @@ struct AttributeListKeyInfo : DenseMapInfo { } }; -struct DenseElementsAttrInfo : DenseMapInfo { - using KeyTy = std::pair>; - using DenseMapInfo::isEqual; - - static unsigned getHashValue(DenseElementsAttributeStorage *key) { - return getHashValue(KeyTy(key->type, key->data)); - } - - static unsigned getHashValue(KeyTy key) { - return hash_combine( - key.first, hash_combine_range(key.second.begin(), key.second.end())); - } - - static bool isEqual(const KeyTy &lhs, - const DenseElementsAttributeStorage *rhs) { - if (rhs == getEmptyKey() || rhs == getTombstoneKey()) - return false; - return lhs == std::make_pair(rhs->type, rhs->data); - } -}; - -struct OpaqueElementsAttrInfo : DenseMapInfo { - // Opaque element attributes are uniqued based on their dialect, type and - // value. - using KeyTy = std::tuple; - using DenseMapInfo::isEqual; - - static unsigned getHashValue(OpaqueElementsAttributeStorage *key) { - return getHashValue(KeyTy(key->dialect, key->type, key->bytes)); - } - - static unsigned getHashValue(KeyTy key) { - auto bytes = std::get<2>(key); - return hash_combine(std::get<0>(key), std::get<1>(key), - hash_combine_range(bytes.begin(), bytes.end())); - } - - static bool isEqual(const KeyTy &lhs, - const OpaqueElementsAttributeStorage *rhs) { - if (rhs == getEmptyKey() || rhs == getTombstoneKey()) - return false; - return lhs == std::make_tuple(rhs->dialect, rhs->type, rhs->bytes); - } -}; - struct CallSiteLocationKeyInfo : DenseMapInfo { // Call locations are uniqued based on their held concret location // and the caller location. @@ -492,36 +383,15 @@ public: //===--------------------------------------------------------------------===// // Attribute uniquing //===--------------------------------------------------------------------===// + StorageUniquer attributeUniquer; - // Attribute allocator and mutex for thread safety. + // Attribute list allocator and mutex for thread safety. llvm::BumpPtrAllocator attributeAllocator; llvm::sys::SmartRWMutex attributeMutex; - UnitAttributeStorage unitAttr; - BoolAttributeStorage *boolAttrs[2] = {nullptr}; - DenseSet integerAttrs; - DenseSet floatAttrs; - llvm::StringMap stringAttrs; - using ArrayAttrSet = DenseSet; - ArrayAttrSet arrayAttrs; - DenseMap affineMapAttrs; - DenseMap integerSetAttrs; - DenseMap typeAttrs; using AttributeListSet = DenseSet; AttributeListSet attributeLists; - DenseMap functionAttrs; - DenseMap, SplatElementsAttributeStorage *> - splatElementsAttrs; - using DenseElementsAttrSet = - DenseSet; - DenseElementsAttrSet denseElementsAttrs; - using OpaqueElementsAttrSet = - DenseSet; - OpaqueElementsAttrSet opaqueElementsAttrs; - DenseMap, - SparseElementsAttributeStorage *> - sparseElementsAttrs; public: MLIRContextImpl() @@ -985,234 +855,10 @@ const Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx, // Attribute uniquing //===----------------------------------------------------------------------===// -UnitAttr UnitAttr::get(MLIRContext *context) { - return &context->getImpl().unitAttr; -} - -BoolAttr BoolAttr::get(bool value, MLIRContext *context) { - auto &impl = context->getImpl(); - - { // Check for an existing instance in read-only mode. - llvm::sys::SmartScopedReader attributeLock(impl.attributeMutex); - if (auto *result = impl.boolAttrs[value]) - return result; - } - - // Aquire the mutex in write mode so that we can safely construct the new - // instance. - llvm::sys::SmartScopedWriter attributeLock(impl.attributeMutex); - - // Check for an existing instance again here, because another writer thread - // may have already created one. - auto *&result = impl.boolAttrs[value]; - if (result) - return result; - - result = impl.attributeAllocator.Allocate(); - new (result) BoolAttributeStorage(IntegerType::get(1, context), value); - return result; -} - -IntegerAttr IntegerAttr::get(Type type, const APInt &value) { - auto &impl = type.getContext()->getImpl(); - IntegerAttrKeyInfo::KeyTy key({type, value}); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.integerAttrs, key, impl.attributeMutex, [&] { - auto elements = ArrayRef(value.getRawData(), value.getNumWords()); - - auto byteSize = - IntegerAttributeStorage::totalSizeToAlloc(elements.size()); - auto rawMem = impl.attributeAllocator.Allocate( - byteSize, alignof(IntegerAttributeStorage)); - auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), - result->getTrailingObjects()); - return result; - }); -} - -IntegerAttr IntegerAttr::get(Type type, int64_t value) { - // This uses 64 bit APInts by default for index type. - if (type.isIndex()) - return get(type, APInt(64, value)); - - auto intType = type.dyn_cast(); - assert(intType && "expected an integer type for an integer attribute"); - return get(type, APInt(intType.getWidth(), value)); -} - -static FloatAttr getFloatAttr(Type type, double value, - llvm::Optional loc) { - if (!type.isa()) { - if (loc) - type.getContext()->emitError(*loc, "expected floating point type"); - return nullptr; - } - - // Treat BF16 as double because it is not supported in LLVM's APFloat. - // TODO(jpienaar): add BF16 support to APFloat? - if (type.isBF16() || type.isF64()) - return FloatAttr::get(type, APFloat(value)); - - // This handles, e.g., F16 because there is no APFloat constructor for it. - bool unused; - APFloat val(value); - val.convert(type.cast().getFloatSemantics(), - APFloat::rmNearestTiesToEven, &unused); - return FloatAttr::get(type, val); -} - -FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { - return getFloatAttr(type, value, loc); -} - -FloatAttr FloatAttr::get(Type type, double value) { - auto res = getFloatAttr(type, value, /*loc=*/llvm::None); - assert(res && "failed to construct float attribute"); - return res; -} - -FloatAttr FloatAttr::get(Type type, const APFloat &value) { - auto fltType = type.cast(); - assert(&fltType.getFloatSemantics() == &value.getSemantics() && - "FloatAttr type doesn't match the type implied by its value"); - (void)fltType; - auto &impl = type.getContext()->getImpl(); - FloatAttrKeyInfo::KeyTy key({type, value}); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.floatAttrs, key, impl.attributeMutex, [&] { - const auto &apint = value.bitcastToAPInt(); - // Here one word's bitwidth equals to that of uint64_t. - auto elements = ArrayRef(apint.getRawData(), apint.getNumWords()); - - auto byteSize = - FloatAttributeStorage::totalSizeToAlloc(elements.size()); - auto rawMem = impl.attributeAllocator.Allocate( - byteSize, alignof(FloatAttributeStorage)); - auto result = ::new (rawMem) - FloatAttributeStorage(value.getSemantics(), type, elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), - result->getTrailingObjects()); - return result; - }); -} - -StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { - auto &impl = context->getImpl(); - - { // Check for an existing instance in read-only mode. - llvm::sys::SmartScopedReader attributeLock(impl.attributeMutex); - auto it = impl.stringAttrs.find(bytes); - if (it != impl.stringAttrs.end()) - return it->second; - } - - // Aquire the mutex in write mode so that we can safely construct the new - // instance. - llvm::sys::SmartScopedWriter attributeLock(impl.attributeMutex); - - // Check for an existing instance again here, because another writer thread - // may have already created one. - auto it = impl.stringAttrs.insert({bytes, nullptr}).first; - if (it->second) - return it->second; - - auto result = new (impl.attributeAllocator.Allocate()) - StringAttributeStorage(it->first()); - return it->second = result; -} - -ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { - auto &impl = context->getImpl(); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.arrayAttrs, value, impl.attributeMutex, [&] { - auto *result = impl.attributeAllocator.Allocate(); - - // Copy the elements into the bump pointer. - value = copyArrayRefInto(impl.attributeAllocator, value); - - // Check to see if any of the elements have a function attr. - bool hasFunctionAttr = false; - for (auto elt : value) - if (elt.isOrContainsFunction()) { - hasFunctionAttr = true; - break; - } - - // Initialize the memory using placement new. - return new (result) ArrayAttributeStorage(hasFunctionAttr, value); - }); -} - -AffineMapAttr AffineMapAttr::get(AffineMap value) { - auto *context = value.getResult(0).getContext(); - auto &impl = context->getImpl(); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.affineMapAttrs, value, impl.attributeMutex, [&] { - auto result = impl.attributeAllocator.Allocate(); - return new (result) AffineMapAttributeStorage(value); - }); -} - -IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { - auto *context = value.getConstraint(0).getContext(); - auto &impl = context->getImpl(); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.integerSetAttrs, value, impl.attributeMutex, [&] { - auto result = - impl.attributeAllocator.Allocate(); - return new (result) IntegerSetAttributeStorage(value); - }); -} - -TypeAttr TypeAttr::get(Type type, MLIRContext *context) { - auto &impl = context->getImpl(); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.typeAttrs, type, impl.attributeMutex, [&] { - auto result = impl.attributeAllocator.Allocate(); - return new (result) TypeAttributeStorage(type); - }); -} - -FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) { - assert(value && "Cannot get FunctionAttr for a null function"); - auto &impl = context->getImpl(); - - // Safely get or create an attribute instance. - return safeGetOrCreate(impl.functionAttrs, value, impl.attributeMutex, [&] { - auto result = impl.attributeAllocator.Allocate(); - return new (result) FunctionAttributeStorage(value); - }); -} - -/// This function is used by the internals of the Function class to null out -/// attributes referring to functions that are about to be deleted. -void FunctionAttr::dropFunctionReference(Function *value) { - auto &impl = value->getContext()->getImpl(); - - // Aquire the mutex in write mode so that we can safely remove the attribute - // if it exists. - llvm::sys::SmartScopedWriter attributeLock(impl.attributeMutex); - - // Check to see if there was an attribute referring to this function. - auto &functionAttrs = impl.functionAttrs; - - // If not, then we're done. - auto it = functionAttrs.find(value); - if (it == functionAttrs.end()) - return; - - // If so, null out the function reference in the attribute (to avoid dangling - // pointers) and remove the entry from the map so the map doesn't contain - // dangling keys. - it->second->value = nullptr; - functionAttrs.erase(it); +/// Returns the storage uniquer used for constructing attribute storage +/// instances. This should not be used directly. +StorageUniquer &MLIRContext::getAttributeUniquer() { + return getImpl().attributeUniquer; } /// Perform a three-way comparison between the names of the specified @@ -1281,168 +927,6 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef attrs, }); } -// Returns false if the given `attr` is not of the given `type`. -// Note: This function is only intended to be used for assertion. So it's -// possibly allowing invalid cases that are unimplemented. -static bool attrIsOfType(Attribute attr, Type type) { - if (auto floatAttr = attr.dyn_cast()) - return floatAttr.getType() == type; - if (auto intAttr = attr.dyn_cast()) - return intAttr.getType() == type; - if (auto elementsAttr = attr.dyn_cast()) - return elementsAttr.getType() == type; - // TODO: check the other cases - return true; -} - -SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, - Attribute elt) { - auto attr = elt.dyn_cast(); - assert(attr && "expected numeric value"); - assert(attr.getType() == type.getElementType() && - "value should be of the given type"); - (void)attr; - - auto &impl = type.getContext()->getImpl(); - - // Safely get or create an attribute instance. - std::pair key(type, elt); - return safeGetOrCreate( - impl.splatElementsAttrs, key, impl.attributeMutex, [&] { - auto result = - impl.attributeAllocator.Allocate(); - return new (result) SplatElementsAttributeStorage(type, elt); - }); -} - -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, - ArrayRef data) { - auto bitsRequired = type.getSizeInBits(); - (void)bitsRequired; - assert((bitsRequired <= data.size() * APInt::APINT_WORD_SIZE) && - "Input data bit size should be larger than that type requires"); - - auto &impl = type.getContext()->getImpl(); - DenseElementsAttrInfo::KeyTy key({type, data}); - - // Safely get or create an attribute instance. - return safeGetOrCreate( - impl.denseElementsAttrs, key, impl.attributeMutex, [&] { - Attribute::Kind kind; - switch (type.getElementType().getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: - case StandardTypes::F32: - case StandardTypes::F64: - kind = Attribute::Kind::DenseFPElements; - break; - case StandardTypes::Integer: - kind = Attribute::Kind::DenseIntElements; - break; - default: - llvm_unreachable("unexpected element type"); - } - - // If the data buffer is non-empty, we copy it into the context. - ArrayRef copy; - if (!data.empty()) { - // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so - // the `readBits` will not fail when it accesses multiples of - // APINT_WORD_SIZE each time. - size_t sizeToAllocate = - llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE); - auto *rawCopy = - (char *)impl.attributeAllocator.Allocate(sizeToAllocate, 64); - std::uninitialized_copy(data.begin(), data.end(), rawCopy); - copy = {rawCopy, data.size()}; - } - auto *result = - impl.attributeAllocator.Allocate(); - return new (result) DenseElementsAttributeStorage(kind, type, copy); - }); -} - -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, - ArrayRef values) { - assert(type.getElementType().isIntOrFloat() && - "expected int or float element type"); - assert(values.size() == type.getNumElements() && - "expected 'values' to contain the same number of elements as 'type'"); - - // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored - // with double semantics. - auto eltType = type.getElementType(); - size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); - - // Compress the attribute values into a character buffer. - SmallVector data(APInt::getNumWords(bitWidth * values.size()) * - APInt::APINT_WORD_SIZE); - APInt intVal; - for (unsigned i = 0, e = values.size(); i < e; ++i) { - switch (eltType.getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: - case StandardTypes::F32: - case StandardTypes::F64: - assert(eltType == values[i].cast().getType() && - "expected attribute value to have element type"); - intVal = values[i].cast().getValue().bitcastToAPInt(); - break; - case StandardTypes::Integer: - assert(eltType == values[i].cast().getType() && - "expected attribute value to have element type"); - intVal = values[i].cast().getValue(); - break; - default: - llvm_unreachable("unexpected element type"); - } - assert(intVal.getBitWidth() == bitWidth && - "expected value to have same bitwidth as element type"); - writeBits(data.data(), i * bitWidth, intVal); - } - return get(type, data); -} - -OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, - VectorOrTensorType type, - StringRef bytes) { - assert(TensorType::isValidElementType(type.getElementType()) && - "Input element type should be a valid tensor element type"); - - auto &impl = type.getContext()->getImpl(); - OpaqueElementsAttrInfo::KeyTy key(dialect, type, bytes); - - return safeGetOrCreate( - impl.opaqueElementsAttrs, key, impl.attributeMutex, [&] { - auto *result = - impl.attributeAllocator.Allocate(); - - // TODO: Provide a way to avoid copying content of large opaque tensors - // This will likely require a new reference attribute kind. - bytes = bytes.copy(impl.attributeAllocator); - return new (result) - OpaqueElementsAttributeStorage(type, dialect, bytes); - }); -} - -SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, - DenseIntElementsAttr indices, - DenseElementsAttr values) { - assert(indices.getType().getElementType().isInteger(64) && - "expected sparse indices to be 64-bit integer values"); - - auto &impl = type.getContext()->getImpl(); - auto key = std::make_tuple(type, indices, values); - - // Safely get or create an attribute instance. - return safeGetOrCreate( - impl.sparseElementsAttrs, key, impl.attributeMutex, [&] { - return new ( - impl.attributeAllocator.Allocate()) - SparseElementsAttributeStorage(type, indices, values); - }); -} - //===----------------------------------------------------------------------===// // AffineMap and AffineExpr uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp index 14d8f39..6fd55e7 100644 --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -105,6 +105,23 @@ struct StorageUniquerImpl { return result = initializeStorage(kind, ctorFn); } + /// Erase an instance of a complex derived type. + void erase(unsigned kind, unsigned hashValue, + llvm::function_ref isEqual, + llvm::function_ref cleanupFn) { + LookupKey lookupKey{kind, hashValue, isEqual}; + + // Acquire a writer-lock so that we can safely erase the type instance. + llvm::sys::SmartScopedWriter typeLock(mutex); + auto existing = storageTypes.find_as(lookupKey); + if (existing == storageTypes.end()) + return; + + // Cleanup the storage and remove it from the map. + cleanupFn(existing->storage); + storageTypes.erase(existing); + } + //===--------------------------------------------------------------------===// // Instance Storage //===--------------------------------------------------------------------===// @@ -179,3 +196,12 @@ auto StorageUniquer::getImpl( -> BaseStorage * { return impl->getOrCreate(kind, ctorFn); } + +/// Implementation for erasing an instance of a derived type with complex +/// storage. +void StorageUniquer::eraseImpl( + unsigned kind, unsigned hashValue, + llvm::function_ref isEqual, + std::function cleanupFn) { + impl->erase(kind, hashValue, isEqual, cleanupFn); +} -- 2.7.4