[llvm][STLExtras] Add various type_trait utilities currently present in MLIR
authorRiver Riddle <riddleriver@gmail.com>
Tue, 14 Apr 2020 21:52:52 +0000 (14:52 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 14 Apr 2020 22:14:40 +0000 (15:14 -0700)
This revision moves several type_trait utilities from MLIR into LLVM. Namely, this revision adds:
is_detected - This matches the experimental std::is_detected
is_invocable - This matches the c++17 std::is_invocable
function_traits - A utility traits class for getting the argument and result types of a callable type

Differential Revision: https://reviews.llvm.org/D78059

llvm/include/llvm/ADT/STLExtras.h
llvm/unittests/ADT/CMakeLists.txt
llvm/unittests/ADT/TypeTraitsTest.cpp [new file with mode: 0644]
mlir/include/mlir/ADT/TypeSwitch.h
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Support/STLExtras.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/IR/SymbolTable.cpp

index 215bc5a..53457f0 100644 (file)
@@ -75,6 +75,79 @@ template <typename T> struct make_const_ref {
       typename std::add_const<T>::type>::type;
 };
 
+/// Utilities for detecting if a given trait holds for some set of arguments
+/// 'Args'. For example, the given trait could be used to detect if a given type
+/// has a copy assignment operator:
+///   template<class T>
+///   using has_copy_assign_t = decltype(std::declval<T&>()
+///                                                 = std::declval<const T&>());
+///   bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
+namespace detail {
+template <typename...> using void_t = void;
+template <class, template <class...> class Op, class... Args> struct detector {
+  using value_t = std::false_type;
+};
+template <template <class...> class Op, class... Args>
+struct detector<void_t<Op<Args...>>, Op, Args...> {
+  using value_t = std::true_type;
+};
+} // end namespace detail
+
+template <template <class...> class Op, class... Args>
+using is_detected = typename detail::detector<void, Op, Args...>::value_t;
+
+/// Check if a Callable type can be invoked with the given set of arg types.
+namespace detail {
+template <typename Callable, typename... Args>
+using is_invocable =
+    decltype(std::declval<Callable &>()(std::declval<Args>()...));
+} // namespace detail
+
+template <typename Callable, typename... Args>
+using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
+
+/// This class provides various trait information about a callable object.
+///   * To access the number of arguments: Traits::num_args
+///   * To access the type of an argument: Traits::arg_t<i>
+///   * To access the type of the result:  Traits::result_t
+template <typename T, bool isClass = std::is_class<T>::value>
+struct function_traits : public function_traits<decltype(&T::operator())> {};
+
+/// Overload for class function types.
+template <typename ClassType, typename ReturnType, typename... Args>
+struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
+  /// The number of arguments to this function.
+  enum { num_args = sizeof...(Args) };
+
+  /// The result type of this function.
+  using result_t = ReturnType;
+
+  /// The type of an argument to this function.
+  template <size_t i>
+  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
+};
+/// Overload for class function types.
+template <typename ClassType, typename ReturnType, typename... Args>
+struct function_traits<ReturnType (ClassType::*)(Args...), false>
+    : function_traits<ReturnType (ClassType::*)(Args...) const> {};
+/// Overload for non-class function types.
+template <typename ReturnType, typename... Args>
+struct function_traits<ReturnType (*)(Args...), false> {
+  /// The number of arguments to this function.
+  enum { num_args = sizeof...(Args) };
+
+  /// The result type of this function.
+  using result_t = ReturnType;
+
+  /// The type of an argument to this function.
+  template <size_t i>
+  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
+};
+/// Overload for non-class function type references.
+template <typename ReturnType, typename... Args>
+struct function_traits<ReturnType (&)(Args...), false>
+    : public function_traits<ReturnType (*)(Args...)> {};
+
 //===----------------------------------------------------------------------===//
 //     Extra additions to <functional>
 //===----------------------------------------------------------------------===//
index 771d160..b6a1484 100644 (file)
@@ -73,6 +73,7 @@ add_llvm_unittest(ADTTests
   TinyPtrVectorTest.cpp
   TripleTest.cpp
   TwineTest.cpp
+  TypeTraitsTest.cpp
   WaymarkingTest.cpp
   )
 
diff --git a/llvm/unittests/ADT/TypeTraitsTest.cpp b/llvm/unittests/ADT/TypeTraitsTest.cpp
new file mode 100644 (file)
index 0000000..d38505c
--- /dev/null
@@ -0,0 +1,80 @@
+//===- TypeTraitsTest.cpp - type_traits unit tests ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+//===----------------------------------------------------------------------===//
+// function_traits
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check a callable type of the form `bool(const int &)`.
+template <typename CallableT> struct CheckFunctionTraits {
+  static_assert(
+      std::is_same<typename function_traits<CallableT>::result_t, bool>::value,
+      "expected result_t to be `bool`");
+  static_assert(
+      std::is_same<typename function_traits<CallableT>::template arg_t<0>,
+                   const int &>::value,
+      "expected arg_t<0> to be `const int &`");
+  static_assert(function_traits<CallableT>::num_args == 1,
+                "expected num_args to be 1");
+};
+
+/// Test function pointers.
+using FuncType = bool (*)(const int &);
+struct CheckFunctionPointer : CheckFunctionTraits<FuncType> {};
+
+static bool func(const int &v);
+struct CheckFunctionPointer2 : CheckFunctionTraits<decltype(&func)> {};
+
+/// Test method pointers.
+struct Foo {
+  bool func(const int &v);
+};
+struct CheckMethodPointer : CheckFunctionTraits<decltype(&Foo::func)> {};
+
+/// Test lambda references.
+auto lambdaFunc = [](const int &v) -> bool { return true; };
+struct CheckLambda : CheckFunctionTraits<decltype(lambdaFunc)> {};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// is_detected
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct HasFooMethod {
+  void foo() {}
+};
+struct NoFooMethod {};
+
+template <class T> using has_foo_method_t = decltype(std::declval<T &>().foo());
+
+static_assert(is_detected<has_foo_method_t, HasFooMethod>::value,
+              "expected foo method to be detected");
+static_assert(!is_detected<has_foo_method_t, NoFooMethod>::value,
+              "expected no foo method to be detected");
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// is_invocable
+//===----------------------------------------------------------------------===//
+
+static void invocable_fn(int) {}
+
+static_assert(is_invocable<decltype(invocable_fn), int>::value,
+              "expected function to be invocable");
+static_assert(!is_invocable<decltype(invocable_fn), void *>::value,
+              "expected function not to be invocable");
+static_assert(!is_invocable<decltype(invocable_fn), int, int>::value,
+              "expected function not to be invocable");
index 38fc3e9..d4798c5 100644 (file)
@@ -46,7 +46,7 @@ public:
   /// Note: This inference rules for this overload are very simple: strip
   ///       pointers and references.
   template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
-    using Traits = FunctionTraits<std::decay_t<CallableT>>;
+    using Traits = llvm::function_traits<std::decay_t<CallableT>>;
     using CaseT = std::remove_cv_t<std::remove_pointer_t<
         std::remove_reference_t<typename Traits::template arg_t<0>>>>;
 
@@ -64,20 +64,22 @@ protected:
   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
   /// selected if `value` already has a suitable dyn_cast method.
   template <typename CastT, typename ValueT>
-  static auto castValue(
-      ValueT value,
-      typename std::enable_if_t<
-          is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
+  static auto
+  castValue(ValueT value,
+            typename std::enable_if_t<
+                llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
+                nullptr) {
     return value.template dyn_cast<CastT>();
   }
 
   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
   /// selected if llvm::dyn_cast should be used.
   template <typename CastT, typename ValueT>
-  static auto castValue(
-      ValueT value,
-      typename std::enable_if_t<
-          !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
+  static auto
+  castValue(ValueT value,
+            typename std::enable_if_t<
+                !llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
+                nullptr) {
     return dyn_cast<CastT>(value);
   }
 
index 12da468..7866206 100644 (file)
@@ -140,18 +140,20 @@ using has_operation_or_value_matcher_t =
 
 /// Statically switch to a Value matcher.
 template <typename MatcherClass>
-typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
-                                      MatcherClass, Value>::value,
-                          bool>
+typename std::enable_if_t<
+    llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
+                      Value>::value,
+    bool>
 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
   return matcher.match(op->getOperand(idx));
 }
 
 /// Statically switch to an Operation matcher.
 template <typename MatcherClass>
-typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
-                                      MatcherClass, Operation *>::value,
-                          bool>
+typename std::enable_if_t<
+    llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
+                      Operation *>::value,
+    bool>
 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
   if (auto defOp = op->getOperand(idx).getDefiningOp())
     return matcher.match(defOp);
index e8944c7..a503c1e 100644 (file)
@@ -1298,16 +1298,16 @@ private:
     /// If 'T' is the same interface as 'interfaceID' return the concept
     /// instance.
     template <typename T>
-    static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
-                                   void *>::type
+    static typename std::enable_if<
+        llvm::is_detected<has_get_interface_id, T>::value, void *>::type
     lookup(TypeID interfaceID) {
       return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
     }
 
     /// 'T' is known to not be an interface, return nullptr.
     template <typename T>
-    static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
-                                   void *>::type
+    static typename std::enable_if<
+        !llvm::is_detected<has_get_interface_id, T>::value, void *>::type
     lookup(TypeID) {
       return nullptr;
     }
index 2b948c2..4e9b3f2 100644 (file)
@@ -71,13 +71,13 @@ using has_is_invalidated = decltype(std::declval<T &>().isInvalidated(
 
 /// Implementation of 'isInvalidated' if the analysis provides a definition.
 template <typename AnalysisT>
-std::enable_if_t<is_detected<has_is_invalidated, AnalysisT>::value, bool>
+std::enable_if_t<llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
 isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
   return analysis.isInvalidated(pa);
 }
 /// Default implementation of 'isInvalidated'.
 template <typename AnalysisT>
-std::enable_if_t<!is_detected<has_is_invalidated, AnalysisT>::value, bool>
+std::enable_if_t<!llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
 isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
   return !pa.isPreserved<AnalysisT>();
 }
index ada69a9..2279a35 100644 (file)
@@ -88,37 +88,6 @@ inline void interleaveComma(const Container &c, raw_ostream &os) {
   interleaveComma(c, os, [&](const T &a) { os << a; });
 }
 
-/// Utilities for detecting if a given trait holds for some set of arguments
-/// 'Args'. For example, the given trait could be used to detect if a given type
-/// has a copy assignment operator:
-///   template<class T>
-///   using has_copy_assign_t = decltype(std::declval<T&>()
-///                                                 = std::declval<const T&>());
-///   bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
-namespace detail {
-template <typename...> using void_t = void;
-template <class, template <class...> class Op, class... Args> struct detector {
-  using value_t = std::false_type;
-};
-template <template <class...> class Op, class... Args>
-struct detector<void_t<Op<Args...>>, Op, Args...> {
-  using value_t = std::true_type;
-};
-} // end namespace detail
-
-template <template <class...> class Op, class... Args>
-using is_detected = typename detail::detector<void, Op, Args...>::value_t;
-
-/// Check if a Callable type can be invoked with the given set of arg types.
-namespace detail {
-template <typename Callable, typename... Args>
-using is_invocable =
-    decltype(std::declval<Callable &>()(std::declval<Args>()...));
-} // namespace detail
-
-template <typename Callable, typename... Args>
-using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
-
 //===----------------------------------------------------------------------===//
 //     Extra additions to <iterator>
 //===----------------------------------------------------------------------===//
@@ -356,47 +325,6 @@ template <typename ContainerTy> bool has_single_element(ContainerTy &&c) {
   return it != e && std::next(it) == e;
 }
 
-//===----------------------------------------------------------------------===//
-//     Extra additions to <type_traits>
-//===----------------------------------------------------------------------===//
-
-/// This class provides various trait information about a callable object.
-///   * To access the number of arguments: Traits::num_args
-///   * To access the type of an argument: Traits::arg_t<i>
-///   * To access the type of the result: Traits::result_t<i>
-template <typename T, bool isClass = std::is_class<T>::value>
-struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
-
-/// Overload for class function types.
-template <typename ClassType, typename ReturnType, typename... Args>
-struct FunctionTraits<ReturnType (ClassType::*)(Args...) const, false> {
-  /// The number of arguments to this function.
-  enum { num_args = sizeof...(Args) };
-
-  /// The result type of this function.
-  using result_t = ReturnType;
-
-  /// The type of an argument to this function.
-  template <size_t i>
-  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
-};
-/// Overload for non-class function types.
-template <typename ReturnType, typename... Args>
-struct FunctionTraits<ReturnType (*)(Args...), false> {
-  /// The number of arguments to this function.
-  enum { num_args = sizeof...(Args) };
-
-  /// The result type of this function.
-  using result_t = ReturnType;
-
-  /// The type of an argument to this function.
-  template <size_t i>
-  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
-};
-/// Overload for non-class function type references.
-template <typename ReturnType, typename... Args>
-struct FunctionTraits<ReturnType (&)(Args...), false>
-    : public FunctionTraits<ReturnType (*)(Args...)> {};
 } // end namespace mlir
 
 #endif // MLIR_SUPPORT_STLEXTRAS_H
index 7b8c490..637c2ce 100644 (file)
@@ -215,7 +215,7 @@ private:
   /// 'ImplTy::getKey' function for the provided arguments.
   template <typename ImplTy, typename... Args>
   static typename std::enable_if<
-      is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
+      llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
       typename ImplTy::KeyTy>::type
   getKey(Args &&... args) {
     return ImplTy::getKey(args...);
@@ -224,7 +224,7 @@ private:
   /// the 'ImplTy::KeyTy' with the provided arguments.
   template <typename ImplTy, typename... Args>
   static typename std::enable_if<
-      !is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
+      !llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
       typename ImplTy::KeyTy>::type
   getKey(Args &&... args) {
     return typename ImplTy::KeyTy(args...);
@@ -238,7 +238,7 @@ private:
   /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
   template <typename ImplTy, typename DerivedKey>
   static typename std::enable_if<
-      is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
+      llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
       ::llvm::hash_code>::type
   getHash(unsigned kind, const DerivedKey &derivedKey) {
     return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
@@ -246,9 +246,9 @@ private:
   /// If there is no 'ImplTy::hashKey' default to using the
   /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
   template <typename ImplTy, typename DerivedKey>
-  static typename std::enable_if<
-      !is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
-      ::llvm::hash_code>::type
+  static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
+                                                    ImplTy, DerivedKey>::value,
+                                 ::llvm::hash_code>::type
   getHash(unsigned kind, const DerivedKey &derivedKey) {
     return llvm::hash_combine(
         kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
index 1c2fc74..2298b3b 100644 (file)
@@ -108,7 +108,7 @@ public:
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT,
-            typename T = typename FunctionTraits<FnT>::template arg_t<0>>
+            typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
   void addConversion(FnT &&callback) {
     registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
   }
@@ -172,7 +172,7 @@ private:
   /// different callback forms, that all compose into a single version.
   /// With callback of form: `Optional<Type>(T)`
   template <typename T, typename FnT>
-  std::enable_if_t<is_invocable<FnT, T>::value, ConversionCallbackFn>
+  std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
   wrapCallback(FnT &&callback) {
     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
                                T type, SmallVectorImpl<Type> &results) {
@@ -187,7 +187,7 @@ private:
   }
   /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
   template <typename T, typename FnT>
-  std::enable_if_t<!is_invocable<FnT, T>::value, ConversionCallbackFn>
+  std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
   wrapCallback(FnT &&callback) {
     return [callback = std::forward<FnT>(callback)](
                Type type,
@@ -482,7 +482,8 @@ public:
     addDynamicallyLegalOp<OpT2, OpTs...>(callback);
   }
   template <typename OpT, class Callable>
-  typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
+  typename std::enable_if<
+      !llvm::is_invocable<Callable, Operation *>::value>::type
   addDynamicallyLegalOp(Callable &&callback) {
     addDynamicallyLegalOp<OpT>(
         [=](Operation *op) { return callback(cast<OpT>(op)); });
@@ -514,7 +515,8 @@ public:
     markOpRecursivelyLegal<OpT2, OpTs...>(callback);
   }
   template <typename OpT, class Callable>
-  typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
+  typename std::enable_if<
+      !llvm::is_invocable<Callable, Operation *>::value>::type
   markOpRecursivelyLegal(Callable &&callback) {
     markOpRecursivelyLegal<OpT>(
         [=](Operation *op) { return callback(cast<OpT>(op)); });
index ba1c8c3..f6fe0ce 100644 (file)
@@ -477,8 +477,8 @@ struct SymbolScope {
   /// 'walkSymbolUses'.
   template <typename CallbackT,
             typename std::enable_if_t<!std::is_same<
-                typename FunctionTraits<CallbackT>::result_t, void>::value> * =
-                nullptr>
+                typename llvm::function_traits<CallbackT>::result_t,
+                void>::value> * = nullptr>
   Optional<WalkResult> walk(CallbackT cback) {
     if (Region *region = limit.dyn_cast<Region *>())
       return walkSymbolUses(*region, cback);
@@ -488,8 +488,8 @@ struct SymbolScope {
   /// void(SymbolTable::SymbolUse use)
   template <typename CallbackT,
             typename std::enable_if_t<std::is_same<
-                typename FunctionTraits<CallbackT>::result_t, void>::value> * =
-                nullptr>
+                typename llvm::function_traits<CallbackT>::result_t,
+                void>::value> * = nullptr>
   Optional<WalkResult> walk(CallbackT cback) {
     return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
       return cback(use), WalkResult::advance();