Simplify TypeUniquer/AttributeUniquer to not require multiple overloads when...
authorRiver Riddle <riverriddle@google.com>
Thu, 2 May 2019 18:27:58 +0000 (11:27 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:24:50 +0000 (08:24 -0700)
--

PiperOrigin-RevId: 246356767

mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/IR/AttributeDetail.h

index f174d6d..fc25ea2 100644 (file)
@@ -104,34 +104,14 @@ namespace detail {
 // MLIRContext. This class manages all creation and uniquing of types.
 class TypeUniquer {
 public:
-  /// Get an uniqued instance of a type T. This overload is used for derived
-  /// types that have complex storage or uniquing constraints.
+  /// Get an uniqued instance of a type T.
   template <typename T, typename... Args>
-  static typename std::enable_if<
-      !std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
-  get(MLIRContext *ctx, unsigned kind, Args &&... args) {
-    // Lookup an instance of this complex storage type.
-    using ImplType = typename T::ImplType;
-    return ctx->getTypeUniquer().getComplex<ImplType>(
-        [&](ImplType *storage) {
-          storage->initializeDialect(lookupDialectForType<T>(ctx));
-        },
-        kind, std::forward<Args>(args)...);
-  }
-
-  /// Get an uniqued instance of a type T. This overload is used for derived
-  /// types that use the DefaultTypeStorage and thus need no additional storage
-  /// or uniquing.
-  template <typename T, typename... Args>
-  static typename std::enable_if<
-      std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
-  get(MLIRContext *ctx, unsigned kind) {
-    // Lookup an instance of this simple storage type.
-    return ctx->getTypeUniquer().getSimple<TypeStorage>(
+  static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+    return ctx->getTypeUniquer().get<typename T::ImplType>(
         [&](TypeStorage *storage) {
           storage->initializeDialect(lookupDialectForType<T>(ctx));
         },
-        kind);
+        kind, std::forward<Args>(args)...);
   }
 
 private:
index 2a9bb4a..5b408f3 100644 (file)
@@ -122,11 +122,12 @@ public:
   /// that can be used to initialize a newly inserted storage instance. This
   /// function is used for derived types that have complex storage or uniquing
   /// constraints.
-  template <typename Storage, typename... Args>
-  Storage *getComplex(std::function<void(Storage *)> initFn, unsigned kind,
-                      Args &&... args) {
+  template <typename Storage, typename Arg, typename... Args>
+  Storage *get(std::function<void(Storage *)> initFn, unsigned kind, Arg &&arg,
+               Args &&... args) {
     // Construct a value of the derived key type.
-    auto derivedKey = getKey<Storage>(args...);
+    auto derivedKey =
+        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
 
     // Create a hash of the kind and the derived key.
     unsigned hashValue = getHash<Storage>(kind, derivedKey);
@@ -155,7 +156,7 @@ public:
   /// function is used for derived types that use no additional storage or
   /// uniquing outside of the kind.
   template <typename Storage>
-  Storage *getSimple(std::function<void(Storage *)> initFn, unsigned kind) {
+  Storage *get(std::function<void(Storage *)> initFn, unsigned kind) {
     auto ctorFn = [&](StorageAllocator &allocator) {
       auto *storage = new (allocator.allocate<Storage>()) Storage();
       if (initFn)
@@ -167,10 +168,11 @@ public:
 
   /// Erases a uniqued instance of 'Storage'. This function is used for derived
   /// types that have complex storage or uniquing constraints.
-  template <typename Storage, typename... Args>
-  void eraseComplex(unsigned kind, Args &&... args) {
+  template <typename Storage, typename Arg, typename... Args>
+  void erase(unsigned kind, Arg &&arg, Args &&... args) {
     // Construct a value of the derived key type.
-    auto derivedKey = getKey<Storage>(args...);
+    auto derivedKey =
+        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
 
     // Create a hash of the kind and the derived key.
     unsigned hashValue = getHash<Storage>(kind, derivedKey);
index 3db5ec8..bb4d8de 100644 (file)
@@ -67,46 +67,24 @@ struct AttributeStorage : public StorageUniquer::BaseStorage {
 // MLIRContext. This class manages all creation and uniquing of attributes.
 class AttributeUniquer {
 public:
-  /// Get an uniqued instance of attribute T. This overload is used for
-  /// derived attributes that have complex storage or uniquing constraints.
+  /// Get an uniqued instance of attribute T.
   template <typename T, typename... Args>
-  static typename std::enable_if<
-      !std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
-  get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
-    return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
-        getInitFn(ctx), static_cast<unsigned>(kind),
-        std::forward<Args>(args)...);
-  }
-
-  /// Get an uniqued instance of attribute T. This overload is used for
-  /// derived attributes that use the AttributeStorage directly and thus need no
-  /// additional storage or uniquing.
-  template <typename T, typename... Args>
-  static typename std::enable_if<
-      std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
-  get(MLIRContext *ctx, Attribute::Kind kind) {
-    return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
-        getInitFn(ctx), static_cast<unsigned>(kind));
+  static T get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+    return ctx->getAttributeUniquer().get<typename T::ImplType>(
+        [ctx](AttributeStorage *storage) {
+          // If the attribute did not provide a type, then default to NoneType.
+          if (!storage->getType())
+            storage->setType(NoneType::get(ctx));
+        },
+        static_cast<unsigned>(kind), std::forward<Args>(args)...);
   }
 
-  /// Erase a uniqued instance of attribute T. This overload is used for
-  /// derived attributes that have complex storage or uniquing constraints.
+  /// Erase a uniqued instance of attribute T.
   template <typename T, typename... Args>
-  static typename std::enable_if<
-      !std::is_same<typename T::ImplType, AttributeStorage>::value>::type
-  erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
-    return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
+  static void erase(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
+    return ctx->getAttributeUniquer().erase<typename T::ImplType>(
         static_cast<unsigned>(kind), std::forward<Args>(args)...);
   }
-
-  /// Generate a functor to initialize a new attribute storage instance.
-  static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx) {
-    return [ctx](AttributeStorage *storage) {
-      // If the attribute did not provide a type, then default to NoneType.
-      if (!storage->getType())
-        storage->setType(NoneType::get(ctx));
-    };
-  }
 };
 
 using AttributeStorageAllocator = StorageUniquer::StorageAllocator;