Explicitly define supported types (#19516)
authorSebastian Messmer <messmer@fb.com>
Mon, 22 Apr 2019 23:16:30 +0000 (16:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 22 Apr 2019 23:31:28 +0000 (16:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19516

Explicitly define types that are supported in kernel inputs and outputs.
Also, this allows us to show much nicer error messages if a user writes kernels with wrong argument types.

Reviewed By: ezyang

Differential Revision: D15020306

fbshipit-source-id: 55ebec81e075e874777acd59aa29a5578fc19ef7

aten/src/ATen/core/op_registration/kernel_functor.h
aten/src/ATen/core/op_registration/op_registration_test.cpp
c10/test/util/TypeList_test.cpp
c10/util/TypeList.h

index d589362..52cf47b 100644 (file)
@@ -25,13 +25,29 @@ namespace c10 {
 class OperatorKernel : public KernelCache {};
 
 namespace detail {
+  // supported_primitive_arg_types defines which primitive types we allow in
+  // kernel functions as arguments or returns.
+  // Additionally, we support lists, dicts and optionals containing these types.
+  using supported_primitive_arg_types = guts::typelist::typelist<
+    int64_t,
+    double,
+    bool,
+    std::string,
+    at::Tensor,
+    at::Scalar
+  >;
+
   // ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
   // cast it to the type that should be passed to the kernel function.
   // Examples: If the IValue contains a plain type like an int, return that.
   //           If the IValue contains an IntList, return it as ArrayRef<int>.
   // TODO Should we move the IValue so we can avoid bumping the Tensor refcount?
+  template<class T, class Enable = void> struct ivalue_to_arg_type {
+    // This base case is hit whenever a type does not have a specialisation below.
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type.");
+  };
   template<class T>
-  struct ivalue_to_arg_type {
+  struct ivalue_to_arg_type<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
     static T call(const IValue& v) {
       return std::move(v).to<T>();
     }
@@ -39,30 +55,51 @@ namespace detail {
   template<class T>
   struct ivalue_to_arg_type<ArrayRef<T>> {
     static ArrayRef<T> call(const IValue& v) {
+      // TODO Do we want to support ArrayRef<optional<T>> ?
+      static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported argument type: c10::ArrayRef<T> and T is not one of the supported primitive types.");
       return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
     }
   };
   template<class T>
-  struct ivalue_to_arg_type<std::vector<T>> {
-    static ArrayRef<T> call(const IValue& v) {
-      // We don't support std::vector because that would prevent us from doing
-      // internal optimization to how we represent lists (e.g. SmallVector).
-      // Users should use ArrayRef instead.
-      static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead.");
-    }
-  };
-  template<class T>
   struct ivalue_to_arg_type<optional<T>> {
     static optional<T> call(const IValue& v) {
       if (v.isNone()) {
         return nullopt;
       }
-      return v.to<T>();
+      return ivalue_to_arg_type<T>::call(v);
     }
   };
-
+  // The following specialisations of ivalue_to_arg_type are technically not
+  // necessary since we would hit the base case and show an error message
+  // there if they didn't exist, but we can show a better error message
+  // in some common error scenarios.
   template<class T>
+  struct ivalue_to_arg_type<std::vector<T>> {
+    // We don't support std::vector because that would prevent us from doing
+    // internal optimization to how we represent lists (e.g. SmallVector).
+    // Users should use ArrayRef instead.
+    static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead.");
+  };
+  template<class T>
+  struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<float, T>::value>> {
+    // There is no reason to support float when we have double. Keep the API lean.
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: float. Please use double instead.");
+  };
+  template<class T>
+  struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<const char*, T>::value>> {
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: const char*. Please use std::string instead.");
+  };
+  template<class T>
+  struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral argument type. Please use int64_t instead.");
+  };
+
+  template<class T, class Enable = void>
   struct return_type_to_ivalue_ {
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type.");
+  };
+  template<class T>
+  struct return_type_to_ivalue_<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
     static IValue call(T&& v) {
       return IValue(std::move(v));
     }
@@ -73,9 +110,38 @@ namespace detail {
       if (!v.has_value()) {
         return IValue();
       }
-      return IValue(std::move(*v));
+      return return_type_to_ivalue_<T>::call(std::move(*v));
+    }
+  };
+  template<class T>
+  struct return_type_to_ivalue_<std::vector<T>> {
+    static IValue call(std::vector<T>&& v) {
+      // TODO Do we want to support vector<optional<T>> ?
+      static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported return type: vector<T> and T is not one of the supported primitive types.");
+      return IValue(std::move(v));
     }
   };
+  // The following specialisations of return_type_to_ivalue_ are technically not
+  // necessary since we would hit the base case and show an error message
+  // there if they didn't exist, but we can show a better error message
+  // in some common error scenarios.
+  template<class T>
+  struct return_type_to_ivalue_<c10::ArrayRef<T>> {
+    static_assert(guts::false_t<c10::ArrayRef<T>>::value, "You tried to register a kernel with an unsupported return type: c10::ArrayRef<T>. Please use std::vector<T> instead.");
+  };
+  template<class T>
+  struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<float, T>::value>> {
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: float. Please use double instead.");
+  };
+  template<class T>
+  struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<const char*, T>::value>> {
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: const char*. Please use std::string instead.");
+  };
+  template<class T>
+  struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
+    static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral return argument type. Please use int64_t instead.");
+  };
+
   template<class T>
   IValue return_type_to_ivalue(T&& v) {
     return return_type_to_ivalue_<guts::decay_t<T>>::call(std::move(v));
index 615555b..cef5c2a 100644 (file)
@@ -265,6 +265,8 @@ private:
 };
 
 TEST(OperatorRegistrationTest, testAvailableArgTypes) {
+  // TODO Test Scalar
+
   // primitive types
   ArgTypeTestKernel<double>::test(
     1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
@@ -481,8 +483,9 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
     },
     "(Tensor[] a) -> Tensor[]");
 
+  // TODO We support optional of list. Add test cases for it.
 
-  // TODO Do we want to support list of optional / optional of list ?
+  // TODO Do we want to support list of optional ?
 
   // TODO Add tests for dict types
 }
index dd811b5..9350c7a 100644 (file)
@@ -145,3 +145,10 @@ namespace test_find_if {
   static_assert(2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value, "");
   static_assert(3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value, "");
 }
+
+namespace test_contains {
+  static_assert(contains<typelist<double>, double>::value, "");
+  static_assert(contains<typelist<int, double>, double>::value, "");
+  static_assert(!contains<typelist<int, double>, float>::value, "");
+  static_assert(!contains<typelist<>, double>::value, "");
+}
index bb0d6bd..373e669 100644 (file)
@@ -132,6 +132,23 @@ struct count_if final {
 };
 
 
+/**
+ * Checks if a typelist contains a certain type.
+ * Examples:
+ *  contains<typelist<int, string>, string> == true_type
+ *  contains<typelist<int, string>, double> == false_type
+ */
+namespace detail {
+template<class TypeList, class Type, class Enable = void> struct contains {};
+template<class Type> struct contains<typelist<>, Type, void> : std::false_type {};
+template<class Type, class Head, class... Tail>
+struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<std::is_same<Head, Type>::value>> : std::true_type {};
+template<class Type, class Head, class... Tail>
+struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<!std::is_same<Head, Type>::value>> : contains<typelist<Tail...>, Type> {};
+}
+template<class TypeList, class Type>
+using contains = typename detail::contains<TypeList, Type>::type;
+
 
 /**
  * Returns true iff the type trait is true for all types in the type list