From 14c28fabd23c92d120544e8d5434b38efda2e7dc Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Sat, 30 Mar 2019 00:03:44 -0700 Subject: [PATCH] Check kernel against function schema in c10 op registration (#18256) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18256 This diff infers the function schema from the kernel function/functor and checks that it matches the specified function schema. This diff does not allow (yet) to omit specifying the function schema in the registration API. That will come in a future diff. Reviewed By: dzhulgakov Differential Revision: D14552738 fbshipit-source-id: 00202b489ede19f26ae686c97416b38c72c11532 --- aten/src/ATen/core/op_registration/base.h | 1 + .../src/ATen/core/op_registration/infer_schema.cpp | 47 +++ aten/src/ATen/core/op_registration/infer_schema.h | 89 +++--- .../core/op_registration/kernel_function_test.cpp | 333 ++++++++++++++++++++- .../src/ATen/core/op_registration/kernel_functor.h | 28 +- .../core/op_registration/kernel_functor_test.cpp | 332 +++++++++++++++++++- .../ATen/core/op_registration/kernel_stackbased.h | 25 +- .../ATen/core/op_registration/op_registration.h | 6 + c10/util/C++17.h | 3 + c10/util/TypeTraits.h | 1 + .../operators/experimental/c10/cpu/concat_cpu.cc | 2 +- 11 files changed, 784 insertions(+), 83 deletions(-) create mode 100644 aten/src/ATen/core/op_registration/infer_schema.cpp diff --git a/aten/src/ATen/core/op_registration/base.h b/aten/src/ATen/core/op_registration/base.h index 3c4ea05..60250a5 100644 --- a/aten/src/ATen/core/op_registration/base.h +++ b/aten/src/ATen/core/op_registration/base.h @@ -56,6 +56,7 @@ namespace detail { TensorTypeId dispatch_key; KernelFunction* kernel_func = nullptr; KernelCacheCreatorFunction cache_creator_func = nullptr; + std::unique_ptr inferred_function_schema = nullptr; }; // is_registration_config_parameter is a concept that returns true_type iff its argument is diff --git a/aten/src/ATen/core/op_registration/infer_schema.cpp b/aten/src/ATen/core/op_registration/infer_schema.cpp new file mode 100644 index 0000000..bcfeb87 --- /dev/null +++ b/aten/src/ATen/core/op_registration/infer_schema.cpp @@ -0,0 +1,47 @@ +#include "infer_schema.h" +#include + +namespace c10 { + +namespace { + std::string serialize_schema(const FunctionSchema& schema) { + std::ostringstream str; + str << schema; + return str.str(); + } +} + +C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified) { + if (inferred.arguments().size() != specified.arguments().size()) { + AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", + "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", + "The number of arguments is different. Specified ", specified.arguments().size(), + " but inferred ", inferred.arguments().size()); + } + if (inferred.returns().size() != specified.returns().size()) { + AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", + "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", + "The number of returns is different.Specified ", specified.returns().size(), + " but inferred ", inferred.returns().size()); + } + + for (size_t i = 0; i < inferred.arguments().size(); ++i) { + if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) { + AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", + "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", + "Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(), + " but inferred ", inferred.arguments()[i].type()->str()); + } + } + + for (size_t i = 0; i < inferred.returns().size(); ++i) { + if (*inferred.returns()[i].type() != *specified.returns()[i].type()) { + AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", + "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", + "Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(), + " but inferred ", inferred.returns()[i].type()->str()); + } + } +} + +} diff --git a/aten/src/ATen/core/op_registration/infer_schema.h b/aten/src/ATen/core/op_registration/infer_schema.h index e040503..36f681e 100644 --- a/aten/src/ATen/core/op_registration/infer_schema.h +++ b/aten/src/ATen/core/op_registration/infer_schema.h @@ -18,63 +18,76 @@ void checkStaticTypes() { // Give nice error messages for some of the common error cases. // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT static_assert( - !std::is_integral::value || std::is_same::value, - "INVALID TYPE: Only int64_t is supported as an integral argument type"); + !std::is_integral::value || std::is_same::value || std::is_same::value, + "INVALID TYPE: Only int64_t and bool are supported as an integral argument type"); static_assert( !std::is_same::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); } -template -void checkStaticTypes() { - checkStaticTypes(); - checkStaticTypes(); -} - template ::std::vector createArgumentVectorFromTypes(guts::index_sequence) { - checkStaticTypes...>(); + // Check types for common errors + (void)std::initializer_list{( + checkStaticTypes() + , 0)...}; + // Arguments are named "_" - return {Argument("_" + std::to_string(Is), getTypePtr>())...}; + return {Argument("_" + c10::guts::to_string(Is), getTypePtr>())...}; } -template -::std::vector createReturns(guts::index_sequence) { - return createArgumentVectorFromTypes(); -} +/// Creates a vector of `Argument` from a list of C++ types that are specified +/// as template arguments. +template struct createArguments final {}; +template +struct createArguments> final { + static std::vector call() { + return createArgumentVectorFromTypes( + guts::make_index_sequence() + ); + } +}; -/// Unpack a tuple return type into a vector of return types, one per tuple -/// element. -template -::std::vector createReturns(std::tuple* tuple) { - return createReturns(guts::make_index_sequence()); -} +/// Creates a vector of `Argument` from a list of C++ types that are specified +/// as a tuple (i.e. in the way c10 kernels return values). +/// It can be a tuple if there's three output arguments with types A, B, C. +/// It can be an empty tuple<>, or void for kernels that don't return anything. +/// It can be a single type A (i.e. no tuple) for the case where a kernel just +/// returns one value. +template struct createReturns final {}; -/// Create a single-element `vector` for simple (non-tuple) return types. -template -::std::vector createReturns(ReturnType*) { - checkStaticTypes>(); - return {Argument("_1", getTypePtr>())}; -} +template +struct createReturns, void> final { + static std::vector call() { + return createArgumentVectorFromTypes( + guts::make_index_sequence() + ); + } +}; -/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices -/// into the argument list. -template -::std::vector createArgumentVectorFromTraits(guts::index_sequence indices) { - using ArgumentTypes = typename FunctionTraits::parameter_types; - return createArgumentVectorFromTypes< - c10::guts::typelist::element_t...>(indices); -} +template +struct createReturns::value && !guts::is_instantiation_of::value>> final { + static std::vector call() { + return createReturns>::call(); + } +}; + +template<> +struct createReturns final { + static std::vector call() { + return createReturns>::call(); + } +}; /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a /// function. template FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) { using ReturnType = typename FunctionTraits::return_type; + using ParameterTypes = typename FunctionTraits::parameter_types; - auto arguments = createArgumentVectorFromTraits( - guts::make_index_sequence()); - auto returns = createReturns(static_cast(nullptr)); + auto arguments = createArguments::call(); + auto returns = createReturns::call(); return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)}; } @@ -85,4 +98,6 @@ FunctionSchema inferFunctionSchema(std::string name, std::string overload_name) return detail::createFunctionSchemaFromTraits>(std::move(name), std::move(overload_name)); } +C10_API void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified); + } diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp index 02231d3..4d86d28 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_test.cpp @@ -8,6 +8,7 @@ using c10::RegisterOperators; using c10::FunctionSchema; using c10::Argument; using c10::IntType; +using c10::FloatType; using c10::ListType; using c10::kernel; using c10::dispatchKey; @@ -29,21 +30,23 @@ C10_DEFINE_TENSOR_TYPE(TensorType1); C10_DECLARE_TENSOR_TYPE(TensorType2); C10_DEFINE_TENSOR_TYPE(TensorType2); -void errorKernel(const Tensor&) { +int64_t errorKernel(const Tensor& tensor, int64_t input) { EXPECT_TRUE(false); // this kernel should never be called + return 0; } FunctionSchema errorOpSchema( "_test::error", "", - (std::vector{Argument("dummy")}), - (std::vector{})); + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); -int incrementKernel(const Tensor& tensor, int input) { +int64_t incrementKernel(const Tensor& tensor, int64_t input) { return input + 1; } -int decrementKernel(const Tensor& tensor, int input) { +int64_t decrementKernel(const Tensor& tensor, int64_t input) { return input - 1; } @@ -159,7 +162,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_wh EXPECT_EQ(0, result.size()); } -int kernelWithIntOutput(Tensor, int a, int b) { +int64_t kernelWithIntOutput(Tensor, int64_t a, int64_t b) { return a + b; } @@ -237,7 +240,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutp EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id()); } -std::vector kernelWithIntListOutput(const Tensor&, int input1, int input2, int input3) { +std::vector kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) { return {input1, input2, input3}; } @@ -393,9 +396,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByV EXPECT_EQ(TensorType2(), captured_input.type_id()); } -int captured_int_input = 0; +int64_t captured_int_input = 0; -void kernelWithIntInputWithoutOutput(Tensor, int input1) { +void kernelWithIntInputWithoutOutput(Tensor, int64_t input1) { captured_int_input = input1; } @@ -419,7 +422,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_witho EXPECT_EQ(3, captured_int_input); } -int kernelWithIntInputWithOutput(Tensor, int input1) { +int64_t kernelWithIntInputWithOutput(Tensor, int64_t input1) { return input1 + 1; } @@ -442,7 +445,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withO EXPECT_EQ(4, outputs[0].toInt()); } -int captured_input_list_size = 0; +int64_t captured_input_list_size = 0; void kernelWithIntListInputWithoutOutput(Tensor, ArrayRef input1) { captured_input_list_size = input1.size(); @@ -468,7 +471,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_w EXPECT_EQ(3, captured_input_list_size); } -int kernelWithIntListInputWithOutput(Tensor, ArrayRef input1) { +int64_t kernelWithIntListInputWithOutput(Tensor, ArrayRef input1) { return input1.size(); } @@ -514,7 +517,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu EXPECT_EQ(2, captured_input_list_size); } -int kernelWithTensorListInputWithOutput(ArrayRef input1) { +int64_t kernelWithTensorListInputWithOutput(ArrayRef input1) { return input1.size(); } @@ -536,4 +539,308 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu EXPECT_EQ(2, outputs[0].toInt()); } +template struct kernel_func final { + static Return func(Args...) { return {}; } +}; +template struct kernel_func final { + static void func(Args...) {} +}; + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{}), + (std::vector{}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), + (std::vector{}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), + Argument("ret2", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret"), Argument("ret2")}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2")}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1")}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", IntType::get())}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), + c10::Error + ); +} + } diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h index 19014a9..86285a1 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor.h +++ b/aten/src/ATen/core/op_registration/kernel_functor.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace c10 { /** @@ -129,6 +130,14 @@ namespace detail { private: std::tuple constructor_parameters_; }; + + template + class FunctionSchemaInferer final { + public: + std::unique_ptr operator()() const { + return guts::make_unique(inferFunctionSchema("", "")); + } + }; } /** @@ -168,20 +177,15 @@ namespace detail { * > c10::dispatchKey(CPUTensorId())); */ template -inline constexpr auto kernel(ConstructorParameters&&... constructorParameters) // enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel --> guts::enable_if_t< -guts::is_functor::value && std::is_base_of::value, -decltype(kernel( +inline constexpr guts::enable_if_t::value && std::is_base_of::value, +detail::KernelRegistrationConfigParameter...>, detail::FunctionSchemaInferer>> +kernel(ConstructorParameters&&... constructorParameters) { + return { &detail::wrap_kernel_functor::call, - detail::KernelFactory...>(std::forward(constructorParameters)...) -))> { - static_assert(std::is_constructible::value, "KernelFunctor cannot be constructed with the given arguments"); - - return kernel( - &detail::wrap_kernel_functor::call, - detail::KernelFactory...>(std::forward(constructorParameters)...) - ); + detail::KernelFactory...>(std::forward(constructorParameters)...), + detail::FunctionSchemaInferer() + }; } } diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp index 3e638ff..49e39bb 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp @@ -9,6 +9,7 @@ using c10::FunctionSchema; using c10::OperatorKernel; using c10::Argument; using c10::IntType; +using c10::FloatType; using c10::ListType; using c10::kernel; using c10::dispatchKey; @@ -31,25 +32,27 @@ C10_DECLARE_TENSOR_TYPE(TensorType2); C10_DEFINE_TENSOR_TYPE(TensorType2); struct ErrorKernel final : public OperatorKernel { - void operator()(const Tensor&) { + int64_t operator()(const Tensor&, int64_t) { EXPECT_TRUE(false); // this kernel should never be called + return 0; } }; FunctionSchema errorOpSchema( "_test::error", "", - (std::vector{Argument("dummy")}), - (std::vector{})); + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); struct IncrementKernel final : OperatorKernel { - int operator()(const Tensor& tensor, int input) { + int64_t operator()(const Tensor& tensor, int64_t input) { return input + 1; } }; struct DecrementKernel final : OperatorKernel { - int operator()(const Tensor& tensor, int input) { + int64_t operator()(const Tensor& tensor, int64_t input) { return input - 1; } }; @@ -171,7 +174,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whe } struct KernelWithIntOutput final : OperatorKernel { - int operator()(Tensor, int a, int b) { + int64_t operator()(Tensor, int64_t a, int64_t b) { return a + b; } }; @@ -256,7 +259,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutpu } struct KernelWithIntListOutput final : OperatorKernel { - std::vector operator()(const Tensor&, int input1, int input2, int input3) { + std::vector operator()(const Tensor&, int64_t input1, int64_t input2, int64_t input3) { return {input1, input2, input3}; } }; @@ -423,10 +426,10 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByVa EXPECT_EQ(TensorType2(), captured_input.type_id()); } -int captured_int_input = 0; +int64_t captured_int_input = 0; struct KernelWithIntInputWithoutOutput final : OperatorKernel { - void operator()(Tensor, int input1) { + void operator()(Tensor, int64_t input1) { captured_int_input = input1; } }; @@ -452,7 +455,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withou } struct KernelWithIntInputWithOutput final : OperatorKernel { - int operator()(Tensor, int input1) { + int64_t operator()(Tensor, int64_t input1) { return input1 + 1; } }; @@ -476,7 +479,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOu EXPECT_EQ(4, outputs[0].toInt()); } -int captured_input_list_size = 0; +int64_t captured_input_list_size = 0; struct KernelWithIntListInputWithoutOutput final : OperatorKernel { void operator()(Tensor, ArrayRef input1) { @@ -505,7 +508,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_wi } struct KernelWithIntListInputWithOutput final : OperatorKernel { - int operator()(Tensor, ArrayRef input1) { + int64_t operator()(Tensor, ArrayRef input1) { return input1.size(); } }; @@ -555,7 +558,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput } struct KernelWithTensorListInputWithOutput final : OperatorKernel { - int operator()(ArrayRef input1) { + int64_t operator()(ArrayRef input1) { return input1.size(); } }; @@ -689,5 +692,308 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstru EXPECT_EQ(13, outputs[0].toInt()); } +template struct KernelFunc final : OperatorKernel{ + Return operator()(Args...) { return {}; } +}; +template struct KernelFunc final : OperatorKernel { + void operator()(Args...) {} +}; + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{}), + (std::vector{}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), + (std::vector{}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), + Argument("ret2", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret"), Argument("ret2")}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2")}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1")}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), kernel>(), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", IntType::get())}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) + ), kernel, Tensor>>(), dispatchKey(TensorType1())), + c10::Error + ); +} } diff --git a/aten/src/ATen/core/op_registration/kernel_stackbased.h b/aten/src/ATen/core/op_registration/kernel_stackbased.h index 412a5dd..8ec9819 100644 --- a/aten/src/ATen/core/op_registration/kernel_stackbased.h +++ b/aten/src/ATen/core/op_registration/kernel_stackbased.h @@ -17,29 +17,40 @@ namespace c10 { namespace detail { - template + struct NoFunctionSchemaInference final { + std::unique_ptr operator()() const { + return nullptr; + } + }; + + template struct KernelRegistrationConfigParameter final { template - constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func) - : kernel_func_(kernel_func), cache_creator_func_(std::forward(cache_creator_func)) { + constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func, InferFunctionSchemaFunction&& infer_function_schema_func) + : kernel_func_(kernel_func) + , cache_creator_func_(std::forward(cache_creator_func)) + , infer_function_schema_func_(std::forward(infer_function_schema_func)) { } void apply(KernelRegistrationConfig* registration) const & { registration->kernel_func = kernel_func_; registration->cache_creator_func = cache_creator_func_; + registration->inferred_function_schema = infer_function_schema_func_(); } void apply(KernelRegistrationConfig* registration) && { registration->kernel_func = kernel_func_; registration->cache_creator_func = std::move(cache_creator_func_); + registration->inferred_function_schema = std::move(infer_function_schema_func_)(); } private: KernelFunction* kernel_func_; KernelCacheCreatorFunction_ cache_creator_func_; + InferFunctionSchemaFunction infer_function_schema_func_; }; - static_assert(is_registration_config_parameter>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept"); + static_assert(is_registration_config_parameter>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept"); } /** @@ -61,10 +72,10 @@ namespace detail { * > c10::dispatchKey(CPUTensorId())); */ template -inline constexpr detail::KernelRegistrationConfigParameter> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) { - static_assert(detail::is_registration_config_parameter>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept"); +inline constexpr detail::KernelRegistrationConfigParameter, detail::NoFunctionSchemaInference> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) { + static_assert(detail::is_registration_config_parameter, detail::NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept"); - return {kernel_func, std::forward(cache_creator)}; + return {kernel_func, std::forward(cache_creator), detail::NoFunctionSchemaInference()}; } } diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 96e12ce..ff0ec8c 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace c10 { @@ -65,6 +66,11 @@ public: guts::enable_if_t>...>::value, RegisterOperators> op(FunctionSchema schema, ConfigParameters&&... configParameters) && { detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward(configParameters)...); + + if (config.inferred_function_schema.get() != nullptr) { + assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); + } + registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func)); return std::move(*this); } diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 2d7345b..8f71660 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -229,6 +229,9 @@ class DummyClassForToString final {}; namespace std { // We use SFINAE to detect if std::to_string exists for a type, but that only works // if the function name is defined. So let's define a std::to_string for a dummy type. +// If you're getting an error here saying that this overload doesn't match your +// std::to_string() call, then you're calling std::to_string() but should be calling +// c10::guts::to_string(). inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; } } namespace c10 { namespace guts { namespace detail { diff --git a/c10/util/TypeTraits.h b/c10/util/TypeTraits.h index ec6437a..b4e04ea 100644 --- a/c10/util/TypeTraits.h +++ b/c10/util/TypeTraits.h @@ -61,5 +61,6 @@ struct is_type_condition : std::false_type {}; template class C> struct is_type_condition::value)>>::value>> : std::true_type {}; + } } diff --git a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc index c23aff3..c84dd3f 100644 --- a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc @@ -111,7 +111,7 @@ static auto registry = c10::RegisterOperators().op( (std::vector{ c10::Argument("inputs", ListType::ofTensors()), c10::Argument("output"), - c10::Argument("split_info", FloatType::get()), + c10::Argument("split_info"), c10::Argument("add", IntType::get()), c10::Argument("add_axis", IntType::get())}), (std::vector{})), -- 2.7.4