TensorTypeId dispatch_key;
KernelFunction* kernel_func = nullptr;
KernelCacheCreatorFunction cache_creator_func = nullptr;
+ std::unique_ptr<FunctionSchema> inferred_function_schema = nullptr;
};
// is_registration_config_parameter is a concept that returns true_type iff its argument is
--- /dev/null
+#include "infer_schema.h"
+#include <sstream>
+
+namespace c10 {
+
+namespace {
+ std::string serialize_schema(const FunctionSchema& schema) {
+ std::ostringstream str;
+ str << schema;
+ return str.str();
+ }
+}
+
+C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified) {
+ if (inferred.arguments().size() != specified.arguments().size()) {
+ AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+ "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+ "The number of arguments is different. Specified ", specified.arguments().size(),
+ " but inferred ", inferred.arguments().size());
+ }
+ if (inferred.returns().size() != specified.returns().size()) {
+ AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+ "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+ "The number of returns is different.Specified ", specified.returns().size(),
+ " but inferred ", inferred.returns().size());
+ }
+
+ for (size_t i = 0; i < inferred.arguments().size(); ++i) {
+ if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) {
+ AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+ "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+ "Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(),
+ " but inferred ", inferred.arguments()[i].type()->str());
+ }
+ }
+
+ for (size_t i = 0; i < inferred.returns().size(); ++i) {
+ if (*inferred.returns()[i].type() != *specified.returns()[i].type()) {
+ AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
+ "doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
+ "Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(),
+ " but inferred ", inferred.returns()[i].type()->str());
+ }
+ }
+}
+
+}
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(
- !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
- "INVALID TYPE: Only int64_t is supported as an integral argument type");
+ !std::is_integral<T>::value || std::is_same<T, int64_t>::value || std::is_same<T, bool>::value,
+ "INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
static_assert(
!std::is_same<T, float>::value,
"INVALID TYPE: float is not supported as an argument type, use double instead");
}
-template <typename First, typename Second, typename... Rest>
-void checkStaticTypes() {
- checkStaticTypes<First>();
- checkStaticTypes<Second, Rest...>();
-}
-
template <typename... Ts, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
- checkStaticTypes<guts::decay_t<Ts>...>();
+ // Check types for common errors
+ (void)std::initializer_list<int>{(
+ checkStaticTypes<Ts>()
+ , 0)...};
+
// Arguments are named "_<index>"
- return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
+ return {Argument("_" + c10::guts::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
}
-template <typename... Ts, size_t... Is>
-::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
- return createArgumentVectorFromTypes<Ts..., Is...>();
-}
+/// Creates a vector of `Argument` from a list of C++ types that are specified
+/// as template arguments.
+template<class ParameterTypes> struct createArguments final {};
+template<class... ParameterTypes>
+struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
+ static std::vector<Argument> call() {
+ return createArgumentVectorFromTypes<ParameterTypes...>(
+ guts::make_index_sequence<sizeof...(ParameterTypes)>()
+ );
+ }
+};
-/// Unpack a tuple return type into a vector of return types, one per tuple
-/// element.
-template <typename... Ts>
-::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
- return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
-}
+/// Creates a vector of `Argument` from a list of C++ types that are specified
+/// as a tuple (i.e. in the way c10 kernels return values).
+/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
+/// It can be an empty tuple<>, or void for kernels that don't return anything.
+/// It can be a single type A (i.e. no tuple) for the case where a kernel just
+/// returns one value.
+template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
-/// Create a single-element `vector` for simple (non-tuple) return types.
-template <typename ReturnType>
-::std::vector<Argument> createReturns(ReturnType*) {
- checkStaticTypes<guts::decay_t<ReturnType>>();
- return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
-}
+template<class... ReturnTypes>
+struct createReturns<std::tuple<ReturnTypes...>, void> final {
+ static std::vector<Argument> call() {
+ return createArgumentVectorFromTypes<ReturnTypes...>(
+ guts::make_index_sequence<sizeof...(ReturnTypes)>()
+ );
+ }
+};
-/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
-/// into the argument list.
-template <typename FunctionTraits, size_t... Is>
-::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
- using ArgumentTypes = typename FunctionTraits::parameter_types;
- return createArgumentVectorFromTypes<
- c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
-}
+template<class ReturnType>
+struct createReturns<ReturnType, guts::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
+ static std::vector<Argument> call() {
+ return createReturns<std::tuple<ReturnType>>::call();
+ }
+};
+
+template<>
+struct createReturns<void, void> final {
+ static std::vector<Argument> call() {
+ return createReturns<std::tuple<>>::call();
+ }
+};
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function.
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
using ReturnType = typename FunctionTraits::return_type;
+ using ParameterTypes = typename FunctionTraits::parameter_types;
- auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
- guts::make_index_sequence<FunctionTraits::number_of_parameters>());
- auto returns = createReturns(static_cast<ReturnType*>(nullptr));
+ auto arguments = createArguments<ParameterTypes>::call();
+ auto returns = createReturns<ReturnType>::call();
return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
}
return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
}
+C10_API void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified);
+
}
using c10::FunctionSchema;
using c10::Argument;
using c10::IntType;
+using c10::FloatType;
using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
C10_DECLARE_TENSOR_TYPE(TensorType2);
C10_DEFINE_TENSOR_TYPE(TensorType2);
-void errorKernel(const Tensor&) {
+int64_t errorKernel(const Tensor& tensor, int64_t input) {
EXPECT_TRUE(false); // this kernel should never be called
+ return 0;
}
FunctionSchema errorOpSchema(
"_test::error",
"",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
+ (std::vector<Argument>{Argument("dummy"),
+ Argument("input", IntType::get())}),
+ (std::vector<Argument>{Argument("output", IntType::get())}));
-int incrementKernel(const Tensor& tensor, int input) {
+int64_t incrementKernel(const Tensor& tensor, int64_t input) {
return input + 1;
}
-int decrementKernel(const Tensor& tensor, int input) {
+int64_t decrementKernel(const Tensor& tensor, int64_t input) {
return input - 1;
}
EXPECT_EQ(0, result.size());
}
-int kernelWithIntOutput(Tensor, int a, int b) {
+int64_t kernelWithIntOutput(Tensor, int64_t a, int64_t b) {
return a + b;
}
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
}
-std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int input1, int input2, int input3) {
+std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
return {input1, input2, input3};
}
EXPECT_EQ(TensorType2(), captured_input.type_id());
}
-int captured_int_input = 0;
+int64_t captured_int_input = 0;
-void kernelWithIntInputWithoutOutput(Tensor, int input1) {
+void kernelWithIntInputWithoutOutput(Tensor, int64_t input1) {
captured_int_input = input1;
}
EXPECT_EQ(3, captured_int_input);
}
-int kernelWithIntInputWithOutput(Tensor, int input1) {
+int64_t kernelWithIntInputWithOutput(Tensor, int64_t input1) {
return input1 + 1;
}
EXPECT_EQ(4, outputs[0].toInt());
}
-int captured_input_list_size = 0;
+int64_t captured_input_list_size = 0;
void kernelWithIntListInputWithoutOutput(Tensor, ArrayRef<int64_t> input1) {
captured_input_list_size = input1.size();
EXPECT_EQ(3, captured_input_list_size);
}
-int kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
+int64_t kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
return input1.size();
}
EXPECT_EQ(2, captured_input_list_size);
}
-int kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
+int64_t kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
return input1.size();
}
EXPECT_EQ(2, outputs[0].toInt());
}
+template<class Return, class... Args> struct kernel_func final {
+ static Return func(Args...) { return {}; }
+};
+template<class... Args> struct kernel_func<void, Args...> final {
+ static void func(Args...) {}
+};
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1", IntType::get()),
+ Argument("ret2", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret")})
+ ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret"), Argument("ret2")})
+ ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1")})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret")})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", FloatType::get())})
+ ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret")})
+ ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", FloatType::get())})
+ ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
+ ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
}
#pragma once
#include <ATen/core/op_registration/kernel_stackbased.h>
+#include <ATen/core/op_registration/infer_schema.h>
namespace c10 {
/**
private:
std::tuple<Args...> constructor_parameters_;
};
+
+ template<class KernelFunctor>
+ class FunctionSchemaInferer final {
+ public:
+ std::unique_ptr<FunctionSchema> operator()() const {
+ return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
+ }
+ };
}
/**
* > c10::dispatchKey(CPUTensorId()));
*/
template<class KernelFunctor, class... ConstructorParameters>
-inline constexpr auto kernel(ConstructorParameters&&... constructorParameters)
// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
--> guts::enable_if_t<
-guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
-decltype(kernel(
+inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
+detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
+kernel(ConstructorParameters&&... constructorParameters) {
+ return {
&detail::wrap_kernel_functor<KernelFunctor>::call,
- detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
-))> {
- static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "KernelFunctor cannot be constructed with the given arguments");
-
- return kernel(
- &detail::wrap_kernel_functor<KernelFunctor>::call,
- detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
- );
+ detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
+ detail::FunctionSchemaInferer<KernelFunctor>()
+ };
}
}
using c10::OperatorKernel;
using c10::Argument;
using c10::IntType;
+using c10::FloatType;
using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
C10_DEFINE_TENSOR_TYPE(TensorType2);
struct ErrorKernel final : public OperatorKernel {
- void operator()(const Tensor&) {
+ int64_t operator()(const Tensor&, int64_t) {
EXPECT_TRUE(false); // this kernel should never be called
+ return 0;
}
};
FunctionSchema errorOpSchema(
"_test::error",
"",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
+ (std::vector<Argument>{Argument("dummy"),
+ Argument("input", IntType::get())}),
+ (std::vector<Argument>{Argument("output", IntType::get())}));
struct IncrementKernel final : OperatorKernel {
- int operator()(const Tensor& tensor, int input) {
+ int64_t operator()(const Tensor& tensor, int64_t input) {
return input + 1;
}
};
struct DecrementKernel final : OperatorKernel {
- int operator()(const Tensor& tensor, int input) {
+ int64_t operator()(const Tensor& tensor, int64_t input) {
return input - 1;
}
};
}
struct KernelWithIntOutput final : OperatorKernel {
- int operator()(Tensor, int a, int b) {
+ int64_t operator()(Tensor, int64_t a, int64_t b) {
return a + b;
}
};
}
struct KernelWithIntListOutput final : OperatorKernel {
- std::vector<int64_t> operator()(const Tensor&, int input1, int input2, int input3) {
+ std::vector<int64_t> operator()(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
return {input1, input2, input3};
}
};
EXPECT_EQ(TensorType2(), captured_input.type_id());
}
-int captured_int_input = 0;
+int64_t captured_int_input = 0;
struct KernelWithIntInputWithoutOutput final : OperatorKernel {
- void operator()(Tensor, int input1) {
+ void operator()(Tensor, int64_t input1) {
captured_int_input = input1;
}
};
}
struct KernelWithIntInputWithOutput final : OperatorKernel {
- int operator()(Tensor, int input1) {
+ int64_t operator()(Tensor, int64_t input1) {
return input1 + 1;
}
};
EXPECT_EQ(4, outputs[0].toInt());
}
-int captured_input_list_size = 0;
+int64_t captured_input_list_size = 0;
struct KernelWithIntListInputWithoutOutput final : OperatorKernel {
void operator()(Tensor, ArrayRef<int64_t> input1) {
}
struct KernelWithIntListInputWithOutput final : OperatorKernel {
- int operator()(Tensor, ArrayRef<int64_t> input1) {
+ int64_t operator()(Tensor, ArrayRef<int64_t> input1) {
return input1.size();
}
};
}
struct KernelWithTensorListInputWithOutput final : OperatorKernel {
- int operator()(ArrayRef<Tensor> input1) {
+ int64_t operator()(ArrayRef<Tensor> input1) {
return input1.size();
}
};
EXPECT_EQ(13, outputs[0].toInt());
}
+template<class Return, class... Args> struct KernelFunc final : OperatorKernel{
+ Return operator()(Args...) { return {}; }
+};
+template<class... Args> struct KernelFunc<void, Args...> final : OperatorKernel {
+ void operator()(Args...) {}
+};
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1", IntType::get()),
+ Argument("ret2", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret")})
+ ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret"), Argument("ret2")})
+ ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{})
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1")})
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
+ ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", IntType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret")})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", FloatType::get())})
+ ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret")})
+ ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret", FloatType::get())})
+ ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // assert this does not fail because it matches
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
+ ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+
+ // and now a set of mismatching schemas
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
+ ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ EXPECT_THROW(
+ RegisterOperators()
+ .op(FunctionSchema(
+ "_test::mismatch",
+ "",
+ (std::vector<Argument>{Argument("arg")}),
+ (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
+ ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
+ c10::Error
+ );
+}
}
namespace detail {
- template<class KernelCacheCreatorFunction_>
+ struct NoFunctionSchemaInference final {
+ std::unique_ptr<FunctionSchema> operator()() const {
+ return nullptr;
+ }
+ };
+
+ template<class KernelCacheCreatorFunction_, class InferFunctionSchemaFunction>
struct KernelRegistrationConfigParameter final {
template<class KernelCacheCreatorFunction__>
- constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func)
- : kernel_func_(kernel_func), cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func)) {
+ constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func, InferFunctionSchemaFunction&& infer_function_schema_func)
+ : kernel_func_(kernel_func)
+ , cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func))
+ , infer_function_schema_func_(std::forward<InferFunctionSchemaFunction>(infer_function_schema_func)) {
}
void apply(KernelRegistrationConfig* registration) const & {
registration->kernel_func = kernel_func_;
registration->cache_creator_func = cache_creator_func_;
+ registration->inferred_function_schema = infer_function_schema_func_();
}
void apply(KernelRegistrationConfig* registration) && {
registration->kernel_func = kernel_func_;
registration->cache_creator_func = std::move(cache_creator_func_);
+ registration->inferred_function_schema = std::move(infer_function_schema_func_)();
}
private:
KernelFunction* kernel_func_;
KernelCacheCreatorFunction_ cache_creator_func_;
+ InferFunctionSchemaFunction infer_function_schema_func_;
};
- static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+ static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction, NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
}
/**
* > c10::dispatchKey(CPUTensorId()));
*/
template<class KernelCacheCreatorFunction_>
-inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
- static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
+ static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
- return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator)};
+ return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator), detail::NoFunctionSchemaInference()};
}
}
#include <ATen/core/op_registration/kernel_stackbased.h>
#include <ATen/core/op_registration/kernel_functor.h>
#include <ATen/core/op_registration/kernel_function.h>
+#include <ATen/core/op_registration/infer_schema.h>
namespace c10 {
guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
+
+ if (config.inferred_function_schema.get() != nullptr) {
+ assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
+ }
+
registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
return std::move(*this);
}
namespace std {
// We use SFINAE to detect if std::to_string exists for a type, but that only works
// if the function name is defined. So let's define a std::to_string for a dummy type.
+// If you're getting an error here saying that this overload doesn't match your
+// std::to_string() call, then you're calling std::to_string() but should be calling
+// c10::guts::to_string().
inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; }
}
namespace c10 { namespace guts { namespace detail {
template<template<class> class C>
struct is_type_condition<C, guts::enable_if_t<std::is_same<bool, guts::remove_cv_t<decltype(C<int>::value)>>::value>> : std::true_type {};
+
}
}
(std::vector<c10::Argument>{
c10::Argument("inputs", ListType::ofTensors()),
c10::Argument("output"),
- c10::Argument("split_info", FloatType::get()),
+ c10::Argument("split_info"),
c10::Argument("add", IntType::get()),
c10::Argument("add_axis", IntType::get())}),
(std::vector<c10::Argument>{})),