Refactor Attribute uniquing to use StorageUniquer instead of being hard coded...
authorRiver Riddle <riverriddle@google.com>
Tue, 30 Apr 2019 17:31:29 +0000 (10:31 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:22:50 +0000 (08:22 -0700)
--

PiperOrigin-RevId: 245974705

mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Support/StorageUniquer.cpp

index 25b7399..5deb2bb 100644 (file)
@@ -38,7 +38,6 @@ class VectorOrTensorType;
 namespace detail {
 
 struct AttributeStorage;
-struct UnitAttributeStorage;
 struct BoolAttributeStorage;
 struct IntegerAttributeStorage;
 struct FloatAttributeStorage;
@@ -165,7 +164,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
 class UnitAttr : public Attribute {
 public:
   using Attribute::Attribute;
-  using ImplType = detail::UnitAttributeStorage;
 
   static UnitAttr get(MLIRContext *context);
 
index ed2640d..4b2343a 100644 (file)
@@ -98,6 +98,10 @@ public:
   /// This should not be used directly.
   StorageUniquer &getTypeUniquer();
 
+  /// Returns the storage uniquer used for constructing attribute storage
+  /// instances. This should not be used directly.
+  StorageUniquer &getAttributeUniquer();
+
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
index 8a0c590..2a9bb4a 100644 (file)
@@ -55,6 +55,12 @@ struct StorageUniquerImpl;
 ///      that builds a unique instance of the derived storage. The arguments to
 ///      this function are an allocator to store any uniqued data and the key
 ///      type for this storage.
+///
+///    - Provide a cleanup method:
+///        'void cleanup()'
+///      that is called when erasing a storage instance. This should cleanup any
+///      fields of the storage as necessary and not attempt to free the memory
+///      of the storage itself.
 class StorageUniquer {
 public:
   StorageUniquer();
@@ -114,7 +120,7 @@ public:
 
   /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
   /// that can be used to initialize a newly inserted storage instance. This
-  /// overload is used for derived types that have complex storage or uniquing
+  /// function is used for derived types that have complex storage or uniquing
   /// constraints.
   template <typename Storage, typename... Args>
   Storage *getComplex(std::function<void(Storage *)> initFn, unsigned kind,
@@ -146,7 +152,7 @@ public:
 
   /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
   /// that can be used to initialize a newly inserted storage instance. This
-  /// overload is used for derived types that use no additional storage or
+  /// function is used for derived types that use no additional storage or
   /// uniquing outside of the kind.
   template <typename Storage>
   Storage *getSimple(std::function<void(Storage *)> initFn, unsigned kind) {
@@ -159,6 +165,28 @@ public:
     return static_cast<Storage *>(getImpl(kind, ctorFn));
   }
 
+  /// Erases a uniqued instance of 'Storage'. This function is used for derived
+  /// types that have complex storage or uniquing constraints.
+  template <typename Storage, typename... Args>
+  void eraseComplex(unsigned kind, Args &&... args) {
+    // Construct a value of the derived key type.
+    auto derivedKey = getKey<Storage>(args...);
+
+    // Create a hash of the kind and the derived key.
+    unsigned hashValue = getHash<Storage>(kind, derivedKey);
+
+    // Generate an equality function for the derived storage.
+    std::function<bool(const BaseStorage *)> isEqual =
+        [&derivedKey](const BaseStorage *existing) {
+          return static_cast<const Storage &>(*existing) == derivedKey;
+        };
+
+    // Attempt to erase the storage instance.
+    eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) {
+      static_cast<Storage *>(storage)->cleanup();
+    });
+  }
+
 private:
   /// Implementation for getting/creating an instance of a derived type with
   /// complex storage.
@@ -171,6 +199,12 @@ private:
   BaseStorage *getImpl(unsigned kind,
                        std::function<BaseStorage *(StorageAllocator &)> ctorFn);
 
+  /// Implementation for erasing an instance of a derived type with complex
+  /// storage.
+  void eraseImpl(unsigned kind, unsigned hashValue,
+                 llvm::function_ref<bool(const BaseStorage *)> isEqual,
+                 std::function<void(BaseStorage *)> cleanupFn);
+
   /// The internal implementation class.
   std::unique_ptr<detail::StorageUniquerImpl> impl;
 
index e1da603..89ac240 100644 (file)
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/StorageUniquer.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
 namespace detail {
-
 /// Base storage class appearing in an attribute.
-struct AttributeStorage {
-  AttributeStorage(Attribute::Kind kind, bool isOrContainsFunctionCache = false)
-      : kind(kind), isOrContainsFunctionCache(isOrContainsFunctionCache) {}
-
-  Attribute::Kind kind : 8;
+struct AttributeStorage : public StorageUniquer::BaseStorage {
+  AttributeStorage(bool isOrContainsFunctionCache = false)
+      : isOrContainsFunctionCache(isOrContainsFunctionCache) {}
 
   /// This field is true if this is, or contains, a function attribute.
   bool isOrContainsFunctionCache : 1;
 };
 
-/// An attribute representing a unit value.
-struct UnitAttributeStorage : public AttributeStorage {
-  UnitAttributeStorage() : AttributeStorage(Attribute::Kind::Unit) {}
+// A utility class to get, or create, unique instances of attributes within an
+// MLIRContext. This class manages all creation and uniquing of attributes.
+class AttributeUniquer {
+public:
+  /// Get an uniqued instance of attribute T. This overload is used for
+  /// derived attributes that have complex storage or uniquing constraints.
+  template <typename T, typename... Args>
+  static typename std::enable_if<
+      !std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
+  get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+    return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
+        /*initFn=*/{}, static_cast<unsigned>(kind),
+        std::forward<Args>(args)...);
+  }
+
+  /// Get an uniqued instance of attribute T. This overload is used for
+  /// derived attributes that use the AttributeStorage directly and thus need no
+  /// additional storage or uniquing.
+  template <typename T, typename... Args>
+  static typename std::enable_if<
+      std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
+  get(MLIRContext *ctx, Attribute::Kind kind) {
+    return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
+        /*initFn=*/{}, static_cast<unsigned>(kind));
+  }
+
+  /// Erase a uniqued instance of attribute T. This overload is used for
+  /// derived attributes that have complex storage or uniquing constraints.
+  template <typename T, typename... Args>
+  static typename std::enable_if<
+      !std::is_same<typename T::ImplType, AttributeStorage>::value>::type
+  erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+    return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
+        static_cast<unsigned>(kind), std::forward<Args>(args)...);
+  }
 };
 
+using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
+
 /// An attribute representing a boolean value.
 struct BoolAttributeStorage : public AttributeStorage {
-  BoolAttributeStorage(Type type, bool value)
-      : AttributeStorage(Attribute::Kind::Bool), type(type), value(value) {}
-  const Type type;
+  using KeyTy = std::pair<MLIRContext *, bool>;
+
+  BoolAttributeStorage(Type type, bool value) : type(type), value(value) {}
+
+  /// We only check equality for and hash with the boolean key parameter.
+  bool operator==(const KeyTy &key) const { return key.second == value; }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_value(key.second);
+  }
+
+  static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    return new (allocator.allocate<BoolAttributeStorage>())
+        BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
+  }
+
+  Type type;
   bool value;
 };
 
@@ -62,14 +108,37 @@ struct BoolAttributeStorage : public AttributeStorage {
 struct IntegerAttributeStorage final
     : public AttributeStorage,
       public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
+  using KeyTy = std::pair<Type, APInt>;
+
   IntegerAttributeStorage(Type type, size_t numObjects)
-      : AttributeStorage(Attribute::Kind::Integer), type(type),
-        numObjects(numObjects) {
+      : type(type), numObjects(numObjects) {
     assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
   }
 
-  const Type type;
-  size_t numObjects;
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(type, getValue());
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.first, llvm::hash_value(key.second));
+  }
+
+  /// Construct a new storage instance.
+  static IntegerAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    Type type;
+    APInt value;
+    std::tie(type, value) = key;
+
+    auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
+    auto size =
+        IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
+    auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
+    auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
+    std::uninitialized_copy(elements.begin(), elements.end(),
+                            result->getTrailingObjects<uint64_t>());
+    return result;
+  }
 
   /// Returns an APInt representing the stored value.
   APInt getValue() const {
@@ -78,19 +147,47 @@ struct IntegerAttributeStorage final
     return APInt(type.getIntOrFloatBitWidth(),
                  {getTrailingObjects<uint64_t>(), numObjects});
   }
+
+  Type type;
+  size_t numObjects;
 };
 
 /// An attribute representing a floating point value.
 struct FloatAttributeStorage final
     : public AttributeStorage,
       public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
+  using KeyTy = std::pair<Type, APFloat>;
+
   FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
                         size_t numObjects)
-      : AttributeStorage(Attribute::Kind::Float), semantics(semantics),
-        type(type.cast<FloatType>()), numObjects(numObjects) {}
-  const llvm::fltSemantics &semantics;
-  const FloatType type;
-  size_t numObjects;
+      : semantics(semantics), type(type.cast<FloatType>()),
+        numObjects(numObjects) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key.first == type && key.second.bitwiseIsEqual(getValue());
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.first, llvm::hash_value(key.second));
+  }
+
+  /// Construct a new storage instance.
+  static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                          const KeyTy &key) {
+    const auto &apint = key.second.bitcastToAPInt();
+
+    // Here one word's bitwidth equals to that of uint64_t.
+    auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
+
+    auto byteSize =
+        FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
+    auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
+    auto result = ::new (rawMem) FloatAttributeStorage(
+        key.second.getSemantics(), key.first, elements.size());
+    std::uninitialized_copy(elements.begin(), elements.end(),
+                            result->getTrailingObjects<uint64_t>());
+    return result;
+  }
 
   /// Returns an APFloat representing the stored value.
   APFloat getValue() const {
@@ -98,95 +195,266 @@ struct FloatAttributeStorage final
                      {getTrailingObjects<uint64_t>(), numObjects});
     return APFloat(semantics, val);
   }
+
+  const llvm::fltSemantics &semantics;
+  FloatType type;
+  size_t numObjects;
 };
 
 /// An attribute representing a string value.
 struct StringAttributeStorage : public AttributeStorage {
-  StringAttributeStorage(StringRef value)
-      : AttributeStorage(Attribute::Kind::String), value(value) {}
+  using KeyTy = StringRef;
+
+  StringAttributeStorage(StringRef value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                           const KeyTy &key) {
+    return new (allocator.allocate<StringAttributeStorage>())
+        StringAttributeStorage(allocator.copyInto(key));
+  }
+
   StringRef value;
 };
 
 /// An attribute representing an array of other attributes.
 struct ArrayAttributeStorage : public AttributeStorage {
+  using KeyTy = ArrayRef<Attribute>;
+
   ArrayAttributeStorage(bool hasFunctionAttr, ArrayRef<Attribute> value)
-      : AttributeStorage(Attribute::Kind::Array, hasFunctionAttr),
-        value(value) {}
+      : AttributeStorage(hasFunctionAttr), value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                          const KeyTy &key) {
+    // Check to see if any of the elements have a function attr.
+    bool hasFunctionAttr = llvm::any_of(
+        key, [](Attribute elt) { return elt.isOrContainsFunction(); });
+
+    // Initialize the memory using placement new.
+    return new (allocator.allocate<ArrayAttributeStorage>())
+        ArrayAttributeStorage(hasFunctionAttr, allocator.copyInto(key));
+  }
+
   ArrayRef<Attribute> value;
 };
 
 // An attribute representing a reference to an affine map.
 struct AffineMapAttributeStorage : public AttributeStorage {
-  AffineMapAttributeStorage(AffineMap value)
-      : AttributeStorage(Attribute::Kind::AffineMap), value(value) {}
+  using KeyTy = AffineMap;
+
+  AffineMapAttributeStorage(AffineMap value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static AffineMapAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<AffineMapAttributeStorage>())
+        AffineMapAttributeStorage(key);
+  }
+
   AffineMap value;
 };
 
 // An attribute representing a reference to an integer set.
 struct IntegerSetAttributeStorage : public AttributeStorage {
-  IntegerSetAttributeStorage(IntegerSet value)
-      : AttributeStorage(Attribute::Kind::IntegerSet), value(value) {}
+  using KeyTy = IntegerSet;
+
+  IntegerSetAttributeStorage(IntegerSet value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static IntegerSetAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<IntegerSetAttributeStorage>())
+        IntegerSetAttributeStorage(key);
+  }
+
   IntegerSet value;
 };
 
 /// An attribute representing a reference to a type.
 struct TypeAttributeStorage : public AttributeStorage {
-  TypeAttributeStorage(Type value)
-      : AttributeStorage(Attribute::Kind::Type), value(value) {}
+  using KeyTy = Type;
+
+  TypeAttributeStorage(Type value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator,
+                                         KeyTy key) {
+    return new (allocator.allocate<TypeAttributeStorage>())
+        TypeAttributeStorage(key);
+  }
+
   Type value;
 };
 
 /// An attribute representing a reference to a function.
 struct FunctionAttributeStorage : public AttributeStorage {
+  using KeyTy = Function *;
+
   FunctionAttributeStorage(Function *value)
-      : AttributeStorage(Attribute::Kind::Function,
-                         /*isOrContainsFunctionCache=*/true),
-        value(value) {}
+      : AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static FunctionAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<FunctionAttributeStorage>())
+        FunctionAttributeStorage(key);
+  }
+
+  /// Storage cleanup function.
+  void cleanup() {
+    // Null out the function reference in the attribute to avoid dangling
+    // pointers.
+    value = nullptr;
+  }
+
   Function *value;
 };
 
 /// A base attribute representing a reference to a vector or tensor constant.
 struct ElementsAttributeStorage : public AttributeStorage {
-  ElementsAttributeStorage(Attribute::Kind kind, VectorOrTensorType type)
-      : AttributeStorage(kind), type(type) {}
+  ElementsAttributeStorage(VectorOrTensorType type) : type(type) {}
   VectorOrTensorType type;
 };
 
 /// An attribute representing a reference to a vector or tensor constant,
 /// inwhich all elements have the same value.
 struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
+  using KeyTy = std::pair<VectorOrTensorType, Attribute>;
+
   SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt)
-      : ElementsAttributeStorage(Attribute::Kind::SplatElements, type),
-        elt(elt) {}
+      : ElementsAttributeStorage(type), elt(elt) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == std::make_pair(type, elt);
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.first, key.second);
+  }
+
+  /// Construct a new storage instance.
+  static SplatElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<SplatElementsAttributeStorage>())
+        SplatElementsAttributeStorage(key.first, key.second);
+  }
+
   Attribute elt;
 };
 
 /// An attribute representing a reference to a dense vector or tensor object.
 struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
-  DenseElementsAttributeStorage(Attribute::Kind kind, VectorOrTensorType type,
-                                ArrayRef<char> data)
-      : ElementsAttributeStorage(kind, type), data(data) {}
+  using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
+
+  DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef<char> data)
+      : ElementsAttributeStorage(ty), data(data) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.first, key.second);
+  }
+
+  /// Construct a new storage instance.
+  static DenseElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    // If the data buffer is non-empty, we copy it into the allocator.
+    ArrayRef<char> data = key.second;
+    if (!data.empty()) {
+      // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so
+      // the `readBits` will not fail when it accesses multiples of
+      // APINT_WORD_SIZE each time.
+      size_t sizeToAllocate =
+          llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE);
+      auto *rawCopy = (char *)allocator.allocate(sizeToAllocate, 64);
+      std::uninitialized_copy(data.begin(), data.end(), rawCopy);
+      data = {rawCopy, data.size()};
+    }
+    return new (allocator.allocate<DenseElementsAttributeStorage>())
+        DenseElementsAttributeStorage(key.first, data);
+  }
+
   ArrayRef<char> data;
 };
 
 /// An attribute representing a reference to a tensor constant with opaque
 /// content.
 struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
+  using KeyTy = std::tuple<VectorOrTensorType, Dialect *, StringRef>;
+
   OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect,
                                  StringRef bytes)
-      : ElementsAttributeStorage(Attribute::Kind::OpaqueElements, type),
-        dialect(dialect), bytes(bytes) {}
+      : ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == std::make_tuple(type, dialect, bytes);
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+                              std::get<2>(key));
+  }
+
+  /// Construct a new storage instance.
+  static OpaqueElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    // TODO(b/131468830): Provide a way to avoid copying content of large opaque
+    // tensors This will likely require a new reference attribute kind.
+    return new (allocator.allocate<OpaqueElementsAttributeStorage>())
+        OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
+                                       allocator.copyInto(std::get<2>(key)));
+  }
+
   Dialect *dialect;
   StringRef bytes;
 };
 
 /// An attribute representing a reference to a sparse vector or tensor object.
 struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
+  using KeyTy =
+      std::tuple<VectorOrTensorType, DenseIntElementsAttr, DenseElementsAttr>;
+
   SparseElementsAttributeStorage(VectorOrTensorType type,
                                  DenseIntElementsAttr indices,
                                  DenseElementsAttr values)
-      : ElementsAttributeStorage(Attribute::Kind::SparseElements, type),
-        indices(indices), values(values) {}
+      : ElementsAttributeStorage(type), indices(indices), values(values) {}
+
+  /// Key equality and hash functions.
+  bool operator==(const KeyTy &key) const {
+    return key == std::make_tuple(type, indices, values);
+  }
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
+                              std::get<2>(key));
+  }
+
+  /// Construct a new storage instance.
+  static SparseElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<SparseElementsAttributeStorage>())
+        SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
+                                       std::get<2>(key));
+  }
+
   DenseIntElementsAttr indices;
   DenseElementsAttr values;
 };
index 8e1214d..4df46df 100644 (file)
@@ -26,7 +26,9 @@
 using namespace mlir;
 using namespace mlir::detail;
 
-Attribute::Kind Attribute::getKind() const { return attr->kind; }
+Attribute::Kind Attribute::getKind() const {
+  return static_cast<Kind>(attr->getKind());
+}
 
 bool Attribute::isOrContainsFunction() const {
   return attr->isOrContainsFunctionCache;
@@ -66,6 +68,14 @@ Attribute Attribute::remapFunctionAttrs(
 }
 
 //===----------------------------------------------------------------------===//
+// UnitAttr
+//===----------------------------------------------------------------------===//
+
+UnitAttr UnitAttr::get(MLIRContext *context) {
+  return AttributeUniquer::get<UnitAttr>(context, Attribute::Kind::Unit);
+}
+
+//===----------------------------------------------------------------------===//
 // NumericAttr
 //===----------------------------------------------------------------------===//
 
@@ -91,6 +101,12 @@ bool NumericAttr::kindof(Kind kind) {
 // BoolAttr
 //===----------------------------------------------------------------------===//
 
+BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
+  // Note: The context is also used within the BoolAttrStorage.
+  return AttributeUniquer::get<BoolAttr>(context, Attribute::Kind::Bool,
+                                         context, value);
+}
+
 bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 
 Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
@@ -99,6 +115,20 @@ Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
 // IntegerAttr
 //===----------------------------------------------------------------------===//
 
+IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
+  return AttributeUniquer::get<IntegerAttr>(
+      type.getContext(), Attribute::Kind::Integer, type, value);
+}
+
+IntegerAttr IntegerAttr::get(Type type, int64_t value) {
+  // This uses 64 bit APInts by default for index type.
+  if (type.isIndex())
+    return get(type, APInt(64, value));
+
+  auto intType = type.cast<IntegerType>();
+  return get(type, APInt(intType.getWidth(), value));
+}
+
 APInt IntegerAttr::getValue() const {
   return static_cast<ImplType *>(attr)->getValue();
 }
@@ -113,6 +143,44 @@ Type IntegerAttr::getType() const {
 // FloatAttr
 //===----------------------------------------------------------------------===//
 
+FloatAttr FloatAttr::get(Type type, const APFloat &value) {
+  assert(&type.cast<FloatType>().getFloatSemantics() == &value.getSemantics() &&
+         "FloatAttr type doesn't match the type implied by its value");
+  return AttributeUniquer::get<FloatAttr>(type.getContext(),
+                                          Attribute::Kind::Float, type, value);
+}
+
+static FloatAttr getFloatAttr(Type type, double value,
+                              llvm::Optional<Location> loc) {
+  if (!type.isa<FloatType>()) {
+    if (loc)
+      type.getContext()->emitError(*loc, "expected floating point type");
+    return nullptr;
+  }
+
+  // Treat BF16 as double because it is not supported in LLVM's APFloat.
+  // TODO(b/121118307): add BF16 support to APFloat?
+  if (type.isBF16() || type.isF64())
+    return FloatAttr::get(type, APFloat(value));
+
+  // This handles, e.g., F16 because there is no APFloat constructor for it.
+  bool unused;
+  APFloat val(value);
+  val.convert(type.cast<FloatType>().getFloatSemantics(),
+              APFloat::rmNearestTiesToEven, &unused);
+  return FloatAttr::get(type, val);
+}
+
+FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
+  return getFloatAttr(type, value, loc);
+}
+
+FloatAttr FloatAttr::get(Type type, double value) {
+  auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
+  assert(res && "failed to construct float attribute");
+  return res;
+}
+
 APFloat FloatAttr::getValue() const {
   return static_cast<ImplType *>(attr)->getValue();
 }
@@ -134,6 +202,11 @@ double FloatAttr::getValueAsDouble() const {
 // StringAttr
 //===----------------------------------------------------------------------===//
 
+StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+  return AttributeUniquer::get<StringAttr>(context, Attribute::Kind::String,
+                                           bytes);
+}
+
 StringRef StringAttr::getValue() const {
   return static_cast<ImplType *>(attr)->value;
 }
@@ -142,6 +215,11 @@ StringRef StringAttr::getValue() const {
 // ArrayAttr
 //===----------------------------------------------------------------------===//
 
+ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+  return AttributeUniquer::get<ArrayAttr>(context, Attribute::Kind::Array,
+                                          value);
+}
+
 ArrayRef<Attribute> ArrayAttr::getValue() const {
   return static_cast<ImplType *>(attr)->value;
 }
@@ -150,6 +228,11 @@ ArrayRef<Attribute> ArrayAttr::getValue() const {
 // AffineMapAttr
 //===----------------------------------------------------------------------===//
 
+AffineMapAttr AffineMapAttr::get(AffineMap value) {
+  return AttributeUniquer::get<AffineMapAttr>(
+      value.getResult(0).getContext(), Attribute::Kind::AffineMap, value);
+}
+
 AffineMap AffineMapAttr::getValue() const {
   return static_cast<ImplType *>(attr)->value;
 }
@@ -158,6 +241,11 @@ AffineMap AffineMapAttr::getValue() const {
 // IntegerSetAttr
 //===----------------------------------------------------------------------===//
 
+IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
+  return AttributeUniquer::get<IntegerSetAttr>(
+      value.getConstraint(0).getContext(), Attribute::Kind::IntegerSet, value);
+}
+
 IntegerSet IntegerSetAttr::getValue() const {
   return static_cast<ImplType *>(attr)->value;
 }
@@ -166,12 +254,29 @@ IntegerSet IntegerSetAttr::getValue() const {
 // TypeAttr
 //===----------------------------------------------------------------------===//
 
+TypeAttr TypeAttr::get(Type value, MLIRContext *context) {
+  return AttributeUniquer::get<TypeAttr>(context, Attribute::Kind::Type, value);
+}
+
 Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 
 //===----------------------------------------------------------------------===//
 // FunctionAttr
 //===----------------------------------------------------------------------===//
 
+FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) {
+  assert(value && "Cannot get FunctionAttr for a null function");
+  return AttributeUniquer::get<FunctionAttr>(context, Attribute::Kind::Function,
+                                             value);
+}
+
+/// This function is used by the internals of the Function class to null out
+/// attributes referring to functions that are about to be deleted.
+void FunctionAttr::dropFunctionReference(Function *value) {
+  AttributeUniquer::erase<FunctionAttr>(value->getContext(),
+                                        Attribute::Kind::Function, value);
+}
+
 Function *FunctionAttr::getValue() const {
   return static_cast<ImplType *>(attr)->value;
 }
@@ -208,6 +313,14 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
 // SplatElementsAttr
 //===----------------------------------------------------------------------===//
 
+SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
+                                         Attribute elt) {
+  assert(elt.cast<NumericAttr>().getType() == type.getElementType() &&
+         "value should be of the given type");
+  return AttributeUniquer::get<SplatElementsAttr>(
+      type.getContext(), Attribute::Kind::SplatElements, type, elt);
+}
+
 Attribute SplatElementsAttr::getValue() const {
   return static_cast<ImplType *>(attr)->elt;
 }
@@ -237,6 +350,70 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const {
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//
 
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
+                                         ArrayRef<char> data) {
+  assert((type.getSizeInBits() <= data.size() * APInt::APINT_WORD_SIZE) &&
+         "Input data bit size should be larger than that type requires");
+
+  Attribute::Kind kind;
+  switch (type.getElementType().getKind()) {
+  case StandardTypes::BF16:
+  case StandardTypes::F16:
+  case StandardTypes::F32:
+  case StandardTypes::F64:
+    kind = Attribute::Kind::DenseFPElements;
+    break;
+  case StandardTypes::Integer:
+    kind = Attribute::Kind::DenseIntElements;
+    break;
+  default:
+    llvm_unreachable("unexpected element type");
+  }
+  return AttributeUniquer::get<DenseElementsAttr>(type.getContext(), kind, type,
+                                                  data);
+}
+
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
+                                         ArrayRef<Attribute> values) {
+  assert(type.getElementType().isIntOrFloat() &&
+         "expected int or float element type");
+  assert(values.size() == type.getNumElements() &&
+         "expected 'values' to contain the same number of elements as 'type'");
+
+  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+  // with double semantics.
+  auto eltType = type.getElementType();
+  size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
+
+  // Compress the attribute values into a character buffer.
+  SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
+                            APInt::APINT_WORD_SIZE);
+  APInt intVal;
+  for (unsigned i = 0, e = values.size(); i < e; ++i) {
+    switch (eltType.getKind()) {
+    case StandardTypes::BF16:
+    case StandardTypes::F16:
+    case StandardTypes::F32:
+    case StandardTypes::F64:
+      assert(eltType == values[i].cast<FloatAttr>().getType() &&
+             "expected attribute value to have element type");
+      intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
+      break;
+    case StandardTypes::Integer:
+      assert(eltType == values[i].cast<IntegerAttr>().getType() &&
+             "expected attribute value to have element type");
+      intVal = values[i].cast<IntegerAttr>().getValue();
+      break;
+    default:
+      llvm_unreachable("unexpected element type");
+    }
+    assert(intVal.getBitWidth() == bitWidth &&
+           "expected value to have same bitwidth as element type");
+    writeBits(data.data(), i * bitWidth, intVal);
+  }
+  return get(type, data);
+}
+
 /// Returns the number of elements held by this attribute.
 size_t DenseElementsAttr::size() const { return getType().getNumElements(); }
 
@@ -457,6 +634,15 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
 // OpaqueElementsAttr
 //===----------------------------------------------------------------------===//
 
+OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
+                                           VectorOrTensorType type,
+                                           StringRef bytes) {
+  assert(TensorType::isValidElementType(type.getElementType()) &&
+         "Input element type should be a valid tensor element type");
+  return AttributeUniquer::get<OpaqueElementsAttr>(
+      type.getContext(), Attribute::Kind::OpaqueElements, type, dialect, bytes);
+}
+
 StringRef OpaqueElementsAttr::getValue() const {
   return static_cast<ImplType *>(attr)->bytes;
 }
@@ -483,6 +669,16 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
 // SparseElementsAttr
 //===----------------------------------------------------------------------===//
 
+SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
+                                           DenseIntElementsAttr indices,
+                                           DenseElementsAttr values) {
+  assert(indices.getType().getElementType().isInteger(64) &&
+         "expected sparse indices to be 64-bit integer values");
+  return AttributeUniquer::get<SparseElementsAttr>(
+      type.getContext(), Attribute::Kind::SparseElements, type, indices,
+      values);
+}
+
 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
   return static_cast<ImplType *>(attr)->indices;
 }
index cad3a28..9760599 100644 (file)
@@ -195,70 +195,6 @@ struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
   }
 };
 
-struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
-  // Float attributes are uniqued based on wrapped APFloat.
-  using KeyTy = std::pair<Type, APFloat>;
-  using DenseMapInfo<FloatAttributeStorage *>::isEqual;
-
-  static unsigned getHashValue(FloatAttributeStorage *key) {
-    return getHashValue(KeyTy(key->type, key->getValue()));
-  }
-
-  static unsigned getHashValue(KeyTy key) {
-    return hash_combine(key.first, llvm::hash_value(key.second));
-  }
-
-  static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
-    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
-      return false;
-    return lhs.first == rhs->type && lhs.second.bitwiseIsEqual(rhs->getValue());
-  }
-};
-
-struct IntegerAttrKeyInfo : DenseMapInfo<IntegerAttributeStorage *> {
-  // Integer attributes are uniqued based on wrapped APInt.
-  using KeyTy = std::pair<Type, APInt>;
-  using DenseMapInfo<IntegerAttributeStorage *>::isEqual;
-
-  static unsigned getHashValue(IntegerAttributeStorage *key) {
-    return getHashValue(KeyTy(key->type, key->getValue()));
-  }
-
-  static unsigned getHashValue(KeyTy key) {
-    return hash_combine(key.first, llvm::hash_value(key.second));
-  }
-
-  static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) {
-    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
-      return false;
-    assert((lhs.first.isIndex() || (lhs.first.isa<IntegerType>() &&
-                                    lhs.first.cast<IntegerType>().getWidth() ==
-                                        lhs.second.getBitWidth())) &&
-           "mismatching integer type and value bitwidth");
-    return lhs.first == rhs->type && lhs.second == rhs->getValue();
-  }
-};
-
-struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
-  // Array attributes are uniqued based on their elements.
-  using KeyTy = ArrayRef<Attribute>;
-  using DenseMapInfo<ArrayAttributeStorage *>::isEqual;
-
-  static unsigned getHashValue(ArrayAttributeStorage *key) {
-    return getHashValue(KeyTy(key->value));
-  }
-
-  static unsigned getHashValue(KeyTy key) {
-    return hash_combine_range(key.begin(), key.end());
-  }
-
-  static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) {
-    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
-      return false;
-    return lhs == rhs->value;
-  }
-};
-
 struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
   // Array attributes are uniqued based on their elements.
   using KeyTy = ArrayRef<NamedAttribute>;
@@ -279,51 +215,6 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
   }
 };
 
-struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
-  using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
-  using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
-
-  static unsigned getHashValue(DenseElementsAttributeStorage *key) {
-    return getHashValue(KeyTy(key->type, key->data));
-  }
-
-  static unsigned getHashValue(KeyTy key) {
-    return hash_combine(
-        key.first, hash_combine_range(key.second.begin(), key.second.end()));
-  }
-
-  static bool isEqual(const KeyTy &lhs,
-                      const DenseElementsAttributeStorage *rhs) {
-    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
-      return false;
-    return lhs == std::make_pair(rhs->type, rhs->data);
-  }
-};
-
-struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
-  // Opaque element attributes are uniqued based on their dialect, type and
-  // value.
-  using KeyTy = std::tuple<Dialect *, VectorOrTensorType, StringRef>;
-  using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
-
-  static unsigned getHashValue(OpaqueElementsAttributeStorage *key) {
-    return getHashValue(KeyTy(key->dialect, key->type, key->bytes));
-  }
-
-  static unsigned getHashValue(KeyTy key) {
-    auto bytes = std::get<2>(key);
-    return hash_combine(std::get<0>(key), std::get<1>(key),
-                        hash_combine_range(bytes.begin(), bytes.end()));
-  }
-
-  static bool isEqual(const KeyTy &lhs,
-                      const OpaqueElementsAttributeStorage *rhs) {
-    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
-      return false;
-    return lhs == std::make_tuple(rhs->dialect, rhs->type, rhs->bytes);
-  }
-};
-
 struct CallSiteLocationKeyInfo : DenseMapInfo<CallSiteLocationStorage *> {
   // Call locations are uniqued based on their held concret location
   // and the caller location.
@@ -492,36 +383,15 @@ public:
   //===--------------------------------------------------------------------===//
   // Attribute uniquing
   //===--------------------------------------------------------------------===//
+  StorageUniquer attributeUniquer;
 
-  // Attribute allocator and mutex for thread safety.
+  // Attribute list allocator and mutex for thread safety.
   llvm::BumpPtrAllocator attributeAllocator;
   llvm::sys::SmartRWMutex<true> attributeMutex;
 
-  UnitAttributeStorage unitAttr;
-  BoolAttributeStorage *boolAttrs[2] = {nullptr};
-  DenseSet<IntegerAttributeStorage *, IntegerAttrKeyInfo> integerAttrs;
-  DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
-  llvm::StringMap<StringAttributeStorage *> stringAttrs;
-  using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
-  ArrayAttrSet arrayAttrs;
-  DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
-  DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
-  DenseMap<Type, TypeAttributeStorage *> typeAttrs;
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
-  DenseMap<Function *, FunctionAttributeStorage *> functionAttrs;
-  DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
-      splatElementsAttrs;
-  using DenseElementsAttrSet =
-      DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
-  DenseElementsAttrSet denseElementsAttrs;
-  using OpaqueElementsAttrSet =
-      DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
-  OpaqueElementsAttrSet opaqueElementsAttrs;
-  DenseMap<std::tuple<Type, Attribute, Attribute>,
-           SparseElementsAttributeStorage *>
-      sparseElementsAttrs;
 
 public:
   MLIRContextImpl()
@@ -985,234 +855,10 @@ const Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx,
 // Attribute uniquing
 //===----------------------------------------------------------------------===//
 
-UnitAttr UnitAttr::get(MLIRContext *context) {
-  return &context->getImpl().unitAttr;
-}
-
-BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  { // Check for an existing instance in read-only mode.
-    llvm::sys::SmartScopedReader<true> attributeLock(impl.attributeMutex);
-    if (auto *result = impl.boolAttrs[value])
-      return result;
-  }
-
-  // Aquire the mutex in write mode so that we can safely construct the new
-  // instance.
-  llvm::sys::SmartScopedWriter<true> attributeLock(impl.attributeMutex);
-
-  // Check for an existing instance again here, because another writer thread
-  // may have already created one.
-  auto *&result = impl.boolAttrs[value];
-  if (result)
-    return result;
-
-  result = impl.attributeAllocator.Allocate<BoolAttributeStorage>();
-  new (result) BoolAttributeStorage(IntegerType::get(1, context), value);
-  return result;
-}
-
-IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
-  auto &impl = type.getContext()->getImpl();
-  IntegerAttrKeyInfo::KeyTy key({type, value});
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.integerAttrs, key, impl.attributeMutex, [&] {
-    auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
-
-    auto byteSize =
-        IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
-    auto rawMem = impl.attributeAllocator.Allocate(
-        byteSize, alignof(IntegerAttributeStorage));
-    auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
-    std::uninitialized_copy(elements.begin(), elements.end(),
-                            result->getTrailingObjects<uint64_t>());
-    return result;
-  });
-}
-
-IntegerAttr IntegerAttr::get(Type type, int64_t value) {
-  // This uses 64 bit APInts by default for index type.
-  if (type.isIndex())
-    return get(type, APInt(64, value));
-
-  auto intType = type.dyn_cast<IntegerType>();
-  assert(intType && "expected an integer type for an integer attribute");
-  return get(type, APInt(intType.getWidth(), value));
-}
-
-static FloatAttr getFloatAttr(Type type, double value,
-                              llvm::Optional<Location> loc) {
-  if (!type.isa<FloatType>()) {
-    if (loc)
-      type.getContext()->emitError(*loc, "expected floating point type");
-    return nullptr;
-  }
-
-  // Treat BF16 as double because it is not supported in LLVM's APFloat.
-  // TODO(jpienaar): add BF16 support to APFloat?
-  if (type.isBF16() || type.isF64())
-    return FloatAttr::get(type, APFloat(value));
-
-  // This handles, e.g., F16 because there is no APFloat constructor for it.
-  bool unused;
-  APFloat val(value);
-  val.convert(type.cast<FloatType>().getFloatSemantics(),
-              APFloat::rmNearestTiesToEven, &unused);
-  return FloatAttr::get(type, val);
-}
-
-FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
-  return getFloatAttr(type, value, loc);
-}
-
-FloatAttr FloatAttr::get(Type type, double value) {
-  auto res = getFloatAttr(type, value, /*loc=*/llvm::None);
-  assert(res && "failed to construct float attribute");
-  return res;
-}
-
-FloatAttr FloatAttr::get(Type type, const APFloat &value) {
-  auto fltType = type.cast<FloatType>();
-  assert(&fltType.getFloatSemantics() == &value.getSemantics() &&
-         "FloatAttr type doesn't match the type implied by its value");
-  (void)fltType;
-  auto &impl = type.getContext()->getImpl();
-  FloatAttrKeyInfo::KeyTy key({type, value});
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.floatAttrs, key, impl.attributeMutex, [&] {
-    const auto &apint = value.bitcastToAPInt();
-    // Here one word's bitwidth equals to that of uint64_t.
-    auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
-
-    auto byteSize =
-        FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
-    auto rawMem = impl.attributeAllocator.Allocate(
-        byteSize, alignof(FloatAttributeStorage));
-    auto result = ::new (rawMem)
-        FloatAttributeStorage(value.getSemantics(), type, elements.size());
-    std::uninitialized_copy(elements.begin(), elements.end(),
-                            result->getTrailingObjects<uint64_t>());
-    return result;
-  });
-}
-
-StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  { // Check for an existing instance in read-only mode.
-    llvm::sys::SmartScopedReader<true> attributeLock(impl.attributeMutex);
-    auto it = impl.stringAttrs.find(bytes);
-    if (it != impl.stringAttrs.end())
-      return it->second;
-  }
-
-  // Aquire the mutex in write mode so that we can safely construct the new
-  // instance.
-  llvm::sys::SmartScopedWriter<true> attributeLock(impl.attributeMutex);
-
-  // Check for an existing instance again here, because another writer thread
-  // may have already created one.
-  auto it = impl.stringAttrs.insert({bytes, nullptr}).first;
-  if (it->second)
-    return it->second;
-
-  auto result = new (impl.attributeAllocator.Allocate<StringAttributeStorage>())
-      StringAttributeStorage(it->first());
-  return it->second = result;
-}
-
-ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.arrayAttrs, value, impl.attributeMutex, [&] {
-    auto *result = impl.attributeAllocator.Allocate<ArrayAttributeStorage>();
-
-    // Copy the elements into the bump pointer.
-    value = copyArrayRefInto(impl.attributeAllocator, value);
-
-    // Check to see if any of the elements have a function attr.
-    bool hasFunctionAttr = false;
-    for (auto elt : value)
-      if (elt.isOrContainsFunction()) {
-        hasFunctionAttr = true;
-        break;
-      }
-
-    // Initialize the memory using placement new.
-    return new (result) ArrayAttributeStorage(hasFunctionAttr, value);
-  });
-}
-
-AffineMapAttr AffineMapAttr::get(AffineMap value) {
-  auto *context = value.getResult(0).getContext();
-  auto &impl = context->getImpl();
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.affineMapAttrs, value, impl.attributeMutex, [&] {
-    auto result = impl.attributeAllocator.Allocate<AffineMapAttributeStorage>();
-    return new (result) AffineMapAttributeStorage(value);
-  });
-}
-
-IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
-  auto *context = value.getConstraint(0).getContext();
-  auto &impl = context->getImpl();
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.integerSetAttrs, value, impl.attributeMutex, [&] {
-    auto result =
-        impl.attributeAllocator.Allocate<IntegerSetAttributeStorage>();
-    return new (result) IntegerSetAttributeStorage(value);
-  });
-}
-
-TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.typeAttrs, type, impl.attributeMutex, [&] {
-    auto result = impl.attributeAllocator.Allocate<TypeAttributeStorage>();
-    return new (result) TypeAttributeStorage(type);
-  });
-}
-
-FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) {
-  assert(value && "Cannot get FunctionAttr for a null function");
-  auto &impl = context->getImpl();
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(impl.functionAttrs, value, impl.attributeMutex, [&] {
-    auto result = impl.attributeAllocator.Allocate<FunctionAttributeStorage>();
-    return new (result) FunctionAttributeStorage(value);
-  });
-}
-
-/// This function is used by the internals of the Function class to null out
-/// attributes referring to functions that are about to be deleted.
-void FunctionAttr::dropFunctionReference(Function *value) {
-  auto &impl = value->getContext()->getImpl();
-
-  // Aquire the mutex in write mode so that we can safely remove the attribute
-  // if it exists.
-  llvm::sys::SmartScopedWriter<true> attributeLock(impl.attributeMutex);
-
-  // Check to see if there was an attribute referring to this function.
-  auto &functionAttrs = impl.functionAttrs;
-
-  // If not, then we're done.
-  auto it = functionAttrs.find(value);
-  if (it == functionAttrs.end())
-    return;
-
-  // If so, null out the function reference in the attribute (to avoid dangling
-  // pointers) and remove the entry from the map so the map doesn't contain
-  // dangling keys.
-  it->second->value = nullptr;
-  functionAttrs.erase(it);
+/// Returns the storage uniquer used for constructing attribute storage
+/// instances. This should not be used directly.
+StorageUniquer &MLIRContext::getAttributeUniquer() {
+  return getImpl().attributeUniquer;
 }
 
 /// Perform a three-way comparison between the names of the specified
@@ -1281,168 +927,6 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
   });
 }
 
-// Returns false if the given `attr` is not of the given `type`.
-// Note: This function is only intended to be used for assertion. So it's
-// possibly allowing invalid cases that are unimplemented.
-static bool attrIsOfType(Attribute attr, Type type) {
-  if (auto floatAttr = attr.dyn_cast<FloatAttr>())
-    return floatAttr.getType() == type;
-  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
-    return intAttr.getType() == type;
-  if (auto elementsAttr = attr.dyn_cast<ElementsAttr>())
-    return elementsAttr.getType() == type;
-  // TODO: check the other cases
-  return true;
-}
-
-SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
-                                         Attribute elt) {
-  auto attr = elt.dyn_cast<NumericAttr>();
-  assert(attr && "expected numeric value");
-  assert(attr.getType() == type.getElementType() &&
-         "value should be of the given type");
-  (void)attr;
-
-  auto &impl = type.getContext()->getImpl();
-
-  // Safely get or create an attribute instance.
-  std::pair<Type, Attribute> key(type, elt);
-  return safeGetOrCreate(
-      impl.splatElementsAttrs, key, impl.attributeMutex, [&] {
-        auto result =
-            impl.attributeAllocator.Allocate<SplatElementsAttributeStorage>();
-        return new (result) SplatElementsAttributeStorage(type, elt);
-      });
-}
-
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
-                                         ArrayRef<char> data) {
-  auto bitsRequired = type.getSizeInBits();
-  (void)bitsRequired;
-  assert((bitsRequired <= data.size() * APInt::APINT_WORD_SIZE) &&
-         "Input data bit size should be larger than that type requires");
-
-  auto &impl = type.getContext()->getImpl();
-  DenseElementsAttrInfo::KeyTy key({type, data});
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(
-      impl.denseElementsAttrs, key, impl.attributeMutex, [&] {
-        Attribute::Kind kind;
-        switch (type.getElementType().getKind()) {
-        case StandardTypes::BF16:
-        case StandardTypes::F16:
-        case StandardTypes::F32:
-        case StandardTypes::F64:
-          kind = Attribute::Kind::DenseFPElements;
-          break;
-        case StandardTypes::Integer:
-          kind = Attribute::Kind::DenseIntElements;
-          break;
-        default:
-          llvm_unreachable("unexpected element type");
-        }
-
-        // If the data buffer is non-empty, we copy it into the context.
-        ArrayRef<char> copy;
-        if (!data.empty()) {
-          // Rounding up the allocate size to multiples of APINT_WORD_SIZE, so
-          // the `readBits` will not fail when it accesses multiples of
-          // APINT_WORD_SIZE each time.
-          size_t sizeToAllocate =
-              llvm::alignTo(data.size(), APInt::APINT_WORD_SIZE);
-          auto *rawCopy =
-              (char *)impl.attributeAllocator.Allocate(sizeToAllocate, 64);
-          std::uninitialized_copy(data.begin(), data.end(), rawCopy);
-          copy = {rawCopy, data.size()};
-        }
-        auto *result =
-            impl.attributeAllocator.Allocate<DenseElementsAttributeStorage>();
-        return new (result) DenseElementsAttributeStorage(kind, type, copy);
-      });
-}
-
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
-                                         ArrayRef<Attribute> values) {
-  assert(type.getElementType().isIntOrFloat() &&
-         "expected int or float element type");
-  assert(values.size() == type.getNumElements() &&
-         "expected 'values' to contain the same number of elements as 'type'");
-
-  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
-  // with double semantics.
-  auto eltType = type.getElementType();
-  size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
-
-  // Compress the attribute values into a character buffer.
-  SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
-                            APInt::APINT_WORD_SIZE);
-  APInt intVal;
-  for (unsigned i = 0, e = values.size(); i < e; ++i) {
-    switch (eltType.getKind()) {
-    case StandardTypes::BF16:
-    case StandardTypes::F16:
-    case StandardTypes::F32:
-    case StandardTypes::F64:
-      assert(eltType == values[i].cast<FloatAttr>().getType() &&
-             "expected attribute value to have element type");
-      intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
-      break;
-    case StandardTypes::Integer:
-      assert(eltType == values[i].cast<IntegerAttr>().getType() &&
-             "expected attribute value to have element type");
-      intVal = values[i].cast<IntegerAttr>().getValue();
-      break;
-    default:
-      llvm_unreachable("unexpected element type");
-    }
-    assert(intVal.getBitWidth() == bitWidth &&
-           "expected value to have same bitwidth as element type");
-    writeBits(data.data(), i * bitWidth, intVal);
-  }
-  return get(type, data);
-}
-
-OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
-                                           VectorOrTensorType type,
-                                           StringRef bytes) {
-  assert(TensorType::isValidElementType(type.getElementType()) &&
-         "Input element type should be a valid tensor element type");
-
-  auto &impl = type.getContext()->getImpl();
-  OpaqueElementsAttrInfo::KeyTy key(dialect, type, bytes);
-
-  return safeGetOrCreate(
-      impl.opaqueElementsAttrs, key, impl.attributeMutex, [&] {
-        auto *result =
-            impl.attributeAllocator.Allocate<OpaqueElementsAttributeStorage>();
-
-        // TODO: Provide a way to avoid copying content of large opaque tensors
-        // This will likely require a new reference attribute kind.
-        bytes = bytes.copy(impl.attributeAllocator);
-        return new (result)
-            OpaqueElementsAttributeStorage(type, dialect, bytes);
-      });
-}
-
-SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
-                                           DenseIntElementsAttr indices,
-                                           DenseElementsAttr values) {
-  assert(indices.getType().getElementType().isInteger(64) &&
-         "expected sparse indices to be 64-bit integer values");
-
-  auto &impl = type.getContext()->getImpl();
-  auto key = std::make_tuple(type, indices, values);
-
-  // Safely get or create an attribute instance.
-  return safeGetOrCreate(
-      impl.sparseElementsAttrs, key, impl.attributeMutex, [&] {
-        return new (
-            impl.attributeAllocator.Allocate<SparseElementsAttributeStorage>())
-            SparseElementsAttributeStorage(type, indices, values);
-      });
-}
-
 //===----------------------------------------------------------------------===//
 // AffineMap and AffineExpr uniquing
 //===----------------------------------------------------------------------===//
index 14d8f39..6fd55e7 100644 (file)
@@ -105,6 +105,23 @@ struct StorageUniquerImpl {
     return result = initializeStorage(kind, ctorFn);
   }
 
+  /// Erase an instance of a complex derived type.
+  void erase(unsigned kind, unsigned hashValue,
+             llvm::function_ref<bool(const BaseStorage *)> isEqual,
+             llvm::function_ref<void(BaseStorage *)> cleanupFn) {
+    LookupKey lookupKey{kind, hashValue, isEqual};
+
+    // Acquire a writer-lock so that we can safely erase the type instance.
+    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+    auto existing = storageTypes.find_as(lookupKey);
+    if (existing == storageTypes.end())
+      return;
+
+    // Cleanup the storage and remove it from the map.
+    cleanupFn(existing->storage);
+    storageTypes.erase(existing);
+  }
+
   //===--------------------------------------------------------------------===//
   // Instance Storage
   //===--------------------------------------------------------------------===//
@@ -179,3 +196,12 @@ auto StorageUniquer::getImpl(
     -> BaseStorage * {
   return impl->getOrCreate(kind, ctorFn);
 }
+
+/// Implementation for erasing an instance of a derived type with complex
+/// storage.
+void StorageUniquer::eraseImpl(
+    unsigned kind, unsigned hashValue,
+    llvm::function_ref<bool(const BaseStorage *)> isEqual,
+    std::function<void(BaseStorage *)> cleanupFn) {
+  impl->erase(kind, hashValue, isEqual, cleanupFn);
+}