[mlir][SubElementInterfaces] Prefer calling the derived get if possible
authorRiver Riddle <riddleriver@gmail.com>
Sat, 5 Nov 2022 23:35:25 +0000 (16:35 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Sat, 5 Nov 2022 23:35:25 +0000 (16:35 -0700)
This allows for better supporting attributes/types that override the
default builders.

mlir/include/mlir/IR/SubElementInterfaces.h

index ed387eb..07d246a 100644 (file)
@@ -220,6 +220,8 @@ template <typename T>
 struct is_tuple : public std::false_type {};
 template <typename... Ts>
 struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
+template <typename T, typename... Ts>
+using has_get_method = decltype(T::get(std::declval<Ts>()...));
 
 /// This function provides the underlying implementation for the
 /// SubElementInterface walk method, using the key type of the derived
@@ -239,6 +241,23 @@ void walkImmediateSubElementsImpl(T derived,
   }
 }
 
+/// This function invokes the proper `get` method for  a type `T` with the given
+/// values.
+template <typename T, typename... Ts>
+T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
+  // Prefer a direct `get` method if one exists.
+  if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
+    (void)ctx;
+    return T::get(std::forward<Ts>(params)...);
+  } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
+                                         Ts...>::value) {
+    return T::get(ctx, std::forward<Ts>(params)...);
+  } else {
+    // Otherwise, pass to the base get.
+    return T::Base::get(ctx, std::forward<Ts>(params)...);
+  }
+}
+
 /// This function provides the underlying implementation for the
 /// SubElementInterface replace method, using the key type of the derived
 /// attribute/type to interact with the individual parameters.
@@ -260,12 +279,13 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
     if constexpr (is_tuple<decltype(key)>::value) {
       return std::apply(
           [&](auto &&...params) {
-            return T::Base::get(derived.getContext(),
-                                std::forward<decltype(params)>(params)...);
+            return constructSubElementReplacement<T>(
+                derived.getContext(),
+                std::forward<decltype(params)>(params)...);
           },
           newKey);
     } else {
-      return T::Base::get(derived.getContext(), newKey);
+      return constructSubElementReplacement<T>(derived.getContext(), newKey);
     }
   }
 }