From f4e87e193a06309923ed1b5df153d4a1922459c8 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Sat, 30 Mar 2019 00:03:46 -0700 Subject: [PATCH] Introduce lambda-based kernel API (#18541) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18541 Allow registering lambdas as c10 kernels. Reviewed By: dzhulgakov Differential Revision: D14653005 fbshipit-source-id: f867cc776b1339e83b7a2e1935f5cf924cfba44a --- .../src/ATen/core/op_registration/kernel_functor.h | 18 +- aten/src/ATen/core/op_registration/kernel_lambda.h | 60 ++ .../core/op_registration/kernel_lambda_test.cpp | 803 +++++++++++++++++++++ .../ATen/core/op_registration/op_registration.h | 1 + c10/test/util/TypeTraits_test.cpp | 41 ++ c10/util/Metaprogramming.h | 28 - c10/util/TypeTraits.h | 58 ++ 7 files changed, 976 insertions(+), 33 deletions(-) create mode 100644 aten/src/ATen/core/op_registration/kernel_lambda.h create mode 100644 aten/src/ATen/core/op_registration/kernel_lambda_test.cpp diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h index 86285a1..427de09 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor.h +++ b/aten/src/ATen/core/op_registration/kernel_functor.h @@ -138,6 +138,18 @@ namespace detail { return guts::make_unique(inferFunctionSchema("", "")); } }; + + template + // enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel + inline constexpr guts::enable_if_t::value && std::is_base_of::value, + detail::KernelRegistrationConfigParameter...>, detail::FunctionSchemaInferer>> + kernelFunctor(ConstructorParameters&&... constructorParameters) { + return { + &detail::wrap_kernel_functor::call, + detail::KernelFactory...>(std::forward(constructorParameters)...), + detail::FunctionSchemaInferer() + }; + } } /** @@ -181,11 +193,7 @@ template inline constexpr guts::enable_if_t::value && std::is_base_of::value, detail::KernelRegistrationConfigParameter...>, detail::FunctionSchemaInferer>> kernel(ConstructorParameters&&... constructorParameters) { - return { - &detail::wrap_kernel_functor::call, - detail::KernelFactory...>(std::forward(constructorParameters)...), - detail::FunctionSchemaInferer() - }; + return detail::kernelFunctor(std::forward(constructorParameters)...); } } diff --git a/aten/src/ATen/core/op_registration/kernel_lambda.h b/aten/src/ATen/core/op_registration/kernel_lambda.h new file mode 100644 index 0000000..a5b8244 --- /dev/null +++ b/aten/src/ATen/core/op_registration/kernel_lambda.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +namespace c10 { + +namespace detail { + // WrapRuntimeKernelFunctor: Wraps any runtime functor into a functor that + // inherits from c10::OperatorKernel, so it can be used as a c10 kernel. + // This can, for example, be used for lamdas, functors or even function pointers. + // In the case of function pointers, since it is a runtime function pointer, + // there is an overhead for calling it whenever the kernel is invoked. + template class WrapRuntimeKernelFunctor_ {}; + template + class WrapRuntimeKernelFunctor_> final : public c10::OperatorKernel { + public: + template + explicit WrapRuntimeKernelFunctor_(FuncType_&& kernel_func) + : kernel_func_(std::forward(kernel_func)) {} + + auto operator()(Parameters&&... args) -> decltype(std::declval()(std::forward(args)...)) { + return kernel_func_(std::forward(args)...); + } + + private: + FuncType kernel_func_; + }; + template + using WrapRuntimeKernelFunctor = WrapRuntimeKernelFunctor_< + FuncType, + typename guts::infer_function_traits_t::return_type, + typename guts::infer_function_traits_t::parameter_types + >; +} + +/** + * Use this to register an operator whose kernel is implemented as a stateless lambda. + * + * Example: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", + * > c10::kernel([] (Tensor a) -> Tensor{...}), + * > c10::dispatchKey(CPUTensorId())); + */ +template +inline constexpr auto kernel(Lambda&& functor) -> +guts::enable_if_t>::value, +decltype(detail::kernelFunctor>>(std::forward(functor)))> { + // We don't support stateful lambdas (i.e. lambdas with a capture), because their + // behavior would be nonobvious. A functor kernel with cache gets a new instance of + // its cache each time the kernel is looked up from the dispatch table. + // A lambda with a capture would be global and share its capture between all kernel lookups. + // So, instead of making users having to think about it (including the thread-safety + // issues this causes), let's just forbid stateful lambdas alltogether. + return detail::kernelFunctor>>(std::forward(functor)); +} + +} diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp new file mode 100644 index 0000000..d9defec --- /dev/null +++ b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp @@ -0,0 +1,803 @@ +#include +#include + +#include +#include + +using c10::RegisterOperators; +using c10::FunctionSchema; +using c10::Argument; +using c10::IntType; +using c10::FloatType; +using c10::ListType; +using c10::kernel; +using c10::dispatchKey; +using c10::TensorTypeId; +using c10::KernelCache; +using c10::Stack; +using c10::guts::make_unique; +using c10::ivalue::TensorList; +using c10::ivalue::IntList; +using c10::intrusive_ptr; +using c10::ArrayRef; +using std::unique_ptr; +using at::Tensor; + +namespace { + +C10_DECLARE_TENSOR_TYPE(TensorType1); +C10_DEFINE_TENSOR_TYPE(TensorType1); +C10_DECLARE_TENSOR_TYPE(TensorType2); +C10_DEFINE_TENSOR_TYPE(TensorType2); + +FunctionSchema errorOpSchema( + "_test::error", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); + +FunctionSchema opSchema( + "_test::my_op", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); + +void expectCallsIncrement(TensorTypeId type_id) { + // assert that schema and cpu kernel are present + auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", ""); + ASSERT_TRUE(op.has_value()); + auto result = callOp(*op, dummyTensor(type_id), 5); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(6, result[0].toInt()); +} + +void expectCallsDecrement(TensorTypeId type_id) { + // assert that schema and cpu kernel are present + auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", ""); + ASSERT_TRUE(op.has_value()); + auto result = callOp(*op, dummyTensor(type_id), 5); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(4, result[0].toInt()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1())); + expectCallsIncrement(TensorType1()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { + auto registrar = RegisterOperators() + .op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1())) + .op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2())) + .op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1())) + .op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2())); + expectCallsIncrement(TensorType1()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { + auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1())); + auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2())); + auto registrar3 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1())); + auto registrar4 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2())); + expectCallsIncrement(TensorType1()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { + { + auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1())); + { + auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i-1;}), dispatchKey(TensorType2())); + + // assert that schema and cpu kernel are present + expectCallsIncrement(TensorType1()); + expectCallsDecrement(TensorType2()); + } + + // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not + expectCallsIncrement(TensorType1()); + expectDoesntFindKernel("_test::my_op", TensorType2()); + } + + // now both registrars are destructed. Assert that the whole schema is gone + expectDoesntFindOperator("_test::my_op"); +} + +bool was_called = false; + +FunctionSchema opWithoutOutputSchema( + "_test::no_return", + "", + (std::vector{Argument("dummy")}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op(opWithoutOutputSchema, + kernel([] (const Tensor&) -> void {was_called = true;}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", ""); + ASSERT_TRUE(op.has_value()); + was_called = false; + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(was_called); + EXPECT_EQ(0, result.size()); +} + +FunctionSchema opWithZeroOutputsSchema( + "_test::zero_outputs", + "", + (std::vector{Argument("dummy")}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, + kernel([] (const Tensor&) -> std::tuple<> {was_called = true; return {};}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", ""); + ASSERT_TRUE(op.has_value()); + was_called = false; + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(was_called); + EXPECT_EQ(0, result.size()); +} + +FunctionSchema opWithIntOutputSchema( + "_test::int_output", + "", + (std::vector{Argument("dummy"), + Argument("a", IntType::get()), + Argument("b", IntType::get())}), + (std::vector{Argument("sum", IntType::get())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntOutputSchema, + kernel([] (Tensor, int64_t a, int64_t b) {return a+b;}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(9, result[0].toInt()); +} + +FunctionSchema opWithTensorOutput( + "_test::returning_tensor", + "", + (std::vector{Argument("input")}), + (std::vector{Argument("output")})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorOutput, + kernel([] (const Tensor& a) {return a;}), + dispatchKey(TensorType1())) + .op(opWithTensorOutput, + kernel([] (const Tensor& a) {return a;}), + dispatchKey(TensorType2())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType1(), result[0].toTensor().type_id()); + + result = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); +} + +FunctionSchema opWithTensorListOutputSchema( + "_test::list_output", + "", + (std::vector{Argument("input1"), + Argument("input2"), + Argument("input3")}), + (std::vector{Argument("output", ListType::ofTensors())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorListOutputSchema, + kernel([] (const Tensor& a, const Tensor& b, const Tensor& c) -> std::vector {return {a, b, c};}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(3, result[0].toTensorListRef().size()); + EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id()); + EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id()); +} + +FunctionSchema opWithIntListOutputSchema( + "_test::list_output", + "", + (std::vector{Argument("dummy"), + Argument("input1", IntType::get()), + Argument("input2", IntType::get()), + Argument("input3", IntType::get())}), + (std::vector{Argument("output", ListType::ofInts())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntListOutputSchema, + kernel([] (const Tensor&, int64_t a, int64_t b, int64_t c) -> std::vector {return {a,b,c};}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(3, result[0].toIntListRef().size()); + EXPECT_EQ(2, result[0].toIntListRef()[0]); + EXPECT_EQ(4, result[0].toIntListRef()[1]); + EXPECT_EQ(6, result[0].toIntListRef()[2]); +} + +FunctionSchema opWithMultipleOutputsSchema( + "_test::multiple_outputs", + "", + (std::vector{Argument("dummy")}), + (std::vector{Argument("output1"), + Argument("output2", IntType::get()), + Argument("output3", ListType::ofTensors())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithMultipleOutputsSchema, + kernel([] (Tensor) -> std::tuple> { + return std::tuple>( + dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())} + ); + }), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(3, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); + EXPECT_EQ(5, result[1].toInt()); + EXPECT_EQ(2, result[2].toTensorListRef().size()); + EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id()); + EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id()); +} + +FunctionSchema opWithTensorInputWithOutput( + "_test::tensor_input", + "", + (std::vector{Argument("input")}), + (std::vector{Argument("output")})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithOutput, + kernel([] (const Tensor& a) {return a;}), + dispatchKey(TensorType1())) + .op(opWithTensorInputWithOutput, + kernel([] (const Tensor& a) {return a;}), + dispatchKey(TensorType2())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType1(), result[0].toTensor().type_id()); + + result = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithOutput, + kernel([] (Tensor a) {return a;}), + dispatchKey(TensorType1())) + .op(opWithTensorInputWithOutput, + kernel([] (Tensor a) {return a;}), + dispatchKey(TensorType2())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType1(), result[0].toTensor().type_id()); + + result = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); +} + +Tensor captured_input; + +FunctionSchema opWithTensorInputWithoutOutput( + "_test::tensor_input", + "", + (std::vector{Argument("input")}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithoutOutput, + kernel([] (const Tensor& a) -> void {captured_input = a;}), + dispatchKey(TensorType1())) + .op(opWithTensorInputWithoutOutput, + kernel([] (const Tensor& a) -> void {captured_input = a;}), + dispatchKey(TensorType2())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType1(), captured_input.type_id()); + + outputs = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType2(), captured_input.type_id()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithoutOutput, + kernel([] (Tensor a) -> void {captured_input = a;}), + dispatchKey(TensorType1())) + .op(opWithTensorInputWithoutOutput, + kernel([] (Tensor a) -> void {captured_input = a;}), + dispatchKey(TensorType2())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType1(), captured_input.type_id()); + + outputs = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType2(), captured_input.type_id()); +} + +int64_t captured_int_input = 0; + +FunctionSchema opWithIntInputWithoutOutput( + "_test::int_input", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntInputWithoutOutput, + kernel([] (Tensor, int64_t a) -> void {captured_int_input = a;}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", ""); + ASSERT_TRUE(op.has_value()); + + captured_int_input = 0; + auto outputs = callOp(*op, dummyTensor(TensorType1()), 3); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(3, captured_int_input); +} + +FunctionSchema opWithIntInputWithOutput( + "_test::int_input", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntInputWithOutput, + kernel([] (Tensor, int64_t a) {return a + 1;}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), 3); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(4, outputs[0].toInt()); +} + +int64_t captured_input_list_size = 0; + +FunctionSchema opWithIntListInputWithoutOutput( + "_test::int_list_input", + "", + (std::vector{Argument("dummy"), + Argument("input", ListType::ofInts())}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntListInputWithoutOutput, + kernel([] (Tensor, ArrayRef a) {captured_input_list_size = a.size();}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", ""); + ASSERT_TRUE(op.has_value()); + + captured_input_list_size = 0; + auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6})); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(3, captured_input_list_size); +} + +FunctionSchema opWithIntListInputWithOutput( + "_test::int_list_input", + "", + (std::vector{Argument("dummy"), + Argument("input", ListType::ofInts())}), + (std::vector{Argument("output", IntType::get())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntListInputWithOutput, + kernel([] (Tensor, ArrayRef a) -> int64_t {return a.size();}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6})); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(3, outputs[0].toInt()); +} + +FunctionSchema opWithTensorListInputWithoutOutput( + "_test::tensor_list_input", + "", + (std::vector{Argument("input", ListType::ofTensors())}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorListInputWithoutOutput, + kernel([] (ArrayRef a) -> void {captured_input_list_size = a.size();}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", ""); + ASSERT_TRUE(op.has_value()); + + captured_input_list_size = 0; + auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())})); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(2, captured_input_list_size); +} + +FunctionSchema opWithTensorListInputWithOutput( + "_test::tensor_list_input", + "", + (std::vector{Argument("input", ListType::ofTensors())}), + (std::vector{Argument("output", IntType::get())})); + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorListInputWithOutput, + kernel([] (ArrayRef a) -> int64_t {return a.size();}), + dispatchKey(TensorType1())); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())})); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(2, outputs[0].toInt()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{}) + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{}), + (std::vector{}) + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), + (std::vector{}) + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), + Argument("ret2", IntType::get())}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret"), Argument("ret2")}) + ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2")}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1")}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", IntType::get())}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), + c10::Error + ); +} + +} diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 7861eee..29b4895 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace c10 { diff --git a/c10/test/util/TypeTraits_test.cpp b/c10/test/util/TypeTraits_test.cpp index e8c2bdb..c93c81e 100644 --- a/c10/test/util/TypeTraits_test.cpp +++ b/c10/test/util/TypeTraits_test.cpp @@ -112,3 +112,44 @@ namespace test_is_type_condition { static_assert(!is_type_condition::value, ""); } } + +namespace test_lambda_is_stateless { + template + struct MyStatelessFunctor final { + Result operator()(Args...) {} + }; + + template + struct MyStatelessConstFunctor final { + Result operator()(Args...) const {} + }; + + void func() { + auto stateless_lambda = [] (int a) {return a;}; + static_assert(is_stateless_lambda::value, ""); + + int b = 4; + auto stateful_lambda_1 = [&] (int a) {return a + b;}; + static_assert(!is_stateless_lambda::value, ""); + + auto stateful_lambda_2 = [=] (int a) {return a + b;}; + static_assert(!is_stateless_lambda::value, ""); + + auto stateful_lambda_3 = [b] (int a) {return a + b;}; + static_assert(!is_stateless_lambda::value, ""); + + static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); + static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); + static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); + static_assert(!is_stateless_lambda>::value, "even if stateless, a functor is not a lambda, so it's false"); + + class Dummy final {}; + static_assert(!is_stateless_lambda::value, "A non-functor type is also not a lambda"); + + static_assert(!is_stateless_lambda::value, "An int is not a lambda"); + + using Func = int(int); + static_assert(!is_stateless_lambda::value, "A function is not a lambda"); + static_assert(!is_stateless_lambda::value, "A function pointer is not a lambda"); + } +} diff --git a/c10/util/Metaprogramming.h b/c10/util/Metaprogramming.h index 60e4553..645b6cf 100644 --- a/c10/util/Metaprogramming.h +++ b/c10/util/Metaprogramming.h @@ -7,24 +7,6 @@ #include namespace c10 { namespace guts { -namespace detail { -/** - * strip_class: helper to remove the class type from pointers to `operator()`. - */ - -template -struct strip_class {}; -template -struct strip_class { - using type = Result(Args...); -}; -template -struct strip_class { - using type = Result(Args...); -}; -template -using strip_class_t = typename strip_class::type; -} // namespace detail /** * Access information about result type or arguments from a function type. @@ -44,16 +26,6 @@ struct function_traits { }; /** - * Evaluates to true_type, iff the given class is a Functor - * (i.e. has a call operator with some set of arguments) - */ - -template -struct is_functor : std::false_type {}; -template -struct is_functor>::value>> : std::true_type {}; - -/** * infer_function_traits: creates a `function_traits` type for a simple * function (pointer) or functor (lambda/struct). Currently does not support * class methods. diff --git a/c10/util/TypeTraits.h b/c10/util/TypeTraits.h index b4e04ea..6eee698 100644 --- a/c10/util/TypeTraits.h +++ b/c10/util/TypeTraits.h @@ -49,6 +49,64 @@ template