From 12d6f79ecdb6b4ec17b1eef0e726e887636d516e Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 18 Apr 2019 02:00:51 -0700 Subject: [PATCH] Optional inputs and outputs (#19289) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19289 Allow optional inputs and outputs in native c10 operators Reviewed By: dzhulgakov Differential Revision: D14931927 fbshipit-source-id: 48f8bec009c6374345b34d933f148c08bb4f7118 --- .../kernel_function_legacy_test.cpp | 98 +++++++++++++++++++ .../core/op_registration/kernel_function_test.cpp | 98 +++++++++++++++++++ .../src/ATen/core/op_registration/kernel_functor.h | 29 +++++- .../core/op_registration/kernel_functor_test.cpp | 104 ++++++++++++++++++++ .../op_registration/kernel_lambda_legacy_test.cpp | 108 +++++++++++++++++++++ .../core/op_registration/kernel_lambda_test.cpp | 100 +++++++++++++++++++ 6 files changed, 535 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp index afd67e5..17cc5b4 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp @@ -456,6 +456,104 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenFallbackKernelWith EXPECT_EQ(4, outputs[0].toInt()); } +c10::optional called_arg2; +c10::optional called_arg3; +c10::optional called_arg4; + +void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; +} + +TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", &kernelWithOptInputWithoutOutput); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +c10::optional kernelWithOptInputWithOutput(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + return arg2; +} + +TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", &kernelWithOptInputWithOutput); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +std::tuple, c10::optional, c10::optional> +kernelWithOptInputWithMultipleOutputs(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + return std::make_tuple(arg2, arg3, arg4); +} + +TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", &kernelWithOptInputWithMultipleOutputs); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(3, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + EXPECT_TRUE(outputs[1].isNone()); + EXPECT_EQ("text", outputs[2].toString()->string()); + + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(3, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + EXPECT_EQ(4, outputs[1].toInt()); + EXPECT_TRUE(outputs[2].isNone()); +} + std::tuple kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef arg3) { return {}; } diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp index deef45c..7dad29b 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_test.cpp @@ -458,6 +458,104 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenFallbackKernelWithoutTen EXPECT_EQ(4, outputs[0].toInt()); } +c10::optional called_arg2; +c10::optional called_arg3; +c10::optional called_arg4; + +void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", kernel(), dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +c10::optional kernelWithOptInputWithOutput(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + return arg2; +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", kernel(), dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +std::tuple, c10::optional, c10::optional> +kernelWithOptInputWithMultipleOutputs(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + return std::make_tuple(arg2, arg3, arg4); +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", kernel(), dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(3, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + EXPECT_TRUE(outputs[1].isNone()); + EXPECT_EQ("text", outputs[2].toString()->string()); + + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(3, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + EXPECT_EQ(4, outputs[1].toInt()); + EXPECT_TRUE(outputs[2].isNone()); +} + std::tuple kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef arg3) { return {}; } diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h index 1582568..13a4389 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor.h +++ b/aten/src/ATen/core/op_registration/kernel_functor.h @@ -29,6 +29,7 @@ namespace detail { // cast it to the type that should be passed to the kernel function. // Examples: If the IValue contains a plain type like an int, return that. // If the IValue contains an IntList, return it as ArrayRef. + // TODO Should we move the IValue so we can avoid bumping the Tensor refcount? template struct ivalue_to_arg_type { static T call(const IValue& v) { @@ -41,10 +42,34 @@ namespace detail { return v.to>>()->elements(); } }; + template + struct ivalue_to_arg_type> { + static optional call(const IValue& v) { + if (v.isNone()) { + return nullopt; + } + return v.to(); + } + }; template - IValue return_type_to_ivalue(T&& t) { - return IValue(std::forward(t)); + struct return_type_to_ivalue_ { + static IValue call(T&& v) { + return IValue(std::move(v)); + } + }; + template + struct return_type_to_ivalue_> { + static IValue call(optional&& v) { + if (!v.has_value()) { + return IValue(); + } + return IValue(std::move(*v)); + } + }; + template + IValue return_type_to_ivalue(T&& v) { + return return_type_to_ivalue_>::call(std::move(v)); } template diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp index ea2aa34..b302961 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp @@ -601,6 +601,110 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenFallbackKernelWithoutTens EXPECT_EQ(4, outputs[0].toInt()); } +c10::optional called_arg2; +c10::optional called_arg3; +c10::optional called_arg4; + +struct KernelWithOptInputWithoutOutput final : OperatorKernel { + void operator()(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + } +}; + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", kernel(), dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +struct KernelWithOptInputWithOutput final : OperatorKernel { + c10::optional operator()(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + return arg2; + } +}; + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", kernel(), dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +struct KernelWithOptInputWithMultipleOutputs final : OperatorKernel { + std::tuple, c10::optional, c10::optional> + operator()(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + return std::make_tuple(arg2, arg3, arg4); + } +}; + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", kernel(), dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(3, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + EXPECT_TRUE(outputs[1].isNone()); + EXPECT_EQ("text", outputs[2].toString()->string()); + + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(3, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + EXPECT_EQ(4, outputs[1].toInt()); + EXPECT_TRUE(outputs[2].isNone()); +} + struct KernelForSchemaInference final : OperatorKernel { std::tuple operator()(Tensor arg1, int64_t arg2, ArrayRef arg3) { return {}; diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp index b8c65b5..121029c 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp @@ -407,6 +407,114 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenFallbackKernelWithou EXPECT_EQ(4, outputs[0].toInt()); } +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { + bool called; + c10::optional called_arg2; + c10::optional called_arg3; + c10::optional called_arg4; + + auto registrar = RegisterOperators().op( + "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", + [&] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + }); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { + bool called; + c10::optional called_arg2; + c10::optional called_arg3; + c10::optional called_arg4; + + auto registrar = RegisterOperators().op( + "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", + [&] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + return arg2; + }); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { + bool called; + c10::optional called_arg2; + c10::optional called_arg3; + c10::optional called_arg4; + + auto registrar = RegisterOperators().op( + "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", + [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + return std::make_tuple(arg2, arg3, arg4); + }); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(3, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + EXPECT_TRUE(outputs[1].isNone()); + EXPECT_EQ("text", outputs[2].toString()->string()); + + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(3, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + EXPECT_EQ(4, outputs[1].toInt()); + EXPECT_TRUE(outputs[2].isNone()); +} + TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { auto registrar = RegisterOperators() .op("_test::no_schema_specified", [] (Tensor arg1, int64_t arg2, ArrayRef arg3) -> std::tuple {return {};}); diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp index 673855b..fc576e7 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp @@ -420,6 +420,106 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenFallbackKernelWithoutTenso EXPECT_EQ(4, outputs[0].toInt()); } +c10::optional called_arg2; +c10::optional called_arg3; +c10::optional called_arg4; + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op( + "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", + kernel([] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + }), + dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(0, outputs.size()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op( + "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", + kernel([] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + called = true; + called_arg2 = arg2; + called_arg3 = arg3; + called_arg4 = arg4; + return arg2; + }), + dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + called = false; + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + + EXPECT_TRUE(called); + EXPECT_TRUE(called_arg2.has_value()); + EXPECT_EQ(called_arg2->type_id(), TensorType2()); + EXPECT_FALSE(called_arg3.has_value()); + EXPECT_TRUE(called_arg4.has_value()); + EXPECT_EQ(*called_arg4, "text"); + + called = false; + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(1, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + + EXPECT_TRUE(called); + EXPECT_FALSE(called_arg2.has_value()); + EXPECT_TRUE(called_arg3.has_value()); + EXPECT_EQ(*called_arg3, 4); + EXPECT_FALSE(called_arg4.has_value()); +} + +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { + auto registrar = RegisterOperators().op( + "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", + kernel([] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + return std::make_tuple(arg2, arg3, arg4); + }), + dispatchKey(TensorType1())); + auto op = c10::Dispatcher::singleton().findSchema("_test::opt_input", ""); + ASSERT_TRUE(op.has_value()); + + auto outputs = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), c10::IValue(), std::string("text")); + EXPECT_EQ(3, outputs.size()); + EXPECT_EQ(TensorType2(), outputs[0].toTensor().type_id()); + EXPECT_TRUE(outputs[1].isNone()); + EXPECT_EQ("text", outputs[2].toString()->string()); + + outputs = callOp(*op, dummyTensor(TensorType1()), c10::IValue(), 4, c10::IValue()); + EXPECT_EQ(3, outputs.size()); + EXPECT_TRUE(outputs[0].isNone()); + EXPECT_EQ(4, outputs[1].toInt()); + EXPECT_TRUE(outputs[2].isNone()); +} + TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { auto registrar = RegisterOperators() .op("_test::no_schema_specified", kernel([] (Tensor arg1, int64_t arg2, ArrayRef arg3) -> std::tuple {return {};})); -- 2.7.4