Improve and test error messages for signature mismatches (#18547)
authorSebastian Messmer <messmer@fb.com>
Tue, 2 Apr 2019 19:23:13 +0000 (12:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 19:33:24 +0000 (12:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18547

- Argument indices in the error messages are 1-indexed not 0-indexed.
- Add test cases that a mismatching signature actually shows the correct error messages

Reviewed By: dzhulgakov

Differential Revision: D14656695

fbshipit-source-id: 55e45634baa3117e18b8687ea6b2a2f83715bdf6

aten/src/ATen/core/op_registration/infer_schema.cpp
aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp
aten/src/ATen/core/op_registration/kernel_function_test.cpp
aten/src/ATen/core/op_registration/kernel_functor_test.cpp
aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
aten/src/ATen/core/op_registration/op_registration_test.cpp
aten/src/ATen/core/op_registration/test_helpers.h

index bcfeb87..8122cbd 100644 (file)
@@ -21,7 +21,7 @@ C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, c
   if (inferred.returns().size() != specified.returns().size()) {
     AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
              "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
-             "The number of returns is different.Specified ", specified.returns().size(),
+             "The number of returns is different. Specified ", specified.returns().size(),
              " but inferred ", inferred.returns().size());
   }
 
@@ -29,7 +29,7 @@ C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, c
     if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) {
       AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
                "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
-               "Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(),
+               "Type mismatch in argument ", (i+1) , ": specified ", specified.arguments()[i].type()->str(),
                " but inferred ", inferred.arguments()[i].type()->str());
     }
   }
@@ -38,7 +38,7 @@ C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, c
     if (*inferred.returns()[i].type() != *specified.returns()[i].type()) {
       AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
                "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
-               "Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(),
+               "Type mismatch in return ", (i+1), ": specified ", specified.returns()[i].type()->str(),
                " but inferred ", inferred.returns()[i].type()->str());
     }
   }
index b122aea..206c9bd 100644 (file)
@@ -547,15 +547,15 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<int64_t, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), &kernel_func<int64_t, Tensor>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor>::func);
+    }, "The number of arguments is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -568,37 +568,37 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<void, Tensor, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{}),
             (std::vector<Argument>{})
-        ), &kernel_func<void, Tensor, Tensor>::func),
-    c10::Error
+        ), &kernel_func<void, Tensor, Tensor>::func);
+    }, "The number of arguments is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), &kernel_func<void, Tensor, Tensor>::func),
-    c10::Error
+        ), &kernel_func<void, Tensor, Tensor>::func);
+    }, "The number of arguments is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
             (std::vector<Argument>{})
-        ), &kernel_func<void, Tensor, Tensor>::func),
-    c10::Error
+        ), &kernel_func<void, Tensor, Tensor>::func);
+    }, "The number of arguments is different. Specified 3 but inferred 2"
   );
 }
 
@@ -613,26 +613,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<int64_t, Tensor, int64_t>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), &kernel_func<int64_t, Tensor, int64_t>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor, int64_t>::func);
+    }, "Type mismatch in argument 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), &kernel_func<int64_t, Tensor, int64_t>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor, int64_t>::func);
+    }, "Type mismatch in argument 1: specified int but inferred Tensor"
   );
 }
 
@@ -647,18 +647,18 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<int64_t, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), &kernel_func<int64_t, Tensor>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor>::func);
+    }, "The number of returns is different. Specified 0 but inferred 1"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
@@ -666,8 +666,8 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()),
                                    Argument("ret2", IntType::get())})
-        ), &kernel_func<int64_t, Tensor>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor>::func);
+    }, "The number of returns is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -680,26 +680,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<void, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), &kernel_func<void, Tensor>::func),
-    c10::Error
+        ), &kernel_func<void, Tensor>::func);
+    }, "The number of returns is different. Specified 1 but inferred 0"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret"), Argument("ret2")})
-        ), &kernel_func<void, Tensor>::func),
-    c10::Error
+        ), &kernel_func<void, Tensor>::func);
+    }, "The number of returns is different. Specified 2 but inferred 0"
   );
 
   // assert this does not fail because it matches
@@ -712,37 +712,37 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func),
-    c10::Error
+        ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+    }, "The number of returns is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1")})
-        ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func),
-    c10::Error
+        ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+    }, "The number of returns is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
-        ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func),
-    c10::Error
+        ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+    }, "The number of returns is different. Specified 3 but inferred 2"
   );
 }
 
@@ -757,26 +757,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<int64_t, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), &kernel_func<int64_t, Tensor>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor>::func);
+    }, "Type mismatch in return 1: specified Tensor but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), &kernel_func<int64_t, Tensor>::func),
-    c10::Error
+        ), &kernel_func<int64_t, Tensor>::func);
+    }, "Type mismatch in return 1: specified float but inferred int"
   );
 
   // assert this does not fail because it matches
@@ -789,15 +789,15 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<Tensor, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), &kernel_func<Tensor, Tensor>::func),
-    c10::Error
+        ), &kernel_func<Tensor, Tensor>::func);
+    }, "Type mismatch in return 1: specified float but inferred Tensor"
   );
 
   // assert this does not fail because it matches
@@ -810,26 +810,26 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_w
       ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
-        ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func),
-    c10::Error
+        ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+    }, "Type mismatch in return 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
-        ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func),
-    c10::Error
+        ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+    }, "Type mismatch in return 1: specified int but inferred Tensor"
   );
 }
 
index 4d86d28..2c8a590 100644 (file)
@@ -557,15 +557,15 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -578,37 +578,37 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{}),
             (std::vector<Argument>{})
-        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
             (std::vector<Argument>{})
-        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 3 but inferred 2"
   );
 }
 
@@ -623,26 +623,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in argument 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in argument 1: specified int but inferred Tensor"
   );
 }
 
@@ -657,18 +657,18 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 0 but inferred 1"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
@@ -676,8 +676,8 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()),
                                    Argument("ret2", IntType::get())})
-        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -690,26 +690,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 1 but inferred 0"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret"), Argument("ret2")})
-        ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 2 but inferred 0"
   );
 
   // assert this does not fail because it matches
@@ -722,37 +722,37 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1")})
-        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
-        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 3 but inferred 2"
   );
 }
 
@@ -767,26 +767,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified Tensor but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified float but inferred int"
   );
 
   // assert this does not fail because it matches
@@ -799,15 +799,15 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified float but inferred Tensor"
   );
 
   // assert this does not fail because it matches
@@ -820,26 +820,26 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif
       ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
-        ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
-        ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified int but inferred Tensor"
   );
 }
 
index 49e39bb..90f18ed 100644 (file)
@@ -710,15 +710,15 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -731,37 +731,37 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{}),
             (std::vector<Argument>{})
-        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
             (std::vector<Argument>{})
-        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 3 but inferred 2"
   );
 }
 
@@ -776,26 +776,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in argument 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in argument 1: specified int but inferred Tensor"
   );
 }
 
@@ -810,18 +810,18 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 0 but inferred 1"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
@@ -829,8 +829,8 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()),
                                    Argument("ret2", IntType::get())})
-        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -843,26 +843,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 1 but inferred 0"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret"), Argument("ret2")})
-        ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 2 but inferred 0"
   );
 
   // assert this does not fail because it matches
@@ -875,37 +875,37 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1")})
-        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
-        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 3 but inferred 2"
   );
 }
 
@@ -920,26 +920,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified Tensor but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified float but inferred int"
   );
 
   // assert this does not fail because it matches
@@ -952,15 +952,15 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified float but inferred Tensor"
   );
 
   // assert this does not fail because it matches
@@ -973,26 +973,26 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff
       ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
-        ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
-        ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified int but inferred Tensor"
   );
 }
 
index 8f45d34..884cd2e 100644 (file)
@@ -497,15 +497,15 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> int64_t {return 0;});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), [] (Tensor) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor) -> int64_t {return 0;});
+    }, "The number of arguments is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -518,37 +518,37 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor, Tensor) -> void {});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{}),
             (std::vector<Argument>{})
-        ), [] (Tensor, Tensor) -> void {}),
-    c10::Error
+        ), [] (Tensor, Tensor) -> void {});
+    }, "The number of arguments is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), [] (Tensor, Tensor) -> void {}),
-    c10::Error
+        ), [] (Tensor, Tensor) -> void {});
+    }, "The number of arguments is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
             (std::vector<Argument>{})
-        ), [] (Tensor, Tensor) -> void {}),
-    c10::Error
+        ), [] (Tensor, Tensor) -> void {});
+    }, "The number of arguments is different. Specified 3 but inferred 2"
   );
 }
 
@@ -563,26 +563,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor, int64_t) -> int64_t {return 0;});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), [] (Tensor, int64_t) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor, int64_t) -> int64_t {return 0;});
+    }, "Type mismatch in argument 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), [] (Tensor, int64_t) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor, int64_t) -> int64_t {return 0;});
+    }, "Type mismatch in argument 1: specified int but inferred Tensor"
   );
 }
 
@@ -597,18 +597,18 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> int64_t {return 0;});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), [] (Tensor) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor) -> int64_t {return 0;});
+    }, "The number of returns is different. Specified 0 but inferred 1"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
@@ -616,8 +616,8 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()),
                                    Argument("ret2", IntType::get())})
-        ), [] (Tensor) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor) -> int64_t {return 0;});
+    }, "The number of returns is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -630,26 +630,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> void {});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), [] (Tensor) -> void {}),
-    c10::Error
+        ), [] (Tensor) -> void {});
+    }, "The number of returns is different. Specified 1 but inferred 0"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret"), Argument("ret2")})
-        ), [] (Tensor) -> void {}),
-    c10::Error
+        ), [] (Tensor) -> void {});
+    }, "The number of returns is different. Specified 2 but inferred 0"
   );
 
   // assert this does not fail because it matches
@@ -662,37 +662,37 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}),
-    c10::Error
+        ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+    }, "The number of returns is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1")})
-        ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}),
-    c10::Error
+        ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+    }, "The number of returns is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
-        ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}),
-    c10::Error
+        ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+    }, "The number of returns is different. Specified 3 but inferred 2"
   );
 }
 
@@ -707,26 +707,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> int64_t {return 0;});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), [] (Tensor) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor) -> int64_t {return 0;});
+    }, "Type mismatch in return 1: specified Tensor but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), [] (Tensor) -> int64_t {return 0;}),
-    c10::Error
+        ), [] (Tensor) -> int64_t {return 0;});
+    }, "Type mismatch in return 1: specified float but inferred int"
   );
 
   // assert this does not fail because it matches
@@ -739,15 +739,15 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> Tensor {return {};});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), [] (Tensor) -> Tensor {return {};}),
-    c10::Error
+        ), [] (Tensor) -> Tensor {return {};});
+    }, "Type mismatch in return 1: specified float but inferred Tensor"
   );
 
   // assert this does not fail because it matches
@@ -760,26 +760,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_wit
       ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
-        ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}),
-    c10::Error
+        ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+    }, "Type mismatch in return 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
-        ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}),
-    c10::Error
+        ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+    }, "Type mismatch in return 1: specified int but inferred Tensor"
   );
 }
 
index d9defec..c156e48 100644 (file)
@@ -514,15 +514,15 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -535,37 +535,37 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{}),
             (std::vector<Argument>{})
-        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
             (std::vector<Argument>{})
-        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+    }, "The number of arguments is different. Specified 3 but inferred 2"
   );
 }
 
@@ -580,26 +580,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in argument 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
             (std::vector<Argument>{Argument("ret", IntType::get())})
-        ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in argument 1: specified int but inferred Tensor"
   );
 }
 
@@ -614,18 +614,18 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 0 but inferred 1"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
@@ -633,8 +633,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()),
                                    Argument("ret2", IntType::get())})
-        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 2 but inferred 1"
   );
 
   // assert this does not fail because it matches
@@ -647,26 +647,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 1 but inferred 0"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret"), Argument("ret2")})
-        ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 2 but inferred 0"
   );
 
   // assert this does not fail because it matches
@@ -679,37 +679,37 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{})
-        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 0 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1")})
-        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 1 but inferred 2"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
-        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+    }, "The number of returns is different. Specified 3 but inferred 2"
   );
 }
 
@@ -724,26 +724,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret")})
-        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified Tensor but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified float but inferred int"
   );
 
   // assert this does not fail because it matches
@@ -756,15 +756,15 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret", FloatType::get())})
-        ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified float but inferred Tensor"
   );
 
   // assert this does not fail because it matches
@@ -777,26 +777,26 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe
       ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
 
   // and now a set of mismatching schemas
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
-        ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 2: specified float but inferred int"
   );
 
-  EXPECT_THROW(
+  expectThrows<c10::Error>([] {
     RegisterOperators()
         .op(FunctionSchema(
             "_test::mismatch",
             "",
             (std::vector<Argument>{Argument("arg")}),
             (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
-        ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
-    c10::Error
+        ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+    }, "Type mismatch in return 1: specified int but inferred Tensor"
   );
 }
 
index 224bab4..f81d5ee 100644 (file)
@@ -48,10 +48,9 @@ FunctionSchema dummySchema(
 
 TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) {
   // make sure it crashes when kernel is absent
-  EXPECT_THROW(
-    c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())),
-    c10::Error
-  );
+  expectThrows<c10::Error>([&] {
+    c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1()));
+  }, "but didn't specify a kernel");
 
   // but make sure it doesn't crash when kernel is present
   c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
@@ -62,10 +61,9 @@ TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWro
 
   auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
   ASSERT_TRUE(op.has_value());
-  EXPECT_THROW(
-    callOp(*op, dummyTensor(TensorType2())),
-    c10::Error
-  );
+  expectThrows<c10::Error>([&] {
+    callOp(*op, dummyTensor(TensorType2()));
+  }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
 }
 
 TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) {
@@ -77,10 +75,9 @@ TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOp
 
   auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
   ASSERT_TRUE(op.has_value());
-  EXPECT_THROW(
-    callOp(*op, dummyTensor(TensorType2())),
-    c10::Error
-  );
+  expectThrows<c10::Error>([&] {
+    callOp(*op, dummyTensor(TensorType2()));
+  }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
 }
 
 TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) {
index 595912c..23b5b2b 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <gtest/gtest.h>
+#include <gmock/gmock.h>
 
 #include <ATen/core/Tensor.h>
 #include <ATen/core/dispatch/Dispatcher.h>
@@ -44,3 +45,15 @@ inline void expectDoesntFindOperator(const char* op_name) {
   auto op = c10::Dispatcher::singleton().findSchema(op_name, "");
   EXPECT_FALSE(op.has_value());
 }
+
+template<class Exception, class Functor>
+inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
+  try {
+    std::forward<Functor>(functor)();
+  } catch (const Exception& e) {
+    EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
+    return;
+  }
+  ADD_FAILURE() << "Expected to throw exception containing \""
+    << expectMessageContains << "\" but didn't throw";
+}