From: Sebastian Messmer Date: Mon, 1 Apr 2019 21:53:08 +0000 (-0700) Subject: Deprecated lambda based API (#18542) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~500 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4a0f842d42dd5eb36b9aa77b52be36b7f6a18a35;p=platform%2Fupstream%2Fpytorch.git Deprecated lambda based API (#18542) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18542 This adds the deprecated API for defining kernels as lambdas. The new API for defining kernels as lambdas was introduced in D14653005. Reviewed By: dzhulgakov Differential Revision: D14653551 fbshipit-source-id: 99900f1436716c69e52c83b68333b642ec2c8558 --- diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp new file mode 100644 index 0000000..8f45d34 --- /dev/null +++ b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp @@ -0,0 +1,786 @@ +#include +#include + +#include +#include + +/** + * This file tests the legacy lambda-based API for registering kernels: + * + * > auto registry = c10::RegisterOperators() + * > .op("myfunc(Tensor a) -> Tensor", [] (Tensor a) -> Tensor {...}); + */ + +using c10::RegisterOperators; +using c10::FunctionSchema; +using c10::Argument; +using c10::IntType; +using c10::FloatType; +using c10::ListType; +using c10::kernel; +using c10::dispatchKey; +using c10::TensorTypeId; +using c10::KernelCache; +using c10::Stack; +using c10::guts::make_unique; +using c10::ivalue::TensorList; +using c10::ivalue::IntList; +using c10::intrusive_ptr; +using c10::ArrayRef; +using std::unique_ptr; +using at::Tensor; + +namespace { + +C10_DECLARE_TENSOR_TYPE(TensorType1); +C10_DEFINE_TENSOR_TYPE(TensorType1); +C10_DECLARE_TENSOR_TYPE(TensorType2); +C10_DEFINE_TENSOR_TYPE(TensorType2); + +FunctionSchema errorOpSchema( + "_test::error", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); + +FunctionSchema opSchema( + "_test::my_op", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); + +void expectCallsIncrement(TensorTypeId type_id) { + // assert that schema and cpu kernel are present + auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", ""); + ASSERT_TRUE(op.has_value()); + auto result = callOp(*op, dummyTensor(type_id), 5); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(6, result[0].toInt()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t { + return input + 1; + }); + expectCallsIncrement(TensorType1()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { + auto registrar = RegisterOperators() + .op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t { + return input + 1; + }) + .op(errorOpSchema, [] (const Tensor& tensor, int64_t input) -> int64_t { + EXPECT_TRUE(false); // this kernel should never be called + return 0; + }); + expectCallsIncrement(TensorType1()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { + auto registrar1 = RegisterOperators().op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t { + return input + 1; + }); + auto registrar2 = RegisterOperators().op(errorOpSchema, [] (const Tensor& tensor, int64_t input) -> int64_t { + EXPECT_TRUE(false); // this kernel should never be called + return 0; + }); + expectCallsIncrement(TensorType1()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { + { + auto registrar = RegisterOperators().op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t { + return input + 1; + }); + + expectCallsIncrement(TensorType1()); + } + + // now the registrar is destructed. Assert that the schema is gone. + expectDoesntFindOperator("_test::my_op"); +} + +bool was_called = false; + +FunctionSchema opWithoutOutputSchema( + "_test::no_return", + "", + (std::vector{Argument("dummy")}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op(opWithoutOutputSchema, [] (const Tensor&) -> void { + was_called = true; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", ""); + ASSERT_TRUE(op.has_value()); + was_called = false; + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(was_called); + EXPECT_EQ(0, result.size()); +} + +FunctionSchema opWithZeroOutputsSchema( + "_test::zero_outputs", + "", + (std::vector{Argument("dummy")}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, [] (const Tensor&) -> std::tuple<> { + was_called = true; + return std::make_tuple(); + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", ""); + ASSERT_TRUE(op.has_value()); + was_called = false; + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(was_called); + EXPECT_EQ(0, result.size()); +} + +FunctionSchema opWithIntOutputSchema( + "_test::int_output", + "", + (std::vector{Argument("dummy"), + Argument("a", IntType::get()), + Argument("b", IntType::get())}), + (std::vector{Argument("sum", IntType::get())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntOutputSchema, [] (Tensor, int64_t a, int64_t b) -> int64_t { + return a + b; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(9, result[0].toInt()); +} + +FunctionSchema opWithTensorOutput( + "_test::returning_tensor", + "", + (std::vector{Argument("input")}), + (std::vector{Argument("output")})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorOutput, [] (const Tensor& input) -> Tensor { + return input; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType1(), result[0].toTensor().type_id()); + + result = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); +} + +FunctionSchema opWithTensorListOutputSchema( + "_test::list_output", + "", + (std::vector{Argument("input1"), + Argument("input2"), + Argument("input3")}), + (std::vector{Argument("output", ListType::ofTensors())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorListOutputSchema, [] (const Tensor& input1, const Tensor& input2, const Tensor& input3) -> std::vector { + return {input1, input2, input3}; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(3, result[0].toTensorListRef().size()); + EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id()); + EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id()); +} + +FunctionSchema opWithIntListOutputSchema( + "_test::list_output", + "", + (std::vector{Argument("dummy"), + Argument("input1", IntType::get()), + Argument("input2", IntType::get()), + Argument("input3", IntType::get())}), + (std::vector{Argument("output", ListType::ofInts())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntListOutputSchema, [](const Tensor&, int64_t input1, int64_t input2, int64_t input3) -> std::vector { + return {input1, input2, input3}; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(3, result[0].toIntListRef().size()); + EXPECT_EQ(2, result[0].toIntListRef()[0]); + EXPECT_EQ(4, result[0].toIntListRef()[1]); + EXPECT_EQ(6, result[0].toIntListRef()[2]); +} + +FunctionSchema opWithMultipleOutputsSchema( + "_test::multiple_outputs", + "", + (std::vector{Argument("dummy")}), + (std::vector{Argument("output1"), + Argument("output2", IntType::get()), + Argument("output3", ListType::ofTensors())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithMultipleOutputsSchema, [] (Tensor) -> std::tuple> { + return std::tuple>( + dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())} + ); + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(3, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); + EXPECT_EQ(5, result[1].toInt()); + EXPECT_EQ(2, result[2].toTensorListRef().size()); + EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id()); + EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id()); +} + +FunctionSchema opWithTensorInputWithOutput( + "_test::tensor_input", + "", + (std::vector{Argument("input")}), + (std::vector{Argument("output")})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithOutput, [] (const Tensor& input1) -> Tensor { + return input1; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType1(), result[0].toTensor().type_id()); + + result = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithOutput, [](Tensor input1) -> Tensor { + return input1; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto result = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType1(), result[0].toTensor().type_id()); + + result = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(1, result.size()); + EXPECT_EQ(TensorType2(), result[0].toTensor().type_id()); +} + +Tensor captured_input; + +FunctionSchema opWithTensorInputWithoutOutput( + "_test::tensor_input", + "", + (std::vector{Argument("input")}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithoutOutput, [] (const Tensor& input1) -> void { + captured_input = input1; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType1(), captured_input.type_id()); + + outputs = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType2(), captured_input.type_id()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorInputWithoutOutput, [] (Tensor input1) -> void { + captured_input = input1; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType1(), captured_input.type_id()); + + outputs = callOp(*op, dummyTensor(TensorType2())); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(TensorType2(), captured_input.type_id()); +} + +int64_t captured_int_input = 0; + +FunctionSchema opWithIntInputWithoutOutput( + "_test::int_input", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntInputWithoutOutput, [](Tensor, int64_t input1) -> void { + captured_int_input = input1; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", ""); + ASSERT_TRUE(op.has_value()); + + captured_int_input = 0; + auto outputs = callOp(*op, dummyTensor(TensorType1()), 3); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(3, captured_int_input); +} + +FunctionSchema opWithIntInputWithOutput( + "_test::int_input", + "", + (std::vector{Argument("dummy"), + Argument("input", IntType::get())}), + (std::vector{Argument("output", IntType::get())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntInputWithOutput, [] (Tensor, int64_t input1) -> int64_t { + return input1 + 1; + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), 3); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(4, outputs[0].toInt()); +} + +int64_t captured_input_list_size = 0; + +FunctionSchema opWithIntListInputWithoutOutput( + "_test::int_list_input", + "", + (std::vector{Argument("dummy"), + Argument("input", ListType::ofInts())}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntListInputWithoutOutput, [] (Tensor, ArrayRef input1) -> void { + captured_input_list_size = input1.size(); + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", ""); + ASSERT_TRUE(op.has_value()); + + captured_input_list_size = 0; + auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6})); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(3, captured_input_list_size); +} + +FunctionSchema opWithIntListInputWithOutput( + "_test::int_list_input", + "", + (std::vector{Argument("dummy"), + Argument("input", ListType::ofInts())}), + (std::vector{Argument("output", IntType::get())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithIntListInputWithOutput, [](Tensor, ArrayRef input1) -> int64_t { + return input1.size(); + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6})); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(3, outputs[0].toInt()); +} + +FunctionSchema opWithTensorListInputWithoutOutput( + "_test::tensor_list_input", + "", + (std::vector{Argument("input", ListType::ofTensors())}), + (std::vector{})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorListInputWithoutOutput, [] (ArrayRef input1) -> void { + captured_input_list_size = input1.size(); + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", ""); + ASSERT_TRUE(op.has_value()); + + captured_input_list_size = 0; + auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())})); + EXPECT_EQ(0, outputs.size()); + EXPECT_EQ(2, captured_input_list_size); +} + +FunctionSchema opWithTensorListInputWithOutput( + "_test::tensor_list_input", + "", + (std::vector{Argument("input", ListType::ofTensors())}), + (std::vector{Argument("output", IntType::get())})); + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators() + .op(opWithTensorListInputWithOutput, [] (ArrayRef input1) -> int64_t { + return input1.size(); + }); + + auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())})); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(2, outputs[0].toInt()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor) -> int64_t {return 0;}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor) -> int64_t {return 0;}), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2")}), + (std::vector{}) + ), [] (Tensor, Tensor) -> void {}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{}), + (std::vector{}) + ), [] (Tensor, Tensor) -> void {}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), [] (Tensor, Tensor) -> void {}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), + (std::vector{}) + ), [] (Tensor, Tensor) -> void {}), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor, int64_t) -> int64_t {return 0;}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor, int64_t) -> int64_t {return 0;}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor, int64_t) -> int64_t {return 0;}), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor) -> int64_t {return 0;}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), [] (Tensor) -> int64_t {return 0;}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), + Argument("ret2", IntType::get())}) + ), [] (Tensor) -> int64_t {return 0;}), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), [] (Tensor) -> void {}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), [] (Tensor) -> void {}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret"), Argument("ret2")}) + ), [] (Tensor) -> void {}), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2")}) + ), [] (Tensor) -> std::tuple {return {};}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{}) + ), [] (Tensor) -> std::tuple {return {};}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1")}) + ), [] (Tensor) -> std::tuple {return {};}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) + ), [] (Tensor) -> std::tuple {return {};}), + c10::Error + ); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", IntType::get())}) + ), [] (Tensor) -> int64_t {return 0;}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), [] (Tensor) -> int64_t {return 0;}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), [] (Tensor) -> int64_t {return 0;}), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret")}) + ), [] (Tensor) -> Tensor {return {};}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret", FloatType::get())}) + ), [] (Tensor) -> Tensor {return {};}), + c10::Error + ); + + // assert this does not fail because it matches + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", IntType::get())}) + ), [] (Tensor) -> std::tuple {return {};}); + + // and now a set of mismatching schemas + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) + ), [] (Tensor) -> std::tuple {return {};}), + c10::Error + ); + + EXPECT_THROW( + RegisterOperators() + .op(FunctionSchema( + "_test::mismatch", + "", + (std::vector{Argument("arg")}), + (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) + ), [] (Tensor) -> std::tuple {return {};}), + c10::Error + ); +} + +} diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 15d03e3..645c265 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -70,8 +70,6 @@ public: return std::move(*this); } - // TODO allow input schema to be just the operator name + overload name, in that case use schema generated from kernel function - /** * Deprecated. For backwards compatibility only. * Don't use this, it introduces a performance overhead on each kernel call @@ -114,7 +112,43 @@ public: return std::move(*this).op(std::move(schema), kernel>(func)); } - // TODO Add deprecated lambda-based API + /** + * Deprecated. For backwards compatibility only. + * + * This deprecated API looks like: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", [] (Tensor a, Tensor b) {...}); + * + * But you should use the new API instead: + * + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", kernel([] (Tensor a, Tensor b) {...})); + * + * Or, alternatively, write your kernel as a functor: + * + * > namespace { + * > class my_kernel_cpu final : public c10::OperatorKernel { + * > public: + * > Tensor operator()(Tensor a, Tensor b) {...} + * > }; + * > } + * > + * > static auto registry = c10::RegisterOperators() + * > .op("my_op", c10::kernel()); + */ + template + C10_DEPRECATED_MESSAGE("Registering kernels via passing lambdas to op() directly is deprecated. " \ + "Please use the new c10::kernel() based API instead.") + // enable_if: only enable it if FuncType is actually a functor, but doesn't inherit from OperatorKernel. + guts::enable_if_t::value && !std::is_base_of::value, RegisterOperators> + op(FunctionSchema schema, FuncType&& func) && { + // We intentionally don't extend this deprecated API to support dispatch keys + // and the like to push people towards using the new API. + return std::move(*this).op(std::move(schema), kernel>(std::forward(func))); + } + + // TODO allow input schema to be just the operator name + overload name, in that case use schema generated from kernel function private: void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);