Refactor the generic storage object uniquing functionality from TypeUniquer into...
authorRiver Riddle <riverriddle@google.com>
Fri, 26 Apr 2019 04:01:21 +0000 (21:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:17:08 +0000 (08:17 -0700)
--

PiperOrigin-RevId: 245358744

mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/include/mlir/Support/StorageUniquer.h [new file with mode: 0644]
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Support/CMakeLists.txt
mlir/lib/Support/StorageUniquer.cpp [new file with mode: 0644]

index 85eb62d..ed2640d 100644 (file)
 
 namespace mlir {
 class AbstractOperation;
-class MLIRContextImpl;
-class Location;
 class Dialect;
+class Location;
+class MLIRContextImpl;
+class StorageUniquer;
 
 /// MLIRContext is the top-level object for a collection of MLIR modules.  It
 /// holds immortal uniqued objects like types, and the tables used to unique
@@ -93,6 +94,10 @@ public:
   // MLIRContextImpl type.
   MLIRContextImpl &getImpl() { return *impl.get(); }
 
+  /// Returns the storage uniquer used for constructing type storage instances.
+  /// This should not be used directly.
+  StorageUniquer &getTypeUniquer();
+
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
index 3c6c877..f174d6d 100644 (file)
@@ -23,9 +23,7 @@
 #define MLIR_IR_TYPE_SUPPORT_H
 
 #include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
+#include "mlir/Support/StorageUniquer.h"
 #include "llvm/ADT/StringRef.h"
 #include <memory>
 
@@ -51,15 +49,16 @@ class TypeUniquer;
 } // end namespace detail
 
 /// Base storage class appearing in a Type.
-class TypeStorage {
+class TypeStorage : public StorageUniquer::BaseStorage {
   friend detail::TypeUniquer;
+  friend StorageUniquer;
 
 protected:
   /// This constructor is used by derived classes as part of the TypeUniquer.
   /// When using this constructor, the initializeTypeInfo function must be
   /// invoked afterwards for the storage to be valid.
   TypeStorage(unsigned subclassData = 0)
-      : dialect(nullptr), kind(0), subclassData(subclassData) {}
+      : dialect(nullptr), subclassData(subclassData) {}
 
 public:
   /// Get the dialect that this type is registered to.
@@ -67,10 +66,6 @@ public:
     assert(dialect && "Malformed type storage object.");
     return *dialect;
   }
-
-  /// Get the kind classification of this type.
-  unsigned getKind() const { return kind; }
-
   /// Get the subclass data.
   unsigned getSubclassData() const { return subclassData; }
 
@@ -78,25 +73,13 @@ public:
   void setSubclassData(unsigned val) { subclassData = val; }
 
 private:
-  // Constructor used for simple type storage that have no subclass data. This
-  // constructor should not be used by derived storage classes.
-  TypeStorage(const Dialect &dialect, unsigned kind)
-      : dialect(&dialect), kind(kind), subclassData(0) {}
-
-  // Initialize an existing type storage with a kind and a context. This is used
-  // by the TypeUniquer when initializing a newly constructed derived type
-  // storage object.
-  void initializeTypeInfo(const Dialect &newDialect, unsigned newKind) {
-    dialect = &newDialect;
-    kind = newKind;
-  }
+  // Set the dialect for this storage instance. This is used by the TypeUniquer
+  // when initializing a newly constructed type storage object.
+  void initializeDialect(const Dialect &newDialect) { dialect = &newDialect; }
 
   /// The registered information for the current type.
   const Dialect *dialect;
 
-  /// Classification of the subclass, used for type checking.
-  unsigned kind;
-
   /// Space for subclasses to store data.
   unsigned subclassData;
 };
@@ -111,37 +94,7 @@ using DefaultTypeStorage = TypeStorage;
 
 // This is a utility allocator used to allocate memory for instances of derived
 // Types.
-class TypeStorageAllocator {
-public:
-  /// Copy the specified array of elements into memory managed by our bump
-  /// pointer allocator.  This assumes the elements are all PODs.
-  template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
-    if (elements.empty())
-      return llvm::None;
-    auto result = allocator.Allocate<T>(elements.size());
-    std::uninitialized_copy(elements.begin(), elements.end(), result);
-    return ArrayRef<T>(result, elements.size());
-  }
-
-  /// Copy the provided string into memory managed by our bump pointer
-  /// allocator.
-  StringRef copyInto(StringRef str) {
-    auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
-    return StringRef(result.data(), str.size());
-  }
-
-  // Allocate an instance of the provided type.
-  template <typename T> T *allocate() { return allocator.Allocate<T>(); }
-
-  /// Allocate 'size' bytes of 'alignment' aligned memory.
-  void *allocate(size_t size, size_t alignment) {
-    return allocator.Allocate(size, alignment);
-  }
-
-private:
-  /// The raw allocator for type storage objects.
-  llvm::BumpPtrAllocator allocator;
-};
+using TypeStorageAllocator = StorageUniquer::StorageAllocator;
 
 //===----------------------------------------------------------------------===//
 // TypeUniquer
@@ -157,30 +110,13 @@ public:
   static typename std::enable_if<
       !std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
   get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+    // Lookup an instance of this complex storage type.
     using ImplType = typename T::ImplType;
-
-    // Construct a value of the derived key type.
-    auto derivedKey = getKey<ImplType>(args...);
-
-    // Create a hash of the kind and the derived key.
-    unsigned hashValue = getHash<ImplType>(kind, derivedKey);
-
-    // Generate an equality function for the derived storage.
-    std::function<bool(const TypeStorage *)> isEqual =
-        [&derivedKey](const TypeStorage *existing) {
-          return static_cast<const ImplType &>(*existing) == derivedKey;
-        };
-
-    // Generate a constructor function for the derived storage.
-    std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn =
-        [&](TypeStorageAllocator &allocator) {
-          TypeStorage *storage = ImplType::construct(allocator, derivedKey);
-          storage->initializeTypeInfo(lookupDialectForType<T>(ctx), kind);
-          return storage;
-        };
-
-    // Get an instance for the derived storage.
-    return T(getImpl(ctx, kind, hashValue, isEqual, constructorFn));
+    return ctx->getTypeUniquer().getComplex<ImplType>(
+        [&](ImplType *storage) {
+          storage->initializeDialect(lookupDialectForType<T>(ctx));
+        },
+        kind, std::forward<Args>(args)...);
   }
 
   /// Get an uniqued instance of a type T. This overload is used for derived
@@ -190,27 +126,15 @@ public:
   static typename std::enable_if<
       std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
   get(MLIRContext *ctx, unsigned kind) {
-    auto constructorFn = [=](TypeStorageAllocator &allocator) {
-      return new (allocator.allocate<DefaultTypeStorage>())
-          DefaultTypeStorage(lookupDialectForType<T>(ctx), kind);
-    };
-    return T(getImpl(ctx, kind, constructorFn));
+    // Lookup an instance of this simple storage type.
+    return ctx->getTypeUniquer().getSimple<TypeStorage>(
+        [&](TypeStorage *storage) {
+          storage->initializeDialect(lookupDialectForType<T>(ctx));
+        },
+        kind);
   }
 
 private:
-  /// Implementation for getting/creating an instance of a derived type with
-  /// complex storage.
-  static TypeStorage *
-  getImpl(MLIRContext *ctx, unsigned kind, unsigned hashValue,
-          llvm::function_ref<bool(const TypeStorage *)> isEqual,
-          std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
-
-  /// Implementation for getting/creating an instance of a derived type with
-  /// default storage.
-  static TypeStorage *
-  getImpl(MLIRContext *ctx, unsigned kind,
-          std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
-
   /// Get the dialect that the type 'T' was registered with.
   template <typename T>
   static const Dialect &lookupDialectForType(MLIRContext *ctx) {
@@ -220,79 +144,6 @@ private:
   /// Get the dialect that registered the type with the provided typeid.
   static const Dialect &lookupDialectForType(MLIRContext *ctx,
                                              const TypeID *const typeID);
-
-  //===--------------------------------------------------------------------===//
-  // Util
-  //===--------------------------------------------------------------------===//
-
-  /// Utilities for detecting if specific traits hold for a given type 'T'.
-  template <typename...> using void_t = void;
-  template <class, template <class...> class Op, class... Args>
-  struct detector {
-    using value_t = std::false_type;
-  };
-  template <template <class...> class Op, class... Args>
-  struct detector<void_t<Op<Args...>>, Op, Args...> {
-    using value_t = std::true_type;
-  };
-  template <template <class...> class Op, class... Args>
-  using is_detected = typename detector<void, Op, Args...>::value_t;
-
-  //===--------------------------------------------------------------------===//
-  // Key Construction
-  //===--------------------------------------------------------------------===//
-
-  /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
-  template <typename ImplTy, typename... Args>
-  using has_impltype_getkey_t =
-      decltype(ImplTy::getKey(std::declval<Args>()...));
-
-  /// Used to construct an instance of 'ImplType::KeyTy' if there is an
-  /// 'ImplTy::getKey' function for the provided arguments.
-  template <typename ImplTy, typename... Args>
-  static typename std::enable_if<
-      is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
-      typename ImplTy::KeyTy>::type
-  getKey(Args &&... args) {
-    return ImplTy::getKey(args...);
-  }
-  /// If there is no 'ImplTy::getKey' method, then we try to directly construct
-  /// the 'ImplTy::KeyTy' with the provided arguments.
-  template <typename ImplTy, typename... Args>
-  static typename std::enable_if<
-      !is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
-      typename ImplTy::KeyTy>::type
-  getKey(Args &&... args) {
-    return typename ImplTy::KeyTy(args...);
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Key and Kind Hashing
-  //===--------------------------------------------------------------------===//
-
-  /// Trait to check if ImplType provides a 'hashKey' method for 'T'.
-  template <typename ImplType, typename T>
-  using has_impltype_hash_t = decltype(ImplType::hashKey(std::declval<T>()));
-
-  /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
-  /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
-  template <typename ImplTy, typename DerivedKey>
-  static typename std::enable_if<
-      is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
-      ::llvm::hash_code>::type
-  getHash(unsigned kind, const DerivedKey &derivedKey) {
-    return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
-  }
-  /// If there is no 'ImplTy::hashKey' default to using the
-  /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
-  template <typename ImplTy, typename DerivedKey>
-  static typename std::enable_if<
-      !is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
-      ::llvm::hash_code>::type
-  getHash(unsigned kind, const DerivedKey &derivedKey) {
-    return llvm::hash_combine(
-        kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
-  }
 };
 } // namespace detail
 
index 951259b..e18b131 100644 (file)
@@ -80,8 +80,9 @@ struct OpaqueTypeStorage;
 ///      instance of the type within its kind.
 ///      * The key type must be constructible from the values passed into the
 ///        detail::TypeUniquer::get call after the type kind.
-///      * The key type must have a llvm::DenseMapInfo specialization for
-///        hashing.
+///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
+///        storage class must define a hashing method:
+///         'static unsigned hashKey(const KeyTy &)'
 ///
 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
 ///      compare the storage instance against an instance of the key type.
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
new file mode 100644 (file)
index 0000000..8a0c590
--- /dev/null
@@ -0,0 +1,252 @@
+//===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
+#define MLIR_SUPPORT_STORAGEUNIQUER_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+
+namespace mlir {
+namespace detail {
+struct StorageUniquerImpl;
+} // namespace detail
+
+/// A utility class to get, or create instances of storage classes. These
+/// storage classes must respect the following constraints:
+///    - Derive from StorageUniquer::BaseStorage.
+///    - Provide an unsigned 'kind' value to be used as part of the unique'ing
+///      process.
+///
+/// For non-parametric storage classes, i.e. those that are solely uniqued by
+/// their kind, nothing else is needed. Instances of these classes can be
+/// queried with 'getSimple'.
+///
+/// Otherwise, the parametric storage classes may be queried with 'getComplex',
+/// and must respect the following:
+///    - Define a type alias, KeyTy, to a type that uniquely identifies the
+///      instance of the storage class within its kind.
+///      * The key type must be constructible from the values passed into the
+///        getComplex call after the kind.
+///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
+///        storage class must define a hashing method:
+///         'static unsigned hashKey(const KeyTy &)'
+///
+///    - Provide a method, 'bool operator==(const KeyTy &) const', to
+///      compare the storage instance against an instance of the key type.
+///
+///    - Provide a construction method:
+///        'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
+///      that builds a unique instance of the derived storage. The arguments to
+///      this function are an allocator to store any uniqued data and the key
+///      type for this storage.
+class StorageUniquer {
+public:
+  StorageUniquer();
+  ~StorageUniquer();
+
+  /// This class acts as the base storage that all storage classes must derived
+  /// from.
+  class BaseStorage {
+  public:
+    /// Get the kind classification of this storage.
+    unsigned getKind() const { return kind; }
+
+  protected:
+    BaseStorage() : kind(0) {}
+
+  private:
+    /// Allow access to the kind field.
+    friend detail::StorageUniquerImpl;
+
+    /// Classification of the subclass, used for type checking.
+    unsigned kind;
+  };
+
+  /// This is a utility allocator used to allocate memory for instances of
+  /// derived types.
+  class StorageAllocator {
+  public:
+    /// Copy the specified array of elements into memory managed by our bump
+    /// pointer allocator.  This assumes the elements are all PODs.
+    template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
+      if (elements.empty())
+        return llvm::None;
+      auto result = allocator.Allocate<T>(elements.size());
+      std::uninitialized_copy(elements.begin(), elements.end(), result);
+      return ArrayRef<T>(result, elements.size());
+    }
+
+    /// Copy the provided string into memory managed by our bump pointer
+    /// allocator.
+    StringRef copyInto(StringRef str) {
+      auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
+      return StringRef(result.data(), str.size());
+    }
+
+    /// Allocate an instance of the provided type.
+    template <typename T> T *allocate() { return allocator.Allocate<T>(); }
+
+    /// Allocate 'size' bytes of 'alignment' aligned memory.
+    void *allocate(size_t size, size_t alignment) {
+      return allocator.Allocate(size, alignment);
+    }
+
+  private:
+    /// The raw allocator for type storage objects.
+    llvm::BumpPtrAllocator allocator;
+  };
+
+  /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
+  /// that can be used to initialize a newly inserted storage instance. This
+  /// overload is used for derived types that have complex storage or uniquing
+  /// constraints.
+  template <typename Storage, typename... Args>
+  Storage *getComplex(std::function<void(Storage *)> initFn, unsigned kind,
+                      Args &&... args) {
+    // Construct a value of the derived key type.
+    auto derivedKey = getKey<Storage>(args...);
+
+    // Create a hash of the kind and the derived key.
+    unsigned hashValue = getHash<Storage>(kind, derivedKey);
+
+    // Generate an equality function for the derived storage.
+    std::function<bool(const BaseStorage *)> isEqual =
+        [&derivedKey](const BaseStorage *existing) {
+          return static_cast<const Storage &>(*existing) == derivedKey;
+        };
+
+    // Generate a constructor function for the derived storage.
+    std::function<BaseStorage *(StorageAllocator &)> ctorFn =
+        [&](StorageAllocator &allocator) {
+          auto *storage = Storage::construct(allocator, derivedKey);
+          if (initFn)
+            initFn(storage);
+          return storage;
+        };
+
+    // Get an instance for the derived storage.
+    return static_cast<Storage *>(getImpl(kind, hashValue, isEqual, ctorFn));
+  }
+
+  /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
+  /// that can be used to initialize a newly inserted storage instance. This
+  /// overload is used for derived types that use no additional storage or
+  /// uniquing outside of the kind.
+  template <typename Storage>
+  Storage *getSimple(std::function<void(Storage *)> initFn, unsigned kind) {
+    auto ctorFn = [&](StorageAllocator &allocator) {
+      auto *storage = new (allocator.allocate<Storage>()) Storage();
+      if (initFn)
+        initFn(storage);
+      return storage;
+    };
+    return static_cast<Storage *>(getImpl(kind, ctorFn));
+  }
+
+private:
+  /// Implementation for getting/creating an instance of a derived type with
+  /// complex storage.
+  BaseStorage *getImpl(unsigned kind, unsigned hashValue,
+                       llvm::function_ref<bool(const BaseStorage *)> isEqual,
+                       std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+
+  /// Implementation for getting/creating an instance of a derived type with
+  /// default storage.
+  BaseStorage *getImpl(unsigned kind,
+                       std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+
+  /// The internal implementation class.
+  std::unique_ptr<detail::StorageUniquerImpl> impl;
+
+  //===--------------------------------------------------------------------===//
+  // Util
+  //===--------------------------------------------------------------------===//
+
+  /// Utilities for detecting if specific traits hold for a given type 'T'.
+  template <typename...> using void_t = void;
+  template <class, template <class...> class Op, class... Args>
+  struct detector {
+    using value_t = std::false_type;
+  };
+  template <template <class...> class Op, class... Args>
+  struct detector<void_t<Op<Args...>>, Op, Args...> {
+    using value_t = std::true_type;
+  };
+  template <template <class...> class Op, class... Args>
+  using is_detected = typename detector<void, Op, Args...>::value_t;
+
+  //===--------------------------------------------------------------------===//
+  // Key Construction
+  //===--------------------------------------------------------------------===//
+
+  /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
+  template <typename ImplTy, typename... Args>
+  using has_impltype_getkey_t =
+      decltype(ImplTy::getKey(std::declval<Args>()...));
+
+  /// Used to construct an instance of 'ImplTy::KeyTy' if there is an
+  /// 'ImplTy::getKey' function for the provided arguments.
+  template <typename ImplTy, typename... Args>
+  static typename std::enable_if<
+      is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
+      typename ImplTy::KeyTy>::type
+  getKey(Args &&... args) {
+    return ImplTy::getKey(args...);
+  }
+  /// If there is no 'ImplTy::getKey' method, then we try to directly construct
+  /// the 'ImplTy::KeyTy' with the provided arguments.
+  template <typename ImplTy, typename... Args>
+  static typename std::enable_if<
+      !is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
+      typename ImplTy::KeyTy>::type
+  getKey(Args &&... args) {
+    return typename ImplTy::KeyTy(args...);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Key and Kind Hashing
+  //===--------------------------------------------------------------------===//
+
+  /// Trait to check if ImplTy provides a 'hashKey' method for 'T'.
+  template <typename ImplTy, typename T>
+  using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
+
+  /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
+  /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
+  template <typename ImplTy, typename DerivedKey>
+  static typename std::enable_if<
+      is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
+      ::llvm::hash_code>::type
+  getHash(unsigned kind, const DerivedKey &derivedKey) {
+    return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
+  }
+  /// If there is no 'ImplTy::hashKey' default to using the
+  /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
+  template <typename ImplTy, typename DerivedKey>
+  static typename std::enable_if<
+      !is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
+      ::llvm::hash_code>::type
+  getHash(unsigned kind, const DerivedKey &derivedKey) {
+    return llvm::hash_combine(
+        kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
+  }
+};
+} // end namespace mlir
+
+#endif
index c0df0e0..e630914 100644 (file)
@@ -363,112 +363,6 @@ struct FusedLocKeyInfo : DenseMapInfo<FusedLocationStorage *> {
     return lhs == std::make_pair(rhs->getLocations(), rhs->metadata);
   }
 };
-
-/// This is the implementation of the TypeUniquer class.
-struct TypeUniquerImpl {
-  /// A lookup key for derived instances of TypeStorage objects.
-  struct TypeLookupKey {
-    /// The known derived kind for the storage.
-    unsigned kind;
-
-    /// The known hash value of the key.
-    unsigned hashValue;
-
-    /// An equality function for comparing with an existing storage instance.
-    llvm::function_ref<bool(const TypeStorage *)> isEqual;
-  };
-
-  /// A utility wrapper object representing a hashed storage object. This class
-  /// contains a storage object and an existing computed hash value.
-  struct HashedStorageType {
-    unsigned hashValue;
-    TypeStorage *storage;
-  };
-
-  /// Get or create an instance of a complex derived type.
-  TypeStorage *getOrCreate(
-      unsigned kind, unsigned hashValue,
-      llvm::function_ref<bool(const TypeStorage *)> isEqual,
-      std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
-    TypeLookupKey lookupKey{kind, hashValue, isEqual};
-
-    { // Check for an existing instance in read-only mode.
-      llvm::sys::SmartScopedReader<true> typeLock(typeMutex);
-      auto it = storageTypes.find_as(lookupKey);
-      if (it != storageTypes.end())
-        return it->storage;
-    }
-
-    // Aquire a writer-lock so that we can safely create the new type instance.
-    llvm::sys::SmartScopedWriter<true> typeLock(typeMutex);
-
-    // Check for an existing instance again here, because another writer thread
-    // may have already created one.
-    auto existing = storageTypes.insert_as({}, lookupKey);
-    if (!existing.second)
-      return existing.first->storage;
-
-    // Otherwise, construct and initialize the derived storage for this type
-    // instance.
-    TypeStorage *storage = constructorFn(allocator);
-    *existing.first = HashedStorageType{hashValue, storage};
-    return storage;
-  }
-
-  /// Get or create an instance of a simple derived type.
-  TypeStorage *getOrCreate(
-      unsigned kind,
-      std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
-    return safeGetOrCreate(simpleTypes, kind, typeMutex,
-                           [&] { return constructorFn(allocator); });
-  }
-
-  //===--------------------------------------------------------------------===//
-  // Instance Storage
-  //===--------------------------------------------------------------------===//
-
-  /// Storage info for derived TypeStorage objects.
-  struct StorageKeyInfo : DenseMapInfo<HashedStorageType> {
-    static HashedStorageType getEmptyKey() {
-      return HashedStorageType{0, DenseMapInfo<TypeStorage *>::getEmptyKey()};
-    }
-    static HashedStorageType getTombstoneKey() {
-      return HashedStorageType{0,
-                               DenseMapInfo<TypeStorage *>::getTombstoneKey()};
-    }
-
-    static unsigned getHashValue(const HashedStorageType &key) {
-      return key.hashValue;
-    }
-    static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; }
-
-    static bool isEqual(const HashedStorageType &lhs,
-                        const HashedStorageType &rhs) {
-      return lhs.storage == rhs.storage;
-    }
-    static bool isEqual(const TypeLookupKey &lhs,
-                        const HashedStorageType &rhs) {
-      if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
-        return false;
-      // If the lookup kind matches the kind of the storage, then invoke the
-      // equality function on the lookup key.
-      return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
-    }
-  };
-
-  // Unique types with specific hashing or storage constraints.
-  using StorageTypeSet = llvm::DenseSet<HashedStorageType, StorageKeyInfo>;
-  StorageTypeSet storageTypes;
-
-  // Unique types with just the kind.
-  DenseMap<unsigned, TypeStorage *> simpleTypes;
-
-  // Allocator to use when constructing derived type instances.
-  TypeStorageAllocator allocator;
-
-  // A mutex to keep type uniquing thread-safe.
-  llvm::sys::SmartRWMutex<true> typeMutex;
-};
 } // end anonymous namespace.
 
 namespace mlir {
@@ -570,7 +464,7 @@ public:
   //===--------------------------------------------------------------------===//
   // Type uniquing
   //===--------------------------------------------------------------------===//
-  TypeUniquerImpl typeUniquer;
+  StorageUniquer typeUniquer;
 
   //===--------------------------------------------------------------------===//
   // Attribute uniquing
@@ -953,23 +847,9 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
 // Type uniquing
 //===----------------------------------------------------------------------===//
 
-/// Implementation for getting/creating an instance of a derived type with
-/// complex storage.
-TypeStorage *TypeUniquer::getImpl(
-    MLIRContext *ctx, unsigned kind, unsigned hashValue,
-    llvm::function_ref<bool(const TypeStorage *)> isEqual,
-    std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
-  return ctx->getImpl().typeUniquer.getOrCreate(kind, hashValue, isEqual,
-                                                constructorFn);
-}
-
-/// Implementation for getting/creating an instance of a derived type with
-/// default storage.
-TypeStorage *TypeUniquer::getImpl(
-    MLIRContext *ctx, unsigned kind,
-    std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
-  return ctx->getImpl().typeUniquer.getOrCreate(kind, constructorFn);
-}
+/// Returns the storage unqiuer used for constructing type storage instances.
+/// This should not be used directly.
+StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
 /// Get the dialect that registered the type with the provided typeid.
 const Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx,
index 2d69b4d..97da45b 100644 (file)
@@ -1,5 +1,6 @@
 add_llvm_library(MLIRSupport
   FileUtilities.cpp
+  StorageUniquer.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support
diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
new file mode 100644 (file)
index 0000000..14d8f39
--- /dev/null
@@ -0,0 +1,181 @@
+//===- StorageUniquer.cpp - Common Storage Class Uniquer --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Support/StorageUniquer.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RWMutex.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace mlir {
+namespace detail {
+/// This is the implementation of the StorageUniquer class.
+struct StorageUniquerImpl {
+  using BaseStorage = StorageUniquer::BaseStorage;
+  using StorageAllocator = StorageUniquer::StorageAllocator;
+
+  /// A lookup key for derived instances of storage objects.
+  struct LookupKey {
+    /// The known derived kind for the storage.
+    unsigned kind;
+
+    /// The known hash value of the key.
+    unsigned hashValue;
+
+    /// An equality function for comparing with an existing storage instance.
+    llvm::function_ref<bool(const BaseStorage *)> isEqual;
+  };
+
+  /// A utility wrapper object representing a hashed storage object. This class
+  /// contains a storage object and an existing computed hash value.
+  struct HashedStorage {
+    unsigned hashValue;
+    BaseStorage *storage;
+  };
+
+  /// Get or create an instance of a complex derived type.
+  BaseStorage *
+  getOrCreate(unsigned kind, unsigned hashValue,
+              llvm::function_ref<bool(const BaseStorage *)> isEqual,
+              llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    LookupKey lookupKey{kind, hashValue, isEqual};
+
+    // Check for an existing instance in read-only mode.
+    {
+      llvm::sys::SmartScopedReader<true> typeLock(mutex);
+      auto it = storageTypes.find_as(lookupKey);
+      if (it != storageTypes.end())
+        return it->storage;
+    }
+
+    // Acquire a writer-lock so that we can safely create the new type instance.
+    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+
+    // Check for an existing instance again here, because another writer thread
+    // may have already created one.
+    auto existing = storageTypes.insert_as({}, lookupKey);
+    if (!existing.second)
+      return existing.first->storage;
+
+    // Otherwise, construct and initialize the derived storage for this type
+    // instance.
+    BaseStorage *storage = initializeStorage(kind, ctorFn);
+    *existing.first = HashedStorage{hashValue, storage};
+    return storage;
+  }
+
+  /// Get or create an instance of a simple derived type.
+  BaseStorage *
+  getOrCreate(unsigned kind,
+              llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    // Check for an existing instance in read-only mode.
+    {
+      llvm::sys::SmartScopedReader<true> typeLock(mutex);
+      auto it = simpleTypes.find(kind);
+      if (it != simpleTypes.end())
+        return it->second;
+    }
+
+    // Acquire a writer-lock so that we can safely create the new type instance.
+    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+
+    // Check for an existing instance again here, because another writer thread
+    // may have already created one.
+    auto &result = simpleTypes[kind];
+    if (result)
+      return result;
+
+    // Otherwise, create and return a new storage instance.
+    return result = initializeStorage(kind, ctorFn);
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Instance Storage
+  //===--------------------------------------------------------------------===//
+
+  /// Utility to create and initialize a storage instance.
+  BaseStorage *initializeStorage(
+      unsigned kind,
+      llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    BaseStorage *storage = ctorFn(allocator);
+    storage->kind = kind;
+    return storage;
+  }
+
+  /// Storage info for derived TypeStorage objects.
+  struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
+    static HashedStorage getEmptyKey() {
+      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
+    }
+    static HashedStorage getTombstoneKey() {
+      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
+    }
+
+    static unsigned getHashValue(const HashedStorage &key) {
+      return key.hashValue;
+    }
+    static unsigned getHashValue(LookupKey key) { return key.hashValue; }
+
+    static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
+      return lhs.storage == rhs.storage;
+    }
+    static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
+      if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
+        return false;
+      // If the lookup kind matches the kind of the storage, then invoke the
+      // equality function on the lookup key.
+      return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
+    }
+  };
+
+  // Unique types with specific hashing or storage constraints.
+  using StorageTypeSet = llvm::DenseSet<HashedStorage, StorageKeyInfo>;
+  StorageTypeSet storageTypes;
+
+  // Unique types with just the kind.
+  DenseMap<unsigned, BaseStorage *> simpleTypes;
+
+  // Allocator to use when constructing derived type instances.
+  StorageUniquer::StorageAllocator allocator;
+
+  // A mutex to keep type uniquing thread-safe.
+  llvm::sys::SmartRWMutex<true> mutex;
+};
+} // end namespace detail
+} // namespace mlir
+
+StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
+StorageUniquer::~StorageUniquer() {}
+
+/// Implementation for getting/creating an instance of a derived type with
+/// complex storage.
+auto StorageUniquer::getImpl(
+    unsigned kind, unsigned hashValue,
+    llvm::function_ref<bool(const BaseStorage *)> isEqual,
+    std::function<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
+  return impl->getOrCreate(kind, hashValue, isEqual, ctorFn);
+}
+
+/// Implementation for getting/creating an instance of a derived type with
+/// default storage.
+auto StorageUniquer::getImpl(
+    unsigned kind, std::function<BaseStorage *(StorageAllocator &)> ctorFn)
+    -> BaseStorage * {
+  return impl->getOrCreate(kind, ctorFn);
+}