Improve compiler error messages of the op registration API (#18550)
authorSebastian Messmer <messmer@fb.com>
Tue, 2 Apr 2019 19:23:13 +0000 (12:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 19:33:27 +0000 (12:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18550

When the operator registration API is used wrongly, in most cases we should now get a nice compiler error
instead of weird template error messages.

This is done by making the enable_if conditions more broad so they also match error cases,
but then having static_asserts against these error cases inside the function.
Before that, since the function didn't match, the error message said something like "no function found to match your call",
now it will show the error message specified in the static_asserts.

Reviewed By: dzhulgakov

Differential Revision: D14659178

fbshipit-source-id: 7ca4fb72d9051eadf0a7e2717b962bf1213a52b2

aten/src/ATen/core/op_registration/kernel_function.h
aten/src/ATen/core/op_registration/kernel_functor.h
aten/src/ATen/core/op_registration/kernel_lambda.h
aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
aten/src/ATen/core/op_registration/op_registration.h

index ba14d5a..e5d7bb3 100644 (file)
@@ -40,9 +40,11 @@ namespace detail {
  */
 template<class FuncType, FuncType* kernel_func>
 inline constexpr auto kernel() ->
-// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
-guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value,
+// enable_if: only enable it if FuncType is actually a function
+guts::enable_if_t<guts::is_function_type<FuncType>::value,
 decltype(kernel<typename detail::WrapKernelFunction<FuncType, kernel_func>::type>())> {
+  static_assert(!std::is_same<FuncType, KernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
+
   return kernel<typename detail::WrapKernelFunction<FuncType, kernel_func>::type>();
 }
 
index 427de09..1520d46 100644 (file)
@@ -88,9 +88,9 @@ namespace detail {
   // SFINAE version for kernels that return an output
   template<class KernelFunctor>
   struct wrap_kernel_functor<KernelFunctor, guts::enable_if_t<!std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
-    static void call(Stack* stack, KernelCache* cache) {
-      static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel");
+    static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
 
+    static void call(Stack* stack, KernelCache* cache) {
       constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
       KernelFunctor* functor = static_cast<KernelFunctor*>(cache);
       auto output = call_functor_with_ivalue_args<KernelFunctor>(functor, torch::jit::last(*stack, num_inputs));
@@ -102,9 +102,9 @@ namespace detail {
   // SFINAE version for kernels that don't return an output
   template<class KernelFunctor>
   struct wrap_kernel_functor<KernelFunctor, guts::enable_if_t<std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
-    static void call(Stack* stack, KernelCache* cache) {
-      static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel.");
+    static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
 
+    static void call(Stack* stack, KernelCache* cache) {
       constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
       KernelFunctor* functor = static_cast<KernelFunctor*>(cache);
       call_functor_with_ivalue_args<KernelFunctor>(functor, torch::jit::last(*stack, num_inputs));
@@ -114,10 +114,9 @@ namespace detail {
 
   template<class KernelFunctor, class... Args>
   class KernelFactory final {
-  public:
-    static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel.");
     static_assert(std::is_constructible<KernelFunctor, Args...>::value, "Wrong argument types for constructor of kernel functor.");
 
+  public:
     explicit constexpr KernelFactory(Args... args)
     : constructor_parameters_(std::move(args)...) {}
 
@@ -140,9 +139,7 @@ namespace detail {
   };
 
   template<class KernelFunctor, class... ConstructorParameters>
-  // enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
-  inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
-  detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
+  detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>
   kernelFunctor(ConstructorParameters&&... constructorParameters) {
     return {
       &detail::wrap_kernel_functor<KernelFunctor>::call,
@@ -189,10 +186,13 @@ namespace detail {
  * >         c10::dispatchKey(CPUTensorId()));
  */
 template<class KernelFunctor, class... ConstructorParameters>
-// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
-inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
+// enable_if: only enable it if KernelFunctor is actually a functor
+inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value,
 detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
 kernel(ConstructorParameters&&... constructorParameters) {
+  static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+  static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
+
   return detail::kernelFunctor<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...);
 }
 
index a5b8244..6ae57da 100644 (file)
@@ -46,14 +46,19 @@ namespace detail {
  */
 template<class Lambda>
 inline constexpr auto kernel(Lambda&& functor) ->
-guts::enable_if_t<guts::is_stateless_lambda<guts::decay_t<Lambda>>::value,
+// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
+guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value,
 decltype(detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor)))> {
+  static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
+
   // We don't support stateful lambdas (i.e. lambdas with a capture), because their
   // behavior would be nonobvious. A functor kernel with cache gets a new instance of
   // its cache each time the kernel is looked up from the dispatch table.
   // A lambda with a capture would be global and share its capture between all kernel lookups.
   // So, instead of making users having to think about it (including the thread-safety
   // issues this causes), let's just forbid stateful lambdas alltogether.
+  static_assert(guts::is_stateless_lambda<guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
+
   return detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor));
 }
 
index c156e48..7c12a78 100644 (file)
@@ -67,6 +67,12 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_then
   expectCallsIncrement(TensorType1());
 }
 
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenOutOfLineKernel_whenRegistered_thenCanBeCalled) {
+  auto my_kernel = [] (Tensor, int64_t i) {return i+1;};
+  auto registrar = RegisterOperators().op(opSchema, kernel(my_kernel), dispatchKey(TensorType1()));
+  expectCallsIncrement(TensorType1());
+}
+
 TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
   auto registrar = RegisterOperators()
       .op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
index 645c265..741277f 100644 (file)
@@ -64,8 +64,11 @@ public:
    * >         c10::dispatchKey(CPUTensorId()));
    */
   template<class... ConfigParameters>
-  guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
-  op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
+  RegisterOperators op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
+    static_assert(guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value,
+      "Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
+      " and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators.");
+
     registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
     return std::move(*this);
   }
@@ -101,15 +104,17 @@ public:
    * > static auto registry = c10::RegisterOperators()
    * >     .op("my_op", c10::kernel<my_kernel_cpu>());
    */
-   template<class FuncType>
+   template<class FuncType, class...  OtherArgs>
    C10_DEPRECATED_MESSAGE("Registering kernels via passing function pointers to op() directly is deprecated. " \
                           "Please use the new c10::kernel() based API instead.")
    // enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
    guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, RegisterOperators>
-   op(FunctionSchema schema, FuncType* func) && {
-    // We intentionally don't extend this deprecated API to support dispatch keys
-    // and the like to push people towards using the new API.
-    return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType*>>(func));
+   op(FunctionSchema schema, FuncType* func, OtherArgs...) && {
+     // We intentionally don't extend this deprecated API to support dispatch keys
+     // and the like to push people towards using the new API.
+     static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
+
+     return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType*>>(func));
    }
 
    /**
@@ -137,15 +142,19 @@ public:
     * > static auto registry = c10::RegisterOperators()
     * >     .op("my_op", c10::kernel<my_kernel_cpu>());
     */
-    template<class FuncType>
+    template<class FuncType, class...  OtherArgs>
     C10_DEPRECATED_MESSAGE("Registering kernels via passing lambdas to op() directly is deprecated. " \
                            "Please use the new c10::kernel() based API instead.")
-    // enable_if: only enable it if FuncType is actually a functor, but doesn't inherit from OperatorKernel.
-    guts::enable_if_t<guts::is_functor<FuncType>::value && !std::is_base_of<OperatorKernel, FuncType>::value, RegisterOperators>
-    op(FunctionSchema schema, FuncType&& func) && {
-     // We intentionally don't extend this deprecated API to support dispatch keys
-     // and the like to push people towards using the new API.
-     return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType>>(std::forward<FuncType>(func)));
+    // enable_if: only enable it if FuncType is actually a functor
+    guts::enable_if_t<guts::is_functor<FuncType>::value, RegisterOperators>
+    op(FunctionSchema schema, FuncType&& func, OtherArgs...) && {
+      // We intentionally don't extend this deprecated API to support dispatch keys
+      // and the like to push people towards using the new API.
+      static_assert(sizeof...(OtherArgs) == 0, "The deprecated lambda based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
+
+      static_assert(!std::is_base_of<OperatorKernel, FuncType>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead.");
+
+      return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType>>(std::forward<FuncType>(func)));
     }
 
    // TODO allow input schema to be just the operator name + overload name, in that case use schema generated from kernel function