namespace mlir {
class AbstractOperation;
-class MLIRContextImpl;
-class Location;
class Dialect;
+class Location;
+class MLIRContextImpl;
+class StorageUniquer;
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
// MLIRContextImpl type.
MLIRContextImpl &getImpl() { return *impl.get(); }
+ /// Returns the storage uniquer used for constructing type storage instances.
+ /// This should not be used directly.
+ StorageUniquer &getTypeUniquer();
+
private:
const std::unique_ptr<MLIRContextImpl> impl;
#define MLIR_IR_TYPE_SUPPORT_H
#include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
+#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
} // end namespace detail
/// Base storage class appearing in a Type.
-class TypeStorage {
+class TypeStorage : public StorageUniquer::BaseStorage {
friend detail::TypeUniquer;
+ friend StorageUniquer;
protected:
/// This constructor is used by derived classes as part of the TypeUniquer.
/// When using this constructor, the initializeTypeInfo function must be
/// invoked afterwards for the storage to be valid.
TypeStorage(unsigned subclassData = 0)
- : dialect(nullptr), kind(0), subclassData(subclassData) {}
+ : dialect(nullptr), subclassData(subclassData) {}
public:
/// Get the dialect that this type is registered to.
assert(dialect && "Malformed type storage object.");
return *dialect;
}
-
- /// Get the kind classification of this type.
- unsigned getKind() const { return kind; }
-
/// Get the subclass data.
unsigned getSubclassData() const { return subclassData; }
void setSubclassData(unsigned val) { subclassData = val; }
private:
- // Constructor used for simple type storage that have no subclass data. This
- // constructor should not be used by derived storage classes.
- TypeStorage(const Dialect &dialect, unsigned kind)
- : dialect(&dialect), kind(kind), subclassData(0) {}
-
- // Initialize an existing type storage with a kind and a context. This is used
- // by the TypeUniquer when initializing a newly constructed derived type
- // storage object.
- void initializeTypeInfo(const Dialect &newDialect, unsigned newKind) {
- dialect = &newDialect;
- kind = newKind;
- }
+ // Set the dialect for this storage instance. This is used by the TypeUniquer
+ // when initializing a newly constructed type storage object.
+ void initializeDialect(const Dialect &newDialect) { dialect = &newDialect; }
/// The registered information for the current type.
const Dialect *dialect;
- /// Classification of the subclass, used for type checking.
- unsigned kind;
-
/// Space for subclasses to store data.
unsigned subclassData;
};
// This is a utility allocator used to allocate memory for instances of derived
// Types.
-class TypeStorageAllocator {
-public:
- /// Copy the specified array of elements into memory managed by our bump
- /// pointer allocator. This assumes the elements are all PODs.
- template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
- if (elements.empty())
- return llvm::None;
- auto result = allocator.Allocate<T>(elements.size());
- std::uninitialized_copy(elements.begin(), elements.end(), result);
- return ArrayRef<T>(result, elements.size());
- }
-
- /// Copy the provided string into memory managed by our bump pointer
- /// allocator.
- StringRef copyInto(StringRef str) {
- auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
- return StringRef(result.data(), str.size());
- }
-
- // Allocate an instance of the provided type.
- template <typename T> T *allocate() { return allocator.Allocate<T>(); }
-
- /// Allocate 'size' bytes of 'alignment' aligned memory.
- void *allocate(size_t size, size_t alignment) {
- return allocator.Allocate(size, alignment);
- }
-
-private:
- /// The raw allocator for type storage objects.
- llvm::BumpPtrAllocator allocator;
-};
+using TypeStorageAllocator = StorageUniquer::StorageAllocator;
//===----------------------------------------------------------------------===//
// TypeUniquer
static typename std::enable_if<
!std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+ // Lookup an instance of this complex storage type.
using ImplType = typename T::ImplType;
-
- // Construct a value of the derived key type.
- auto derivedKey = getKey<ImplType>(args...);
-
- // Create a hash of the kind and the derived key.
- unsigned hashValue = getHash<ImplType>(kind, derivedKey);
-
- // Generate an equality function for the derived storage.
- std::function<bool(const TypeStorage *)> isEqual =
- [&derivedKey](const TypeStorage *existing) {
- return static_cast<const ImplType &>(*existing) == derivedKey;
- };
-
- // Generate a constructor function for the derived storage.
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn =
- [&](TypeStorageAllocator &allocator) {
- TypeStorage *storage = ImplType::construct(allocator, derivedKey);
- storage->initializeTypeInfo(lookupDialectForType<T>(ctx), kind);
- return storage;
- };
-
- // Get an instance for the derived storage.
- return T(getImpl(ctx, kind, hashValue, isEqual, constructorFn));
+ return ctx->getTypeUniquer().getComplex<ImplType>(
+ [&](ImplType *storage) {
+ storage->initializeDialect(lookupDialectForType<T>(ctx));
+ },
+ kind, std::forward<Args>(args)...);
}
/// Get an uniqued instance of a type T. This overload is used for derived
static typename std::enable_if<
std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
get(MLIRContext *ctx, unsigned kind) {
- auto constructorFn = [=](TypeStorageAllocator &allocator) {
- return new (allocator.allocate<DefaultTypeStorage>())
- DefaultTypeStorage(lookupDialectForType<T>(ctx), kind);
- };
- return T(getImpl(ctx, kind, constructorFn));
+ // Lookup an instance of this simple storage type.
+ return ctx->getTypeUniquer().getSimple<TypeStorage>(
+ [&](TypeStorage *storage) {
+ storage->initializeDialect(lookupDialectForType<T>(ctx));
+ },
+ kind);
}
private:
- /// Implementation for getting/creating an instance of a derived type with
- /// complex storage.
- static TypeStorage *
- getImpl(MLIRContext *ctx, unsigned kind, unsigned hashValue,
- llvm::function_ref<bool(const TypeStorage *)> isEqual,
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
-
- /// Implementation for getting/creating an instance of a derived type with
- /// default storage.
- static TypeStorage *
- getImpl(MLIRContext *ctx, unsigned kind,
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
-
/// Get the dialect that the type 'T' was registered with.
template <typename T>
static const Dialect &lookupDialectForType(MLIRContext *ctx) {
/// Get the dialect that registered the type with the provided typeid.
static const Dialect &lookupDialectForType(MLIRContext *ctx,
const TypeID *const typeID);
-
- //===--------------------------------------------------------------------===//
- // Util
- //===--------------------------------------------------------------------===//
-
- /// Utilities for detecting if specific traits hold for a given type 'T'.
- template <typename...> using void_t = void;
- template <class, template <class...> class Op, class... Args>
- struct detector {
- using value_t = std::false_type;
- };
- template <template <class...> class Op, class... Args>
- struct detector<void_t<Op<Args...>>, Op, Args...> {
- using value_t = std::true_type;
- };
- template <template <class...> class Op, class... Args>
- using is_detected = typename detector<void, Op, Args...>::value_t;
-
- //===--------------------------------------------------------------------===//
- // Key Construction
- //===--------------------------------------------------------------------===//
-
- /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
- template <typename ImplTy, typename... Args>
- using has_impltype_getkey_t =
- decltype(ImplTy::getKey(std::declval<Args>()...));
-
- /// Used to construct an instance of 'ImplType::KeyTy' if there is an
- /// 'ImplTy::getKey' function for the provided arguments.
- template <typename ImplTy, typename... Args>
- static typename std::enable_if<
- is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
- typename ImplTy::KeyTy>::type
- getKey(Args &&... args) {
- return ImplTy::getKey(args...);
- }
- /// If there is no 'ImplTy::getKey' method, then we try to directly construct
- /// the 'ImplTy::KeyTy' with the provided arguments.
- template <typename ImplTy, typename... Args>
- static typename std::enable_if<
- !is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
- typename ImplTy::KeyTy>::type
- getKey(Args &&... args) {
- return typename ImplTy::KeyTy(args...);
- }
-
- //===--------------------------------------------------------------------===//
- // Key and Kind Hashing
- //===--------------------------------------------------------------------===//
-
- /// Trait to check if ImplType provides a 'hashKey' method for 'T'.
- template <typename ImplType, typename T>
- using has_impltype_hash_t = decltype(ImplType::hashKey(std::declval<T>()));
-
- /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
- /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
- template <typename ImplTy, typename DerivedKey>
- static typename std::enable_if<
- is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
- ::llvm::hash_code>::type
- getHash(unsigned kind, const DerivedKey &derivedKey) {
- return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
- }
- /// If there is no 'ImplTy::hashKey' default to using the
- /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
- template <typename ImplTy, typename DerivedKey>
- static typename std::enable_if<
- !is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
- ::llvm::hash_code>::type
- getHash(unsigned kind, const DerivedKey &derivedKey) {
- return llvm::hash_combine(
- kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
- }
};
} // namespace detail
/// instance of the type within its kind.
/// * The key type must be constructible from the values passed into the
/// detail::TypeUniquer::get call after the type kind.
-/// * The key type must have a llvm::DenseMapInfo specialization for
-/// hashing.
+/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
+/// storage class must define a hashing method:
+/// 'static unsigned hashKey(const KeyTy &)'
///
/// - Provide a method, 'bool operator==(const KeyTy &) const', to
/// compare the storage instance against an instance of the key type.
--- /dev/null
+//===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
+#define MLIR_SUPPORT_STORAGEUNIQUER_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+
+namespace mlir {
+namespace detail {
+struct StorageUniquerImpl;
+} // namespace detail
+
+/// A utility class to get, or create instances of storage classes. These
+/// storage classes must respect the following constraints:
+/// - Derive from StorageUniquer::BaseStorage.
+/// - Provide an unsigned 'kind' value to be used as part of the unique'ing
+/// process.
+///
+/// For non-parametric storage classes, i.e. those that are solely uniqued by
+/// their kind, nothing else is needed. Instances of these classes can be
+/// queried with 'getSimple'.
+///
+/// Otherwise, the parametric storage classes may be queried with 'getComplex',
+/// and must respect the following:
+/// - Define a type alias, KeyTy, to a type that uniquely identifies the
+/// instance of the storage class within its kind.
+/// * The key type must be constructible from the values passed into the
+/// getComplex call after the kind.
+/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
+/// storage class must define a hashing method:
+/// 'static unsigned hashKey(const KeyTy &)'
+///
+/// - Provide a method, 'bool operator==(const KeyTy &) const', to
+/// compare the storage instance against an instance of the key type.
+///
+/// - Provide a construction method:
+/// 'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
+/// that builds a unique instance of the derived storage. The arguments to
+/// this function are an allocator to store any uniqued data and the key
+/// type for this storage.
+class StorageUniquer {
+public:
+ StorageUniquer();
+ ~StorageUniquer();
+
+ /// This class acts as the base storage that all storage classes must derived
+ /// from.
+ class BaseStorage {
+ public:
+ /// Get the kind classification of this storage.
+ unsigned getKind() const { return kind; }
+
+ protected:
+ BaseStorage() : kind(0) {}
+
+ private:
+ /// Allow access to the kind field.
+ friend detail::StorageUniquerImpl;
+
+ /// Classification of the subclass, used for type checking.
+ unsigned kind;
+ };
+
+ /// This is a utility allocator used to allocate memory for instances of
+ /// derived types.
+ class StorageAllocator {
+ public:
+ /// Copy the specified array of elements into memory managed by our bump
+ /// pointer allocator. This assumes the elements are all PODs.
+ template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
+ if (elements.empty())
+ return llvm::None;
+ auto result = allocator.Allocate<T>(elements.size());
+ std::uninitialized_copy(elements.begin(), elements.end(), result);
+ return ArrayRef<T>(result, elements.size());
+ }
+
+ /// Copy the provided string into memory managed by our bump pointer
+ /// allocator.
+ StringRef copyInto(StringRef str) {
+ auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
+ return StringRef(result.data(), str.size());
+ }
+
+ /// Allocate an instance of the provided type.
+ template <typename T> T *allocate() { return allocator.Allocate<T>(); }
+
+ /// Allocate 'size' bytes of 'alignment' aligned memory.
+ void *allocate(size_t size, size_t alignment) {
+ return allocator.Allocate(size, alignment);
+ }
+
+ private:
+ /// The raw allocator for type storage objects.
+ llvm::BumpPtrAllocator allocator;
+ };
+
+ /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
+ /// that can be used to initialize a newly inserted storage instance. This
+ /// overload is used for derived types that have complex storage or uniquing
+ /// constraints.
+ template <typename Storage, typename... Args>
+ Storage *getComplex(std::function<void(Storage *)> initFn, unsigned kind,
+ Args &&... args) {
+ // Construct a value of the derived key type.
+ auto derivedKey = getKey<Storage>(args...);
+
+ // Create a hash of the kind and the derived key.
+ unsigned hashValue = getHash<Storage>(kind, derivedKey);
+
+ // Generate an equality function for the derived storage.
+ std::function<bool(const BaseStorage *)> isEqual =
+ [&derivedKey](const BaseStorage *existing) {
+ return static_cast<const Storage &>(*existing) == derivedKey;
+ };
+
+ // Generate a constructor function for the derived storage.
+ std::function<BaseStorage *(StorageAllocator &)> ctorFn =
+ [&](StorageAllocator &allocator) {
+ auto *storage = Storage::construct(allocator, derivedKey);
+ if (initFn)
+ initFn(storage);
+ return storage;
+ };
+
+ // Get an instance for the derived storage.
+ return static_cast<Storage *>(getImpl(kind, hashValue, isEqual, ctorFn));
+ }
+
+ /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
+ /// that can be used to initialize a newly inserted storage instance. This
+ /// overload is used for derived types that use no additional storage or
+ /// uniquing outside of the kind.
+ template <typename Storage>
+ Storage *getSimple(std::function<void(Storage *)> initFn, unsigned kind) {
+ auto ctorFn = [&](StorageAllocator &allocator) {
+ auto *storage = new (allocator.allocate<Storage>()) Storage();
+ if (initFn)
+ initFn(storage);
+ return storage;
+ };
+ return static_cast<Storage *>(getImpl(kind, ctorFn));
+ }
+
+private:
+ /// Implementation for getting/creating an instance of a derived type with
+ /// complex storage.
+ BaseStorage *getImpl(unsigned kind, unsigned hashValue,
+ llvm::function_ref<bool(const BaseStorage *)> isEqual,
+ std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+
+ /// Implementation for getting/creating an instance of a derived type with
+ /// default storage.
+ BaseStorage *getImpl(unsigned kind,
+ std::function<BaseStorage *(StorageAllocator &)> ctorFn);
+
+ /// The internal implementation class.
+ std::unique_ptr<detail::StorageUniquerImpl> impl;
+
+ //===--------------------------------------------------------------------===//
+ // Util
+ //===--------------------------------------------------------------------===//
+
+ /// Utilities for detecting if specific traits hold for a given type 'T'.
+ template <typename...> using void_t = void;
+ template <class, template <class...> class Op, class... Args>
+ struct detector {
+ using value_t = std::false_type;
+ };
+ template <template <class...> class Op, class... Args>
+ struct detector<void_t<Op<Args...>>, Op, Args...> {
+ using value_t = std::true_type;
+ };
+ template <template <class...> class Op, class... Args>
+ using is_detected = typename detector<void, Op, Args...>::value_t;
+
+ //===--------------------------------------------------------------------===//
+ // Key Construction
+ //===--------------------------------------------------------------------===//
+
+ /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
+ template <typename ImplTy, typename... Args>
+ using has_impltype_getkey_t =
+ decltype(ImplTy::getKey(std::declval<Args>()...));
+
+ /// Used to construct an instance of 'ImplTy::KeyTy' if there is an
+ /// 'ImplTy::getKey' function for the provided arguments.
+ template <typename ImplTy, typename... Args>
+ static typename std::enable_if<
+ is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
+ typename ImplTy::KeyTy>::type
+ getKey(Args &&... args) {
+ return ImplTy::getKey(args...);
+ }
+ /// If there is no 'ImplTy::getKey' method, then we try to directly construct
+ /// the 'ImplTy::KeyTy' with the provided arguments.
+ template <typename ImplTy, typename... Args>
+ static typename std::enable_if<
+ !is_detected<has_impltype_getkey_t, ImplTy, Args...>::value,
+ typename ImplTy::KeyTy>::type
+ getKey(Args &&... args) {
+ return typename ImplTy::KeyTy(args...);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Key and Kind Hashing
+ //===--------------------------------------------------------------------===//
+
+ /// Trait to check if ImplTy provides a 'hashKey' method for 'T'.
+ template <typename ImplTy, typename T>
+ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
+
+ /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
+ /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
+ template <typename ImplTy, typename DerivedKey>
+ static typename std::enable_if<
+ is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
+ ::llvm::hash_code>::type
+ getHash(unsigned kind, const DerivedKey &derivedKey) {
+ return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
+ }
+ /// If there is no 'ImplTy::hashKey' default to using the
+ /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
+ template <typename ImplTy, typename DerivedKey>
+ static typename std::enable_if<
+ !is_detected<has_impltype_hash_t, ImplTy, DerivedKey>::value,
+ ::llvm::hash_code>::type
+ getHash(unsigned kind, const DerivedKey &derivedKey) {
+ return llvm::hash_combine(
+ kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
+ }
+};
+} // end namespace mlir
+
+#endif
return lhs == std::make_pair(rhs->getLocations(), rhs->metadata);
}
};
-
-/// This is the implementation of the TypeUniquer class.
-struct TypeUniquerImpl {
- /// A lookup key for derived instances of TypeStorage objects.
- struct TypeLookupKey {
- /// The known derived kind for the storage.
- unsigned kind;
-
- /// The known hash value of the key.
- unsigned hashValue;
-
- /// An equality function for comparing with an existing storage instance.
- llvm::function_ref<bool(const TypeStorage *)> isEqual;
- };
-
- /// A utility wrapper object representing a hashed storage object. This class
- /// contains a storage object and an existing computed hash value.
- struct HashedStorageType {
- unsigned hashValue;
- TypeStorage *storage;
- };
-
- /// Get or create an instance of a complex derived type.
- TypeStorage *getOrCreate(
- unsigned kind, unsigned hashValue,
- llvm::function_ref<bool(const TypeStorage *)> isEqual,
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
- TypeLookupKey lookupKey{kind, hashValue, isEqual};
-
- { // Check for an existing instance in read-only mode.
- llvm::sys::SmartScopedReader<true> typeLock(typeMutex);
- auto it = storageTypes.find_as(lookupKey);
- if (it != storageTypes.end())
- return it->storage;
- }
-
- // Aquire a writer-lock so that we can safely create the new type instance.
- llvm::sys::SmartScopedWriter<true> typeLock(typeMutex);
-
- // Check for an existing instance again here, because another writer thread
- // may have already created one.
- auto existing = storageTypes.insert_as({}, lookupKey);
- if (!existing.second)
- return existing.first->storage;
-
- // Otherwise, construct and initialize the derived storage for this type
- // instance.
- TypeStorage *storage = constructorFn(allocator);
- *existing.first = HashedStorageType{hashValue, storage};
- return storage;
- }
-
- /// Get or create an instance of a simple derived type.
- TypeStorage *getOrCreate(
- unsigned kind,
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
- return safeGetOrCreate(simpleTypes, kind, typeMutex,
- [&] { return constructorFn(allocator); });
- }
-
- //===--------------------------------------------------------------------===//
- // Instance Storage
- //===--------------------------------------------------------------------===//
-
- /// Storage info for derived TypeStorage objects.
- struct StorageKeyInfo : DenseMapInfo<HashedStorageType> {
- static HashedStorageType getEmptyKey() {
- return HashedStorageType{0, DenseMapInfo<TypeStorage *>::getEmptyKey()};
- }
- static HashedStorageType getTombstoneKey() {
- return HashedStorageType{0,
- DenseMapInfo<TypeStorage *>::getTombstoneKey()};
- }
-
- static unsigned getHashValue(const HashedStorageType &key) {
- return key.hashValue;
- }
- static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; }
-
- static bool isEqual(const HashedStorageType &lhs,
- const HashedStorageType &rhs) {
- return lhs.storage == rhs.storage;
- }
- static bool isEqual(const TypeLookupKey &lhs,
- const HashedStorageType &rhs) {
- if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
- return false;
- // If the lookup kind matches the kind of the storage, then invoke the
- // equality function on the lookup key.
- return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
- }
- };
-
- // Unique types with specific hashing or storage constraints.
- using StorageTypeSet = llvm::DenseSet<HashedStorageType, StorageKeyInfo>;
- StorageTypeSet storageTypes;
-
- // Unique types with just the kind.
- DenseMap<unsigned, TypeStorage *> simpleTypes;
-
- // Allocator to use when constructing derived type instances.
- TypeStorageAllocator allocator;
-
- // A mutex to keep type uniquing thread-safe.
- llvm::sys::SmartRWMutex<true> typeMutex;
-};
} // end anonymous namespace.
namespace mlir {
//===--------------------------------------------------------------------===//
// Type uniquing
//===--------------------------------------------------------------------===//
- TypeUniquerImpl typeUniquer;
+ StorageUniquer typeUniquer;
//===--------------------------------------------------------------------===//
// Attribute uniquing
// Type uniquing
//===----------------------------------------------------------------------===//
-/// Implementation for getting/creating an instance of a derived type with
-/// complex storage.
-TypeStorage *TypeUniquer::getImpl(
- MLIRContext *ctx, unsigned kind, unsigned hashValue,
- llvm::function_ref<bool(const TypeStorage *)> isEqual,
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
- return ctx->getImpl().typeUniquer.getOrCreate(kind, hashValue, isEqual,
- constructorFn);
-}
-
-/// Implementation for getting/creating an instance of a derived type with
-/// default storage.
-TypeStorage *TypeUniquer::getImpl(
- MLIRContext *ctx, unsigned kind,
- std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
- return ctx->getImpl().typeUniquer.getOrCreate(kind, constructorFn);
-}
+/// Returns the storage unqiuer used for constructing type storage instances.
+/// This should not be used directly.
+StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
/// Get the dialect that registered the type with the provided typeid.
const Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx,
add_llvm_library(MLIRSupport
FileUtilities.cpp
+ StorageUniquer.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Support
--- /dev/null
+//===- StorageUniquer.cpp - Common Storage Class Uniquer --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Support/StorageUniquer.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/RWMutex.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace mlir {
+namespace detail {
+/// This is the implementation of the StorageUniquer class.
+struct StorageUniquerImpl {
+ using BaseStorage = StorageUniquer::BaseStorage;
+ using StorageAllocator = StorageUniquer::StorageAllocator;
+
+ /// A lookup key for derived instances of storage objects.
+ struct LookupKey {
+ /// The known derived kind for the storage.
+ unsigned kind;
+
+ /// The known hash value of the key.
+ unsigned hashValue;
+
+ /// An equality function for comparing with an existing storage instance.
+ llvm::function_ref<bool(const BaseStorage *)> isEqual;
+ };
+
+ /// A utility wrapper object representing a hashed storage object. This class
+ /// contains a storage object and an existing computed hash value.
+ struct HashedStorage {
+ unsigned hashValue;
+ BaseStorage *storage;
+ };
+
+ /// Get or create an instance of a complex derived type.
+ BaseStorage *
+ getOrCreate(unsigned kind, unsigned hashValue,
+ llvm::function_ref<bool(const BaseStorage *)> isEqual,
+ llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+ LookupKey lookupKey{kind, hashValue, isEqual};
+
+ // Check for an existing instance in read-only mode.
+ {
+ llvm::sys::SmartScopedReader<true> typeLock(mutex);
+ auto it = storageTypes.find_as(lookupKey);
+ if (it != storageTypes.end())
+ return it->storage;
+ }
+
+ // Acquire a writer-lock so that we can safely create the new type instance.
+ llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+
+ // Check for an existing instance again here, because another writer thread
+ // may have already created one.
+ auto existing = storageTypes.insert_as({}, lookupKey);
+ if (!existing.second)
+ return existing.first->storage;
+
+ // Otherwise, construct and initialize the derived storage for this type
+ // instance.
+ BaseStorage *storage = initializeStorage(kind, ctorFn);
+ *existing.first = HashedStorage{hashValue, storage};
+ return storage;
+ }
+
+ /// Get or create an instance of a simple derived type.
+ BaseStorage *
+ getOrCreate(unsigned kind,
+ llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+ // Check for an existing instance in read-only mode.
+ {
+ llvm::sys::SmartScopedReader<true> typeLock(mutex);
+ auto it = simpleTypes.find(kind);
+ if (it != simpleTypes.end())
+ return it->second;
+ }
+
+ // Acquire a writer-lock so that we can safely create the new type instance.
+ llvm::sys::SmartScopedWriter<true> typeLock(mutex);
+
+ // Check for an existing instance again here, because another writer thread
+ // may have already created one.
+ auto &result = simpleTypes[kind];
+ if (result)
+ return result;
+
+ // Otherwise, create and return a new storage instance.
+ return result = initializeStorage(kind, ctorFn);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Instance Storage
+ //===--------------------------------------------------------------------===//
+
+ /// Utility to create and initialize a storage instance.
+ BaseStorage *initializeStorage(
+ unsigned kind,
+ llvm::function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+ BaseStorage *storage = ctorFn(allocator);
+ storage->kind = kind;
+ return storage;
+ }
+
+ /// Storage info for derived TypeStorage objects.
+ struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
+ static HashedStorage getEmptyKey() {
+ return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
+ }
+ static HashedStorage getTombstoneKey() {
+ return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
+ }
+
+ static unsigned getHashValue(const HashedStorage &key) {
+ return key.hashValue;
+ }
+ static unsigned getHashValue(LookupKey key) { return key.hashValue; }
+
+ static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
+ return lhs.storage == rhs.storage;
+ }
+ static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
+ if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
+ return false;
+ // If the lookup kind matches the kind of the storage, then invoke the
+ // equality function on the lookup key.
+ return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
+ }
+ };
+
+ // Unique types with specific hashing or storage constraints.
+ using StorageTypeSet = llvm::DenseSet<HashedStorage, StorageKeyInfo>;
+ StorageTypeSet storageTypes;
+
+ // Unique types with just the kind.
+ DenseMap<unsigned, BaseStorage *> simpleTypes;
+
+ // Allocator to use when constructing derived type instances.
+ StorageUniquer::StorageAllocator allocator;
+
+ // A mutex to keep type uniquing thread-safe.
+ llvm::sys::SmartRWMutex<true> mutex;
+};
+} // end namespace detail
+} // namespace mlir
+
+StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
+StorageUniquer::~StorageUniquer() {}
+
+/// Implementation for getting/creating an instance of a derived type with
+/// complex storage.
+auto StorageUniquer::getImpl(
+ unsigned kind, unsigned hashValue,
+ llvm::function_ref<bool(const BaseStorage *)> isEqual,
+ std::function<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
+ return impl->getOrCreate(kind, hashValue, isEqual, ctorFn);
+}
+
+/// Implementation for getting/creating an instance of a derived type with
+/// default storage.
+auto StorageUniquer::getImpl(
+ unsigned kind, std::function<BaseStorage *(StorageAllocator &)> ctorFn)
+ -> BaseStorage * {
+ return impl->getOrCreate(kind, ctorFn);
+}