From: River Riddle Date: Sat, 5 Nov 2022 23:35:25 +0000 (-0700) Subject: [mlir][SubElementInterfaces] Prefer calling the derived get if possible X-Git-Tag: upstream/17.0.6~28421 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a782922708af4e80bc9eaba977704420b6c765d9;p=platform%2Fupstream%2Fllvm.git [mlir][SubElementInterfaces] Prefer calling the derived get if possible This allows for better supporting attributes/types that override the default builders. --- diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h index ed387eb..07d246a 100644 --- a/mlir/include/mlir/IR/SubElementInterfaces.h +++ b/mlir/include/mlir/IR/SubElementInterfaces.h @@ -220,6 +220,8 @@ template struct is_tuple : public std::false_type {}; template struct is_tuple> : public std::true_type {}; +template +using has_get_method = decltype(T::get(std::declval()...)); /// This function provides the underlying implementation for the /// SubElementInterface walk method, using the key type of the derived @@ -239,6 +241,23 @@ void walkImmediateSubElementsImpl(T derived, } } +/// This function invokes the proper `get` method for a type `T` with the given +/// values. +template +T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) { + // Prefer a direct `get` method if one exists. + if constexpr (llvm::is_detected::value) { + (void)ctx; + return T::get(std::forward(params)...); + } else if constexpr (llvm::is_detected::value) { + return T::get(ctx, std::forward(params)...); + } else { + // Otherwise, pass to the base get. + return T::Base::get(ctx, std::forward(params)...); + } +} + /// This function provides the underlying implementation for the /// SubElementInterface replace method, using the key type of the derived /// attribute/type to interact with the individual parameters. @@ -260,12 +279,13 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef &replAttrs, if constexpr (is_tuple::value) { return std::apply( [&](auto &&...params) { - return T::Base::get(derived.getContext(), - std::forward(params)...); + return constructSubElementReplacement( + derived.getContext(), + std::forward(params)...); }, newKey); } else { - return T::Base::get(derived.getContext(), newKey); + return constructSubElementReplacement(derived.getContext(), newKey); } } }