From ae1d13a06fc43688fcf79c8f04c437b8ac6b3f61 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Tue, 2 Apr 2019 12:23:13 -0700 Subject: [PATCH] Improve and test error messages for signature mismatches (#18547) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18547 - Argument indices in the error messages are 1-indexed not 0-indexed. - Add test cases that a mismatching signature actually shows the correct error messages Reviewed By: dzhulgakov Differential Revision: D14656695 fbshipit-source-id: 55e45634baa3117e18b8687ea6b2a2f83715bdf6 --- .../src/ATen/core/op_registration/infer_schema.cpp | 6 +- .../kernel_function_legacy_test.cpp | 108 ++++++++++----------- .../core/op_registration/kernel_function_test.cpp | 108 ++++++++++----------- .../core/op_registration/kernel_functor_test.cpp | 108 ++++++++++----------- .../op_registration/kernel_lambda_legacy_test.cpp | 108 ++++++++++----------- .../core/op_registration/kernel_lambda_test.cpp | 108 ++++++++++----------- .../core/op_registration/op_registration_test.cpp | 21 ++-- aten/src/ATen/core/op_registration/test_helpers.h | 13 +++ 8 files changed, 295 insertions(+), 285 deletions(-) diff --git a/aten/src/ATen/core/op_registration/infer_schema.cpp b/aten/src/ATen/core/op_registration/infer_schema.cpp index bcfeb87..8122cbd 100644 --- a/aten/src/ATen/core/op_registration/infer_schema.cpp +++ b/aten/src/ATen/core/op_registration/infer_schema.cpp @@ -21,7 +21,7 @@ C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, c if (inferred.returns().size() != specified.returns().size()) { AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", - "The number of returns is different.Specified ", specified.returns().size(), + "The number of returns is different. Specified ", specified.returns().size(), " but inferred ", inferred.returns().size()); } @@ -29,7 +29,7 @@ C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, c if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) { AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", - "Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(), + "Type mismatch in argument ", (i+1) , ": specified ", specified.arguments()[i].type()->str(), " but inferred ", inferred.arguments()[i].type()->str()); } } @@ -38,7 +38,7 @@ C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, c if (*inferred.returns()[i].type() != *specified.returns()[i].type()) { AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ", "doesn't match inferred function schema [", serialize_schema(inferred), "]. ", - "Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(), + "Type mismatch in return ", (i+1), ": specified ", specified.returns()[i].type()->str(), " but inferred ", inferred.returns()[i].type()->str()); } } diff --git a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp index b122aea..206c9bd 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp @@ -547,15 +547,15 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2")}), (std::vector{Argument("ret", IntType::get())}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of arguments is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -568,37 +568,37 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{}), (std::vector{}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of arguments is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of arguments is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), (std::vector{}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of arguments is different. Specified 3 but inferred 2" ); } @@ -613,26 +613,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "Type mismatch in argument 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "Type mismatch in argument 1: specified int but inferred Tensor" ); } @@ -647,18 +647,18 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of returns is different. Specified 0 but inferred 1" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", @@ -666,8 +666,8 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of returns is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -680,26 +680,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of returns is different. Specified 1 but inferred 0" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret"), Argument("ret2")}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "The number of returns is different. Specified 2 but inferred 0" ); // assert this does not fail because it matches @@ -712,37 +712,37 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func, Tensor>::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), &kernel_func, Tensor>::func), - c10::Error + ), &kernel_func, Tensor>::func); + }, "The number of returns is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1")}) - ), &kernel_func, Tensor>::func), - c10::Error + ), &kernel_func, Tensor>::func); + }, "The number of returns is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) - ), &kernel_func, Tensor>::func), - c10::Error + ), &kernel_func, Tensor>::func); + }, "The number of returns is different. Specified 3 but inferred 2" ); } @@ -757,26 +757,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "Type mismatch in return 1: specified Tensor but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "Type mismatch in return 1: specified float but inferred int" ); // assert this does not fail because it matches @@ -789,15 +789,15 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), &kernel_func::func), - c10::Error + ), &kernel_func::func); + }, "Type mismatch in return 1: specified float but inferred Tensor" ); // assert this does not fail because it matches @@ -810,26 +810,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w ), &kernel_func, Tensor>::func); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) - ), &kernel_func, Tensor>::func), - c10::Error + ), &kernel_func, Tensor>::func); + }, "Type mismatch in return 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), &kernel_func, Tensor>::func), - c10::Error + ), &kernel_func, Tensor>::func); + }, "Type mismatch in return 1: specified int but inferred Tensor" ); } diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp index 4d86d28..2c8a590 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_test.cpp @@ -557,15 +557,15 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2")}), (std::vector{Argument("ret", IntType::get())}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -578,37 +578,37 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{}), (std::vector{}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), (std::vector{}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 3 but inferred 2" ); } @@ -623,26 +623,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in argument 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in argument 1: specified int but inferred Tensor" ); } @@ -657,18 +657,18 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 0 but inferred 1" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", @@ -676,8 +676,8 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -690,26 +690,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 1 but inferred 0" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret"), Argument("ret2")}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 2 but inferred 0" ); // assert this does not fail because it matches @@ -722,37 +722,37 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1")}) - ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) - ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 3 but inferred 2" ); } @@ -767,26 +767,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified Tensor but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified float but inferred int" ); // assert this does not fail because it matches @@ -799,15 +799,15 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel::func), &kernel_func::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified float but inferred Tensor" ); // assert this does not fail because it matches @@ -820,26 +820,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) - ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>::func), &kernel_func, Tensor>::func>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified int but inferred Tensor" ); } diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp index 49e39bb..90f18ed 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp @@ -710,15 +710,15 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2")}), (std::vector{Argument("ret", IntType::get())}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -731,37 +731,37 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{}), (std::vector{}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), (std::vector{}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 3 but inferred 2" ); } @@ -776,26 +776,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "Type mismatch in argument 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "Type mismatch in argument 1: specified int but inferred Tensor" ); } @@ -810,18 +810,18 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 0 but inferred 1" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", @@ -829,8 +829,8 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -843,26 +843,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 1 but inferred 0" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret"), Argument("ret2")}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 2 but inferred 0" ); // assert this does not fail because it matches @@ -875,37 +875,37 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel, Tensor>>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel, Tensor>>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1")}) - ), kernel, Tensor>>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) - ), kernel, Tensor>>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 3 but inferred 2" ); } @@ -920,26 +920,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified Tensor but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified float but inferred int" ); // assert this does not fail because it matches @@ -952,15 +952,15 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), kernel>(), dispatchKey(TensorType1())), - c10::Error + ), kernel>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified float but inferred Tensor" ); // assert this does not fail because it matches @@ -973,26 +973,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff ), kernel, Tensor>>(), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) - ), kernel, Tensor>>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), kernel, Tensor>>(), dispatchKey(TensorType1())), - c10::Error + ), kernel, Tensor>>(), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified int but inferred Tensor" ); } diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp index 8f45d34..884cd2e 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp @@ -497,15 +497,15 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> int64_t {return 0;}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2")}), (std::vector{Argument("ret", IntType::get())}) - ), [] (Tensor) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor) -> int64_t {return 0;}); + }, "The number of arguments is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -518,37 +518,37 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor, Tensor) -> void {}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{}), (std::vector{}) - ), [] (Tensor, Tensor) -> void {}), - c10::Error + ), [] (Tensor, Tensor) -> void {}); + }, "The number of arguments is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), [] (Tensor, Tensor) -> void {}), - c10::Error + ), [] (Tensor, Tensor) -> void {}); + }, "The number of arguments is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), (std::vector{}) - ), [] (Tensor, Tensor) -> void {}), - c10::Error + ), [] (Tensor, Tensor) -> void {}); + }, "The number of arguments is different. Specified 3 but inferred 2" ); } @@ -563,26 +563,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor, int64_t) -> int64_t {return 0;}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), [] (Tensor, int64_t) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor, int64_t) -> int64_t {return 0;}); + }, "Type mismatch in argument 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), [] (Tensor, int64_t) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor, int64_t) -> int64_t {return 0;}); + }, "Type mismatch in argument 1: specified int but inferred Tensor" ); } @@ -597,18 +597,18 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> int64_t {return 0;}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), [] (Tensor) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor) -> int64_t {return 0;}); + }, "The number of returns is different. Specified 0 but inferred 1" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", @@ -616,8 +616,8 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), [] (Tensor) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor) -> int64_t {return 0;}); + }, "The number of returns is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -630,26 +630,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> void {}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), [] (Tensor) -> void {}), - c10::Error + ), [] (Tensor) -> void {}); + }, "The number of returns is different. Specified 1 but inferred 0" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret"), Argument("ret2")}) - ), [] (Tensor) -> void {}), - c10::Error + ), [] (Tensor) -> void {}); + }, "The number of returns is different. Specified 2 but inferred 0" ); // assert this does not fail because it matches @@ -662,37 +662,37 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> std::tuple {return {};}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), [] (Tensor) -> std::tuple {return {};}), - c10::Error + ), [] (Tensor) -> std::tuple {return {};}); + }, "The number of returns is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1")}) - ), [] (Tensor) -> std::tuple {return {};}), - c10::Error + ), [] (Tensor) -> std::tuple {return {};}); + }, "The number of returns is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) - ), [] (Tensor) -> std::tuple {return {};}), - c10::Error + ), [] (Tensor) -> std::tuple {return {};}); + }, "The number of returns is different. Specified 3 but inferred 2" ); } @@ -707,26 +707,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> int64_t {return 0;}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), [] (Tensor) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor) -> int64_t {return 0;}); + }, "Type mismatch in return 1: specified Tensor but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), [] (Tensor) -> int64_t {return 0;}), - c10::Error + ), [] (Tensor) -> int64_t {return 0;}); + }, "Type mismatch in return 1: specified float but inferred int" ); // assert this does not fail because it matches @@ -739,15 +739,15 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> Tensor {return {};}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), [] (Tensor) -> Tensor {return {};}), - c10::Error + ), [] (Tensor) -> Tensor {return {};}); + }, "Type mismatch in return 1: specified float but inferred Tensor" ); // assert this does not fail because it matches @@ -760,26 +760,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit ), [] (Tensor) -> std::tuple {return {};}); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) - ), [] (Tensor) -> std::tuple {return {};}), - c10::Error + ), [] (Tensor) -> std::tuple {return {};}); + }, "Type mismatch in return 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), [] (Tensor) -> std::tuple {return {};}), - c10::Error + ), [] (Tensor) -> std::tuple {return {};}); + }, "Type mismatch in return 1: specified int but inferred Tensor" ); } diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp index d9defec..c156e48 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp @@ -514,15 +514,15 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2")}), (std::vector{Argument("ret", IntType::get())}) - ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -535,37 +535,37 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{}), (std::vector{}) - ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg"), Argument("arg2"), Argument("arg3")}), (std::vector{}) - ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())); + }, "The number of arguments is different. Specified 3 but inferred 2" ); } @@ -580,26 +580,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1"), Argument("arg2", FloatType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in argument 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}), (std::vector{Argument("ret", IntType::get())}) - ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in argument 1: specified int but inferred Tensor" ); } @@ -614,18 +614,18 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 0 but inferred 1" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", @@ -633,8 +633,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 2 but inferred 1" ); // assert this does not fail because it matches @@ -647,26 +647,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 1 but inferred 0" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret"), Argument("ret2")}) - ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 2 but inferred 0" ); // assert this does not fail because it matches @@ -679,37 +679,37 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{}) - ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 0 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1")}) - ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 1 but inferred 2" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2"), Argument("ret3")}) - ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + }, "The number of returns is different. Specified 3 but inferred 2" ); } @@ -724,26 +724,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret")}) - ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified Tensor but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified float but inferred int" ); // assert this does not fail because it matches @@ -756,15 +756,15 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret", FloatType::get())}) - ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified float but inferred Tensor" ); // assert this does not fail because it matches @@ -777,26 +777,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); // and now a set of mismatching schemas - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1"), Argument("ret2", FloatType::get())}) - ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in return 2: specified float but inferred int" ); - EXPECT_THROW( + expectThrows([] { RegisterOperators() .op(FunctionSchema( "_test::mismatch", "", (std::vector{Argument("arg")}), (std::vector{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())}) - ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())), - c10::Error + ), kernel([] (Tensor) -> std::tuple {return {};}), dispatchKey(TensorType1())); + }, "Type mismatch in return 1: specified int but inferred Tensor" ); } diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 224bab4..f81d5ee 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -48,10 +48,9 @@ FunctionSchema dummySchema( TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) { // make sure it crashes when kernel is absent - EXPECT_THROW( - c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())), - c10::Error - ); + expectThrows([&] { + c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())); + }, "but didn't specify a kernel"); // but make sure it doesn't crash when kernel is present c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); @@ -62,10 +61,9 @@ TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWro auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); ASSERT_TRUE(op.has_value()); - EXPECT_THROW( - callOp(*op, dummyTensor(TensorType2())), - c10::Error - ); + expectThrows([&] { + callOp(*op, dummyTensor(TensorType2())); + }, "Didn't find kernel to dispatch to for operator '_test::dummy'"); } TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) { @@ -77,10 +75,9 @@ TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOp auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); ASSERT_TRUE(op.has_value()); - EXPECT_THROW( - callOp(*op, dummyTensor(TensorType2())), - c10::Error - ); + expectThrows([&] { + callOp(*op, dummyTensor(TensorType2())); + }, "Didn't find kernel to dispatch to for operator '_test::dummy'"); } TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) { diff --git a/aten/src/ATen/core/op_registration/test_helpers.h b/aten/src/ATen/core/op_registration/test_helpers.h index 595912c..23b5b2b 100644 --- a/aten/src/ATen/core/op_registration/test_helpers.h +++ b/aten/src/ATen/core/op_registration/test_helpers.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -44,3 +45,15 @@ inline void expectDoesntFindOperator(const char* op_name) { auto op = c10::Dispatcher::singleton().findSchema(op_name, ""); EXPECT_FALSE(op.has_value()); } + +template +inline void expectThrows(Functor&& functor, const char* expectMessageContains) { + try { + std::forward(functor)(); + } catch (const Exception& e) { + EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains)); + return; + } + ADD_FAILURE() << "Expected to throw exception containing \"" + << expectMessageContains << "\" but didn't throw"; +} -- 2.7.4