Allow registering an operator schema without a kernel (#18551)
authorSebastian Messmer <messmer@fb.com>
Tue, 2 Apr 2019 19:23:13 +0000 (12:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 19:33:30 +0000 (12:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18551

This is helpful for defining a set of operators as an interface but not adding concrete kernels just yet.
The registration logic will ensure that any other libraries that add kernels for these schemas exactly match the schema defined here.

Reviewed By: dzhulgakov

Differential Revision: D14660208

fbshipit-source-id: 7adb5a4876cff5a0ad21d92d8c450cb889f00cc3

aten/src/ATen/core/op_registration/op_registration.cpp
aten/src/ATen/core/op_registration/op_registration_test.cpp

index f8d4182..d63d2b7 100644 (file)
@@ -12,16 +12,21 @@ RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) = default;
 class RegisterOperators::OperatorRegistrar final {
 public:
   explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
-  : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
-    if (dispatch_key_.has_value()) {
-      Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator));
-    } else {
-      Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator));
+  : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), has_kernel_(kernel != nullptr), owns_registration_(true) {
+    // either both, kernel and cache_creator, or none must be set.
+    AT_ASSERT((kernel != nullptr) == static_cast<bool>(cache_creator));
+
+    if (has_kernel_) {
+      if (dispatch_key_.has_value()) {
+        Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator));
+      } else {
+        Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator));
+      }
     }
   }
 
   OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
-  :  op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) {
+  :  op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), has_kernel_(rhs.has_kernel_), owns_registration_(rhs.owns_registration_) {
     rhs.owns_registration_ = false;
   }
 
@@ -32,10 +37,12 @@ public:
 
   ~OperatorRegistrar() {
     if (owns_registration_) {
-      if (dispatch_key_.has_value()) {
-        Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_);
-      } else {
-        Dispatcher::singleton().deregisterFallbackKernel(op_);
+      if (has_kernel_) {
+        if (dispatch_key_.has_value()) {
+          Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_);
+        } else {
+          Dispatcher::singleton().deregisterFallbackKernel(op_);
+        }
       }
       Dispatcher::singleton().deregisterSchema(op_);
     }
@@ -44,16 +51,17 @@ public:
 private:
   const OperatorHandle op_;
   const c10::optional<TensorTypeId> dispatch_key_;
+  bool has_kernel_;
   bool owns_registration_;
 };
 
 void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
-  // TODO Should we allow this and only register a schema without a kernel?
-  AT_CHECK(config.kernel_func != nullptr,
-      "Tried to register an operator with function schema ", toString(schema),
-      ", but didn't specify a kernel. Please add a c10::kernel<...>(...) parameter to the registration call.");
+  AT_CHECK(!config.dispatch_key.has_value() || config.kernel_func != nullptr,
+    "Tried to register an operator with a dispatch key but without a kernel. "
+    "Please either specify a kernel or omit the dispatch key to only register the schema.");
+
   // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else.
-  AT_ASSERT(static_cast<bool>(config.cache_creator_func));
+  AT_ASSERT((config.kernel_func != nullptr) == static_cast<bool>(config.cache_creator_func));
 
   if (config.inferred_function_schema.get() != nullptr) {
     assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
index f81d5ee..31ef98b 100644 (file)
@@ -46,16 +46,6 @@ FunctionSchema dummySchema(
     (std::vector<Argument>{Argument("dummy")}),
     (std::vector<Argument>{}));
 
-TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) {
-  // make sure it crashes when kernel is absent
-  expectThrows<c10::Error>([&] {
-    c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1()));
-  }, "but didn't specify a kernel");
-
-  // but make sure it doesn't crash when kernel is present
-  c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
-}
-
 TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) {
   auto registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
 
@@ -172,4 +162,49 @@ TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCa
   EXPECT_TRUE(called_fallback);
 }
 
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegistering_thenOnlyRegistersSchema) {
+  auto registrar = c10::RegisterOperators().op(dummySchema);
+
+  auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+  ASSERT_TRUE(op.has_value()); // assert schema is registered
+  expectThrows<c10::Error>([&] {
+    callOp(*op, dummyTensor(TensorType1()));
+  }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
+}
+
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone) {
+  {
+    auto registrar = c10::RegisterOperators().op(dummySchema);
+  }
+
+  auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+  EXPECT_FALSE(op.has_value());
+}
+
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwards_thenCanBeCalled) {
+  auto registrar1 = c10::RegisterOperators().op(dummySchema);
+
+  bool called_kernel = false;
+  auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+
+  auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+  ASSERT_TRUE(op.has_value()); // assert schema is registered
+  callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(called_kernel);
+}
+
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwardsAndRunsOutOfScope_thenSchemaIsStillThereButCannotBeCalledAnymore) {
+  auto registrar1 = c10::RegisterOperators().op(dummySchema);
+
+  {
+    auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+  }
+
+  auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+  ASSERT_TRUE(op.has_value()); // assert schema is registered
+  expectThrows<c10::Error>([&] {
+    callOp(*op, dummyTensor(TensorType1()));
+  }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
+}
+
 }