Register operators by passing arguments to RegisterOperators constructor (#18577)
authorSebastian Messmer <messmer@fb.com>
Tue, 2 Apr 2019 19:23:15 +0000 (12:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 19:33:33 +0000 (12:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18577

This is also part of the legacy API and we need to support it if we want to replace it.

Reviewed By: dzhulgakov

Differential Revision: D14671432

fbshipit-source-id: 007abf4ab816647a509fc08e35d79b6c1aa55b03

aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp
aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
aten/src/ATen/core/op_registration/op_registration.h

index 206c9bd..a11c5a5 100644 (file)
@@ -88,6 +88,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegiste
   expectCallsIncrement(TensorType1());
 }
 
+TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegisteredInConstructor_thenCanBeCalled) {
+  auto registrar = RegisterOperators(opSchema, &incrementKernel);
+  expectCallsIncrement(TensorType1());
+}
+
 TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
   auto registrar = RegisterOperators()
       .op(opSchema, &incrementKernel)
index 884cd2e..7b62a86 100644 (file)
@@ -67,6 +67,13 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistere
   expectCallsIncrement(TensorType1());
 }
 
+TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredInConstructor_thenCanBeCalled) {
+  auto registrar = RegisterOperators(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+      return input + 1;
+    });
+  expectCallsIncrement(TensorType1());
+}
+
 TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
   auto registrar = RegisterOperators()
       .op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
index 741277f..58bc99e 100644 (file)
@@ -44,7 +44,6 @@ public:
   RegisterOperators(RegisterOperators&&) noexcept;
   RegisterOperators& operator=(RegisterOperators&&);
 
-
   /**
    * Register an operator based on a function schema and a set of configuration
    * parameters (i.e. kernel function, dispatch key, ...).
@@ -69,10 +68,28 @@ public:
       "Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
       " and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators.");
 
-    registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+    op_(std::move(schema), std::forward<ConfigParameters>(configParameters)...);
     return std::move(*this);
   }
 
+  template<class FuncType>
+  C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
+                         "Please use RegisterOperators().op(...) instead.")
+  // enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
+  explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, FunctionSchema> schema, FuncType* func)
+  : RegisterOperators() {
+    legacyAPIOp_(std::move(schema), func);
+  }
+
+  template<class FuncType>
+  C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
+                         "Please use RegisterOperators().op(...) instead.")
+  // enable_if: only enable it if FuncType is actually a functor
+  explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, FunctionSchema> schema, FuncType&& func)
+  : RegisterOperators() {
+    legacyAPIOp_(std::move(schema), std::forward<FuncType>(func));
+  }
+
   /**
    * Deprecated. For backwards compatibility only.
    * Don't use this, it introduces a performance overhead on each kernel call
@@ -114,7 +131,8 @@ public:
      // and the like to push people towards using the new API.
      static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
 
-     return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType*>>(func));
+     legacyAPIOp_(std::move(schema), func);
+     return std::move(*this);
    }
 
    /**
@@ -154,12 +172,23 @@ public:
 
       static_assert(!std::is_base_of<OperatorKernel, FuncType>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead.");
 
-      return std::move(*this).op(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<FuncType>>(std::forward<FuncType>(func)));
+      legacyAPIOp_(std::move(schema), std::forward<FuncType>(func));
+      return std::move(*this);
     }
 
    // TODO allow input schema to be just the operator name + overload name, in that case use schema generated from kernel function
 
 private:
+  template<class... ConfigParameters>
+  void op_(FunctionSchema&& schema, ConfigParameters&&... configParameters) {
+    registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+  }
+
+  template<class FuncType>
+  void legacyAPIOp_(FunctionSchema&& schema, FuncType&& func) {
+    op_(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
+  }
+
   void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
 
   class OperatorRegistrar;