*/
template<class FuncType, FuncType* kernel_func>
inline constexpr auto kernel() ->
-// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
-guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value,
+// enable_if: only enable it if FuncType is actually a function
+guts::enable_if_t<guts::is_function_type<FuncType>::value,
decltype(kernel<typename detail::WrapKernelFunction<FuncType, kernel_func>::type>())> {
+ static_assert(!std::is_same<FuncType, KernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
+
return kernel<typename detail::WrapKernelFunction<FuncType, kernel_func>::type>();
}
// SFINAE version for kernels that return an output
template<class KernelFunctor>
struct wrap_kernel_functor<KernelFunctor, guts::enable_if_t<!std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
- static void call(Stack* stack, KernelCache* cache) {
- static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel");
+ static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+ static void call(Stack* stack, KernelCache* cache) {
constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
KernelFunctor* functor = static_cast<KernelFunctor*>(cache);
auto output = call_functor_with_ivalue_args<KernelFunctor>(functor, torch::jit::last(*stack, num_inputs));
// SFINAE version for kernels that don't return an output
template<class KernelFunctor>
struct wrap_kernel_functor<KernelFunctor, guts::enable_if_t<std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
- static void call(Stack* stack, KernelCache* cache) {
- static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel.");
+ static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+ static void call(Stack* stack, KernelCache* cache) {
constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
KernelFunctor* functor = static_cast<KernelFunctor*>(cache);
call_functor_with_ivalue_args<KernelFunctor>(functor, torch::jit::last(*stack, num_inputs));
template<class KernelFunctor, class... Args>
class KernelFactory final {
- public:
- static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel.");
static_assert(std::is_constructible<KernelFunctor, Args...>::value, "Wrong argument types for constructor of kernel functor.");
+ public:
explicit constexpr KernelFactory(Args... args)
: constructor_parameters_(std::move(args)...) {}
};
template<class KernelFunctor, class... ConstructorParameters>
- // enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
- inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
- detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
+ detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>
kernelFunctor(ConstructorParameters&&... constructorParameters) {
return {
&detail::wrap_kernel_functor<KernelFunctor>::call,
* > c10::dispatchKey(CPUTensorId()));
*/
template<class KernelFunctor, class... ConstructorParameters>
-// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
-inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
+// enable_if: only enable it if KernelFunctor is actually a functor
+inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value,
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
kernel(ConstructorParameters&&... constructorParameters) {
+ static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+ static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
+
return detail::kernelFunctor<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...);
}
*/
template<class Lambda>
inline constexpr auto kernel(Lambda&& functor) ->
-guts::enable_if_t<guts::is_stateless_lambda<guts::decay_t<Lambda>>::value,
+// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
+guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value,
decltype(detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor)))> {
+ static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
+
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
// behavior would be nonobvious. A functor kernel with cache gets a new instance of
// its cache each time the kernel is looked up from the dispatch table.
// A lambda with a capture would be global and share its capture between all kernel lookups.
// So, instead of making users having to think about it (including the thread-safety
// issues this causes), let's just forbid stateful lambdas alltogether.
+ static_assert(guts::is_stateless_lambda<guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
+
return detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor));
}
expectCallsIncrement(TensorType1());
}
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenOutOfLineKernel_whenRegistered_thenCanBeCalled) {
+ auto my_kernel = [] (Tensor, int64_t i) {return i+1;};
+ auto registrar = RegisterOperators().op(opSchema, kernel(my_kernel), dispatchKey(TensorType1()));
+ expectCallsIncrement(TensorType1());
+}
+
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
.op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
* > c10::dispatchKey(CPUTensorId()));
*/
template<class... ConfigParameters>
- guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
- op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
+ RegisterOperators op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
+ static_assert(guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value,
+ "Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
+ " and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators.");
+
registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
return std::move(*this);
}
* > static auto registry = c10::RegisterOperators()
* > .op("my_op", c10::kernel<my_kernel_cpu>());
*/
- template<class FuncType>
+ template<class FuncType, class... OtherArgs>
C10_DEPRECATED_MESSAGE("Registering kernels via passing function pointers to op() directly is deprecated. " \
"Please use the new c10::kernel() based API instead.")
// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, RegisterOperators>
- op(FunctionSchema schema, FuncType* func) && {
- // We intentionally don't extend this deprecated API to support dispatch keys
- // and the like to push people towards using the new API.
- return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType*>>(func));
+ op(FunctionSchema schema, FuncType* func, OtherArgs...) && {
+ // We intentionally don't extend this deprecated API to support dispatch keys
+ // and the like to push people towards using the new API.
+ static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
+
+ return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType*>>(func));
}
/**
* > static auto registry = c10::RegisterOperators()
* > .op("my_op", c10::kernel<my_kernel_cpu>());
*/
- template<class FuncType>
+ template<class FuncType, class... OtherArgs>
C10_DEPRECATED_MESSAGE("Registering kernels via passing lambdas to op() directly is deprecated. " \
"Please use the new c10::kernel() based API instead.")
- // enable_if: only enable it if FuncType is actually a functor, but doesn't inherit from OperatorKernel.
- guts::enable_if_t<guts::is_functor<FuncType>::value && !std::is_base_of<OperatorKernel, FuncType>::value, RegisterOperators>
- op(FunctionSchema schema, FuncType&& func) && {
- // We intentionally don't extend this deprecated API to support dispatch keys
- // and the like to push people towards using the new API.
- return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType>>(std::forward<FuncType>(func)));
+ // enable_if: only enable it if FuncType is actually a functor
+ guts::enable_if_t<guts::is_functor<FuncType>::value, RegisterOperators>
+ op(FunctionSchema schema, FuncType&& func, OtherArgs...) && {
+ // We intentionally don't extend this deprecated API to support dispatch keys
+ // and the like to push people towards using the new API.
+ static_assert(sizeof...(OtherArgs) == 0, "The deprecated lambda based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
+
+ static_assert(!std::is_base_of<OperatorKernel, FuncType>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead.");
+
+ return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType>>(std::forward<FuncType>(func)));
}
// TODO allow input schema to be just the operator name + overload name, in that case use schema generated from kernel function