From: Sebastian Messmer Date: Mon, 22 Apr 2019 23:16:30 +0000 (-0700) Subject: Explicitly define supported types (#19516) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~68 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=969af4315a30e96205e125a16e67bd6e3c03e218;p=platform%2Fupstream%2Fpytorch.git Explicitly define supported types (#19516) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19516 Explicitly define types that are supported in kernel inputs and outputs. Also, this allows us to show much nicer error messages if a user writes kernels with wrong argument types. Reviewed By: ezyang Differential Revision: D15020306 fbshipit-source-id: 55ebec81e075e874777acd59aa29a5578fc19ef7 --- diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h index d589362..52cf47b 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor.h +++ b/aten/src/ATen/core/op_registration/kernel_functor.h @@ -25,13 +25,29 @@ namespace c10 { class OperatorKernel : public KernelCache {}; namespace detail { + // supported_primitive_arg_types defines which primitive types we allow in + // kernel functions as arguments or returns. + // Additionally, we support lists, dicts and optionals containing these types. + using supported_primitive_arg_types = guts::typelist::typelist< + int64_t, + double, + bool, + std::string, + at::Tensor, + at::Scalar + >; + // ivalue_to_arg_type: Take an IValue that is an argument to a kernel and // cast it to the type that should be passed to the kernel function. // Examples: If the IValue contains a plain type like an int, return that. // If the IValue contains an IntList, return it as ArrayRef. // TODO Should we move the IValue so we can avoid bumping the Tensor refcount? + template struct ivalue_to_arg_type { + // This base case is hit whenever a type does not have a specialisation below. + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported argument type."); + }; template - struct ivalue_to_arg_type { + struct ivalue_to_arg_type::value>> { static T call(const IValue& v) { return std::move(v).to(); } @@ -39,30 +55,51 @@ namespace detail { template struct ivalue_to_arg_type> { static ArrayRef call(const IValue& v) { + // TODO Do we want to support ArrayRef> ? + static_assert(guts::typelist::contains::value, "You tried to register a kernel with an unsupported argument type: c10::ArrayRef and T is not one of the supported primitive types."); return v.to>>()->elements(); } }; template - struct ivalue_to_arg_type> { - static ArrayRef call(const IValue& v) { - // We don't support std::vector because that would prevent us from doing - // internal optimization to how we represent lists (e.g. SmallVector). - // Users should use ArrayRef instead. - static_assert(guts::false_t>::value, "You tried to register a kernel with an unsupported argument type: std::vector. Please use c10::ArrayRef instead."); - } - }; - template struct ivalue_to_arg_type> { static optional call(const IValue& v) { if (v.isNone()) { return nullopt; } - return v.to(); + return ivalue_to_arg_type::call(v); } }; - + // The following specialisations of ivalue_to_arg_type are technically not + // necessary since we would hit the base case and show an error message + // there if they didn't exist, but we can show a better error message + // in some common error scenarios. template + struct ivalue_to_arg_type> { + // We don't support std::vector because that would prevent us from doing + // internal optimization to how we represent lists (e.g. SmallVector). + // Users should use ArrayRef instead. + static_assert(guts::false_t>::value, "You tried to register a kernel with an unsupported argument type: std::vector. Please use c10::ArrayRef instead."); + }; + template + struct ivalue_to_arg_type::value>> { + // There is no reason to support float when we have double. Keep the API lean. + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported argument type: float. Please use double instead."); + }; + template + struct ivalue_to_arg_type::value>> { + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported argument type: const char*. Please use std::string instead."); + }; + template + struct ivalue_to_arg_type::value && !guts::typelist::contains::value>> { + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported integral argument type. Please use int64_t instead."); + }; + + template struct return_type_to_ivalue_ { + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported return type."); + }; + template + struct return_type_to_ivalue_::value>> { static IValue call(T&& v) { return IValue(std::move(v)); } @@ -73,9 +110,38 @@ namespace detail { if (!v.has_value()) { return IValue(); } - return IValue(std::move(*v)); + return return_type_to_ivalue_::call(std::move(*v)); + } + }; + template + struct return_type_to_ivalue_> { + static IValue call(std::vector&& v) { + // TODO Do we want to support vector> ? + static_assert(guts::typelist::contains::value, "You tried to register a kernel with an unsupported return type: vector and T is not one of the supported primitive types."); + return IValue(std::move(v)); } }; + // The following specialisations of return_type_to_ivalue_ are technically not + // necessary since we would hit the base case and show an error message + // there if they didn't exist, but we can show a better error message + // in some common error scenarios. + template + struct return_type_to_ivalue_> { + static_assert(guts::false_t>::value, "You tried to register a kernel with an unsupported return type: c10::ArrayRef. Please use std::vector instead."); + }; + template + struct return_type_to_ivalue_::value>> { + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported return type: float. Please use double instead."); + }; + template + struct return_type_to_ivalue_::value>> { + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported return type: const char*. Please use std::string instead."); + }; + template + struct return_type_to_ivalue_::value && !guts::typelist::contains::value>> { + static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported integral return argument type. Please use int64_t instead."); + }; + template IValue return_type_to_ivalue(T&& v) { return return_type_to_ivalue_>::call(std::move(v)); diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 615555b..cef5c2a 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -265,6 +265,8 @@ private: }; TEST(OperatorRegistrationTest, testAvailableArgTypes) { + // TODO Test Scalar + // primitive types ArgTypeTestKernel::test( 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, @@ -481,8 +483,9 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { }, "(Tensor[] a) -> Tensor[]"); + // TODO We support optional of list. Add test cases for it. - // TODO Do we want to support list of optional / optional of list ? + // TODO Do we want to support list of optional ? // TODO Add tests for dict types } diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp index dd811b5..9350c7a 100644 --- a/c10/test/util/TypeList_test.cpp +++ b/c10/test/util/TypeList_test.cpp @@ -145,3 +145,10 @@ namespace test_find_if { static_assert(2 == find_if, std::is_reference>::value, ""); static_assert(3 == find_if, std::is_reference>::value, ""); } + +namespace test_contains { + static_assert(contains, double>::value, ""); + static_assert(contains, double>::value, ""); + static_assert(!contains, float>::value, ""); + static_assert(!contains, double>::value, ""); +} diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h index bb0d6bd..373e669 100644 --- a/c10/util/TypeList.h +++ b/c10/util/TypeList.h @@ -132,6 +132,23 @@ struct count_if final { }; +/** + * Checks if a typelist contains a certain type. + * Examples: + * contains, string> == true_type + * contains, double> == false_type + */ +namespace detail { +template struct contains {}; +template struct contains, Type, void> : std::false_type {}; +template +struct contains, Type, guts::enable_if_t::value>> : std::true_type {}; +template +struct contains, Type, guts::enable_if_t::value>> : contains, Type> {}; +} +template +using contains = typename detail::contains::type; + /** * Returns true iff the type trait is true for all types in the type list