Check kernel against function schema in c10 op registration (#18256)
authorSebastian Messmer <messmer@fb.com>
Sat, 30 Mar 2019 07:03:44 +0000 (00:03 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 07:07:22 +0000 (00:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18256

This diff infers the function schema from the kernel function/functor and checks that it matches the specified function schema.

This diff does not allow (yet) to omit specifying the function schema in the registration API. That will come in a future diff.

Reviewed By: dzhulgakov

Differential Revision: D14552738

fbshipit-source-id: 00202b489ede19f26ae686c97416b38c72c11532

aten/src/ATen/core/op_registration/base.h
aten/src/ATen/core/op_registration/infer_schema.cpp [new file with mode: 0644]
aten/src/ATen/core/op_registration/infer_schema.h
aten/src/ATen/core/op_registration/kernel_function_test.cpp
aten/src/ATen/core/op_registration/kernel_functor.h
aten/src/ATen/core/op_registration/kernel_functor_test.cpp
aten/src/ATen/core/op_registration/kernel_stackbased.h
aten/src/ATen/core/op_registration/op_registration.h
c10/util/C++17.h
c10/util/TypeTraits.h
caffe2/operators/experimental/c10/cpu/concat_cpu.cc

index 3c4ea05..60250a5 100644 (file)
@@ -56,6 +56,7 @@ namespace detail {
     TensorTypeId dispatch_key;
     KernelFunction* kernel_func = nullptr;
     KernelCacheCreatorFunction cache_creator_func = nullptr;
+    std::unique_ptr<FunctionSchema> inferred_function_schema = nullptr;
   };
 
   // is_registration_config_parameter is a concept that returns true_type iff its argument is
diff --git a/aten/src/ATen/core/op_registration/infer_schema.cpp b/aten/src/ATen/core/op_registration/infer_schema.cpp
new file mode 100644 (file)
index 0000000..bcfeb87
--- /dev/null
@@ -0,0 +1,47 @@
+#include "infer_schema.h"
+#include <sstream>
+
+namespace c10 {
+
+namespace {
+  std::string serialize_schema(const FunctionSchema& schema) {
+    std::ostringstream str;
+    str << schema;
+    return str.str();
+  }
+}
+
+C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified) {
+  if (inferred.arguments().size() != specified.arguments().size()) {
+    AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+             "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+             "The number of arguments is different. Specified ", specified.arguments().size(),
+             " but inferred ", inferred.arguments().size());
+  }
+  if (inferred.returns().size() != specified.returns().size()) {
+    AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+             "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+             "The number of returns is different.Specified ", specified.returns().size(),
+             " but inferred ", inferred.returns().size());
+  }
+
+  for (size_t i = 0; i < inferred.arguments().size(); ++i) {
+    if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) {
+      AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+               "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+               "Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(),
+               " but inferred ", inferred.arguments()[i].type()->str());
+    }
+  }
+
+  for (size_t i = 0; i < inferred.returns().size(); ++i) {
+    if (*inferred.returns()[i].type() != *specified.returns()[i].type()) {
+      AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+               "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+               "Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(),
+               " but inferred ", inferred.returns()[i].type()->str());
+    }
+  }
+}
+
+}
index e040503..36f681e 100644 (file)
@@ -18,63 +18,76 @@ void checkStaticTypes() {
  // Give nice error messages for some of the common error cases.
  // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
  static_assert(
-     !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
-     "INVALID TYPE: Only int64_t is supported as an integral argument type");
+     !std::is_integral<T>::value || std::is_same<T, int64_t>::value || std::is_same<T, bool>::value,
+     "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
  static_assert(
      !std::is_same<T, float>::value,
      "INVALID TYPE: float is not supported as an argument type, use double instead");
 }
 
-template <typename First, typename Second, typename... Rest>
-void checkStaticTypes() {
- checkStaticTypes<First>();
- checkStaticTypes<Second, Rest...>();
-}
-
 template <typename... Ts, size_t... Is>
 ::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
-  checkStaticTypes<guts::decay_t<Ts>...>();
+  // Check types for common errors
+  (void)std::initializer_list<int>{(
+    checkStaticTypes<Ts>()
+  , 0)...};
+
   // Arguments are named "_<index>"
-  return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
+  return {Argument("_" + c10::guts::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
 }
 
-template <typename... Ts, size_t... Is>
-::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
-  return createArgumentVectorFromTypes<Ts..., Is...>();
-}
+/// Creates a vector of `Argument` from a list of C++ types that are specified
+/// as template arguments.
+template<class ParameterTypes> struct createArguments final {};
+template<class... ParameterTypes>
+struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
+  static std::vector<Argument> call() {
+    return createArgumentVectorFromTypes<ParameterTypes...>(
+        guts::make_index_sequence<sizeof...(ParameterTypes)>()
+    );
+  }
+};
 
-/// Unpack a tuple return type into a vector of return types, one per tuple
-/// element.
-template <typename... Ts>
-::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
-  return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
-}
+/// Creates a vector of `Argument` from a list of C++ types that are specified
+/// as a tuple (i.e. in the way c10 kernels return values).
+/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
+/// It can be an empty tuple<>, or void for kernels that don't return anything.
+/// It can be a single type A (i.e. no tuple) for the case where a kernel just
+/// returns one value.
+template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
 
-/// Create a single-element `vector` for simple (non-tuple) return types.
-template <typename ReturnType>
-::std::vector<Argument> createReturns(ReturnType*) {
-  checkStaticTypes<guts::decay_t<ReturnType>>();
-  return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
-}
+template<class... ReturnTypes>
+struct createReturns<std::tuple<ReturnTypes...>, void> final {
+  static std::vector<Argument> call() {
+    return createArgumentVectorFromTypes<ReturnTypes...>(
+        guts::make_index_sequence<sizeof...(ReturnTypes)>()
+    );
+  }
+};
 
-/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
-/// into the argument list.
-template <typename FunctionTraits, size_t... Is>
-::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
- using ArgumentTypes = typename FunctionTraits::parameter_types;
- return createArgumentVectorFromTypes<
-     c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
-}
+template<class ReturnType>
+struct createReturns<ReturnType, guts::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
+  static std::vector<Argument> call() {
+    return createReturns<std::tuple<ReturnType>>::call();
+  }
+};
+
+template<>
+struct createReturns<void, void> final {
+  static std::vector<Argument> call() {
+    return createReturns<std::tuple<>>::call();
+  }
+};
 
 /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
 /// function.
 template <typename FunctionTraits>
 FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
  using ReturnType = typename FunctionTraits::return_type;
+ using ParameterTypes = typename FunctionTraits::parameter_types;
 
- auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
-     guts::make_index_sequence<FunctionTraits::number_of_parameters>());
- auto returns = createReturns(static_cast<ReturnType*>(nullptr));
+ auto arguments = createArguments<ParameterTypes>::call();
+ auto returns = createReturns<ReturnType>::call();
 
  return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
 }
@@ -85,4 +98,6 @@ FunctionSchema inferFunctionSchema(std::string name, std::string overload_name)
   return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
 }
 
+C10_API void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified);
+
 }
index 02231d3..4d86d28 100644 (file)
@@ -8,6 +8,7 @@ using c10::RegisterOperators;
 using c10::FunctionSchema;
 using c10::Argument;
 using c10::IntType;
+using c10::FloatType;
 using c10::ListType;
 using c10::kernel;
 using c10::dispatchKey;
@@ -29,21 +30,23 @@ C10_DEFINE_TENSOR_TYPE(TensorType1);
 C10_DECLARE_TENSOR_TYPE(TensorType2);
 C10_DEFINE_TENSOR_TYPE(TensorType2);
 
-void errorKernel(const Tensor&) {
+int64_t errorKernel(const Tensor& tensor, int64_t input) {
   EXPECT_TRUE(false); // this kernel should never be called
+  return 0;
 }
 
 FunctionSchema errorOpSchema(
     "_test::error",
     "",
-    (std::vector<Argument>{Argument("dummy")}),
-    (std::vector<Argument>{}));
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
 
-int incrementKernel(const Tensor& tensor, int input) {
+int64_t incrementKernel(const Tensor& tensor, int64_t input) {
   return input + 1;
 }
 
-int decrementKernel(const Tensor& tensor, int input) {
+int64_t decrementKernel(const Tensor& tensor, int64_t input) {
   return input - 1;
 }
 
@@ -159,7 +162,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_wh
   EXPECT_EQ(0, result.size());
 }
 
-int kernelWithIntOutput(Tensor, int a, int b) {
+int64_t kernelWithIntOutput(Tensor, int64_t a, int64_t b) {
   return a + b;
 }
 
@@ -237,7 +240,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutp
   EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
 }
 
-std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int input1, int input2, int input3) {
+std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
   return {input1, input2, input3};
 }
 
@@ -393,9 +396,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByV
   EXPECT_EQ(TensorType2(), captured_input.type_id());
 }
 
-int captured_int_input = 0;
+int64_t captured_int_input = 0;
 
-void kernelWithIntInputWithoutOutput(Tensor, int input1) {
+void kernelWithIntInputWithoutOutput(Tensor, int64_t input1) {
   captured_int_input = input1;
 }
 
@@ -419,7 +422,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_witho
   EXPECT_EQ(3, captured_int_input);
 }
 
-int kernelWithIntInputWithOutput(Tensor, int input1) {
+int64_t kernelWithIntInputWithOutput(Tensor, int64_t input1) {
   return input1 + 1;
 }
 
@@ -442,7 +445,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withO
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
-int captured_input_list_size = 0;
+int64_t captured_input_list_size = 0;
 
 void kernelWithIntListInputWithoutOutput(Tensor, ArrayRef<int64_t> input1) {
   captured_input_list_size = input1.size();
@@ -468,7 +471,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_w
   EXPECT_EQ(3, captured_input_list_size);
 }
 
-int kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
+int64_t kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
   return input1.size();
 }
 
@@ -514,7 +517,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
   EXPECT_EQ(2, captured_input_list_size);
 }
 
-int kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
+int64_t kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
   return input1.size();
 }
 
@@ -536,4 +539,308 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
   EXPECT_EQ(2, outputs[0].toInt());
 }
 
+template<class Return, class... Args> struct kernel_func final {
+  static Return func(Args...) { return {}; }
+};
+template<class... Args> struct kernel_func<void, Args...> final {
+  static void func(Args...) {}
+};
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+          (std::vector<Argument>{})
+      ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{}),
+            (std::vector<Argument>{})
+        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
+            (std::vector<Argument>{})
+        ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1", IntType::get()),
+                                   Argument("ret2", IntType::get())})
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{})
+      ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret")})
+        ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret"), Argument("ret2")})
+        ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
+      ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1")})
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret")})
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret", FloatType::get())})
+        ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret")})
+      ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret", FloatType::get())})
+        ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
+      ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
+        ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
 }
index 19014a9..86285a1 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <ATen/core/op_registration/kernel_stackbased.h>
+#include <ATen/core/op_registration/infer_schema.h>
 
 namespace c10 {
 /**
@@ -129,6 +130,14 @@ namespace detail {
   private:
     std::tuple<Args...> constructor_parameters_;
   };
+
+  template<class KernelFunctor>
+  class FunctionSchemaInferer final {
+  public:
+    std::unique_ptr<FunctionSchema> operator()() const {
+      return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
+    }
+  };
 }
 
 /**
@@ -168,20 +177,15 @@ namespace detail {
  * >         c10::dispatchKey(CPUTensorId()));
  */
 template<class KernelFunctor, class... ConstructorParameters>
-inline constexpr auto kernel(ConstructorParameters&&... constructorParameters)
 // enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
--> guts::enable_if_t<
-guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
-decltype(kernel(
+inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
+detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
+kernel(ConstructorParameters&&... constructorParameters) {
+  return {
     &detail::wrap_kernel_functor<KernelFunctor>::call,
-    detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
-))> {
-  static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "KernelFunctor cannot be constructed with the given arguments");
-
-  return kernel(
-      &detail::wrap_kernel_functor<KernelFunctor>::call,
-      detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
-  );
+    detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
+    detail::FunctionSchemaInferer<KernelFunctor>()
+  };
 }
 
 }
index 3e638ff..49e39bb 100644 (file)
@@ -9,6 +9,7 @@ using c10::FunctionSchema;
 using c10::OperatorKernel;
 using c10::Argument;
 using c10::IntType;
+using c10::FloatType;
 using c10::ListType;
 using c10::kernel;
 using c10::dispatchKey;
@@ -31,25 +32,27 @@ C10_DECLARE_TENSOR_TYPE(TensorType2);
 C10_DEFINE_TENSOR_TYPE(TensorType2);
 
 struct ErrorKernel final : public OperatorKernel {
-  void operator()(const Tensor&) {
+  int64_t operator()(const Tensor&, int64_t) {
     EXPECT_TRUE(false); // this kernel should never be called
+    return 0;
   }
 };
 
 FunctionSchema errorOpSchema(
     "_test::error",
     "",
-    (std::vector<Argument>{Argument("dummy")}),
-    (std::vector<Argument>{}));
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
 
 struct IncrementKernel final : OperatorKernel {
-  int operator()(const Tensor& tensor, int input) {
+  int64_t operator()(const Tensor& tensor, int64_t input) {
     return input + 1;
   }
 };
 
 struct DecrementKernel final : OperatorKernel {
-  int operator()(const Tensor& tensor, int input) {
+  int64_t operator()(const Tensor& tensor, int64_t input) {
     return input - 1;
   }
 };
@@ -171,7 +174,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whe
 }
 
 struct KernelWithIntOutput final : OperatorKernel {
-  int operator()(Tensor, int a, int b) {
+  int64_t operator()(Tensor, int64_t a, int64_t b) {
     return a + b;
   }
 };
@@ -256,7 +259,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutpu
 }
 
 struct KernelWithIntListOutput final : OperatorKernel {
-  std::vector<int64_t> operator()(const Tensor&, int input1, int input2, int input3) {
+  std::vector<int64_t> operator()(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
     return {input1, input2, input3};
   }
 };
@@ -423,10 +426,10 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByVa
   EXPECT_EQ(TensorType2(), captured_input.type_id());
 }
 
-int captured_int_input = 0;
+int64_t captured_int_input = 0;
 
 struct KernelWithIntInputWithoutOutput final : OperatorKernel {
-  void operator()(Tensor, int input1) {
+  void operator()(Tensor, int64_t input1) {
     captured_int_input = input1;
   }
 };
@@ -452,7 +455,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withou
 }
 
 struct KernelWithIntInputWithOutput final : OperatorKernel {
-  int operator()(Tensor, int input1) {
+  int64_t operator()(Tensor, int64_t input1) {
     return input1 + 1;
   }
 };
@@ -476,7 +479,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOu
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
-int captured_input_list_size = 0;
+int64_t captured_input_list_size = 0;
 
 struct KernelWithIntListInputWithoutOutput final : OperatorKernel {
   void operator()(Tensor, ArrayRef<int64_t> input1) {
@@ -505,7 +508,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_wi
 }
 
 struct KernelWithIntListInputWithOutput final : OperatorKernel {
-  int operator()(Tensor, ArrayRef<int64_t> input1) {
+  int64_t operator()(Tensor, ArrayRef<int64_t> input1) {
     return input1.size();
   }
 };
@@ -555,7 +558,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput
 }
 
 struct KernelWithTensorListInputWithOutput final : OperatorKernel {
-  int operator()(ArrayRef<Tensor> input1) {
+  int64_t operator()(ArrayRef<Tensor> input1) {
     return input1.size();
   }
 };
@@ -689,5 +692,308 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstru
   EXPECT_EQ(13, outputs[0].toInt());
 }
 
+template<class Return, class... Args> struct KernelFunc final : OperatorKernel{
+  Return operator()(Args...) { return {}; }
+};
+template<class... Args> struct KernelFunc<void, Args...> final : OperatorKernel {
+  void operator()(Args...) {}
+};
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+          (std::vector<Argument>{})
+      ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{}),
+            (std::vector<Argument>{})
+        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
+            (std::vector<Argument>{})
+        ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1", IntType::get()),
+                                   Argument("ret2", IntType::get())})
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{})
+      ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret")})
+        ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret"), Argument("ret2")})
+        ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
+      ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1")})
+        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
+        ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret")})
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret", FloatType::get())})
+        ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret")})
+      ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret", FloatType::get())})
+        ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
+      ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
+        ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
+        ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
 
 }
index 412a5dd..8ec9819 100644 (file)
@@ -17,29 +17,40 @@ namespace c10 {
 
 namespace detail {
 
-  template<class KernelCacheCreatorFunction_>
+  struct NoFunctionSchemaInference final {
+    std::unique_ptr<FunctionSchema> operator()() const {
+      return nullptr;
+    }
+  };
+
+  template<class KernelCacheCreatorFunction_, class InferFunctionSchemaFunction>
   struct KernelRegistrationConfigParameter final {
     template<class KernelCacheCreatorFunction__>
-    constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func)
-    : kernel_func_(kernel_func), cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func)) {
+    constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func, InferFunctionSchemaFunction&& infer_function_schema_func)
+    : kernel_func_(kernel_func)
+    , cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func))
+    , infer_function_schema_func_(std::forward<InferFunctionSchemaFunction>(infer_function_schema_func)) {
     }
 
     void apply(KernelRegistrationConfig* registration) const & {
       registration->kernel_func = kernel_func_;
       registration->cache_creator_func = cache_creator_func_;
+      registration->inferred_function_schema = infer_function_schema_func_();
     }
 
     void apply(KernelRegistrationConfig* registration) && {
       registration->kernel_func = kernel_func_;
       registration->cache_creator_func = std::move(cache_creator_func_);
+      registration->inferred_function_schema = std::move(infer_function_schema_func_)();
     }
 
   private:
     KernelFunction* kernel_func_;
     KernelCacheCreatorFunction_ cache_creator_func_;
+    InferFunctionSchemaFunction infer_function_schema_func_;
   };
 
-  static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+  static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction, NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
 }
 
 /**
@@ -61,10 +72,10 @@ namespace detail {
  * >         c10::dispatchKey(CPUTensorId()));
  */
 template<class KernelCacheCreatorFunction_>
-inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
-  static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
+  static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
 
-  return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator)};
+  return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator), detail::NoFunctionSchemaInference()};
 }
 
 }
index 96e12ce..ff0ec8c 100644 (file)
@@ -10,6 +10,7 @@
 #include <ATen/core/op_registration/kernel_stackbased.h>
 #include <ATen/core/op_registration/kernel_functor.h>
 #include <ATen/core/op_registration/kernel_function.h>
+#include <ATen/core/op_registration/infer_schema.h>
 
 namespace c10 {
 
@@ -65,6 +66,11 @@ public:
   guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
   op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
     detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
+
+    if (config.inferred_function_schema.get() != nullptr) {
+      assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
+    }
+
     registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
     return std::move(*this);
   }
index 2d7345b..8f71660 100644 (file)
@@ -229,6 +229,9 @@ class DummyClassForToString final {};
 namespace std {
 // We use SFINAE to detect if std::to_string exists for a type, but that only works
 // if the function name is defined. So let's define a std::to_string for a dummy type.
+// If you're getting an error here saying that this overload doesn't match your
+// std::to_string() call, then you're calling std::to_string() but should be calling
+// c10::guts::to_string().
 inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; }
 }
 namespace c10 { namespace guts { namespace detail {
index ec6437a..b4e04ea 100644 (file)
@@ -61,5 +61,6 @@ struct is_type_condition : std::false_type {};
 template<template<class> class C>
 struct is_type_condition<C, guts::enable_if_t<std::is_same<bool, guts::remove_cv_t<decltype(C<int>::value)>>::value>> : std::true_type {};
 
+
 }
 }
index c23aff3..c84dd3f 100644 (file)
@@ -111,7 +111,7 @@ static auto registry = c10::RegisterOperators().op(
         (std::vector<c10::Argument>{
             c10::Argument("inputs", ListType::ofTensors()),
             c10::Argument("output"),
-            c10::Argument("split_info", FloatType::get()),
+            c10::Argument("split_info"),
             c10::Argument("add", IntType::get()),
             c10::Argument("add_axis", IntType::get())}),
         (std::vector<c10::Argument>{})),