From 58f5954252ba8970aae965d4c5b010f28ad15c06 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Tue, 2 Apr 2019 12:23:13 -0700 Subject: [PATCH] Allow registering an operator schema without a kernel (#18551) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18551 This is helpful for defining a set of operators as an interface but not adding concrete kernels just yet. The registration logic will ensure that any other libraries that add kernels for these schemas exactly match the schema defined here. Reviewed By: dzhulgakov Differential Revision: D14660208 fbshipit-source-id: 7adb5a4876cff5a0ad21d92d8c450cb889f00cc3 --- .../ATen/core/op_registration/op_registration.cpp | 38 +++++++++------ .../core/op_registration/op_registration_test.cpp | 55 ++++++++++++++++++---- 2 files changed, 68 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index f8d4182..d63d2b7 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -12,16 +12,21 @@ RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) = default; class RegisterOperators::OperatorRegistrar final { public: explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator) - : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { - if (dispatch_key_.has_value()) { - Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator)); - } else { - Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator)); + : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), has_kernel_(kernel != nullptr), owns_registration_(true) { + // either both, kernel and cache_creator, or none must be set. + AT_ASSERT((kernel != nullptr) == static_cast(cache_creator)); + + if (has_kernel_) { + if (dispatch_key_.has_value()) { + Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator)); + } else { + Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator)); + } } } OperatorRegistrar(OperatorRegistrar&& rhs) noexcept - : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) { + : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), has_kernel_(rhs.has_kernel_), owns_registration_(rhs.owns_registration_) { rhs.owns_registration_ = false; } @@ -32,10 +37,12 @@ public: ~OperatorRegistrar() { if (owns_registration_) { - if (dispatch_key_.has_value()) { - Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_); - } else { - Dispatcher::singleton().deregisterFallbackKernel(op_); + if (has_kernel_) { + if (dispatch_key_.has_value()) { + Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_); + } else { + Dispatcher::singleton().deregisterFallbackKernel(op_); + } } Dispatcher::singleton().deregisterSchema(op_); } @@ -44,16 +51,17 @@ public: private: const OperatorHandle op_; const c10::optional dispatch_key_; + bool has_kernel_; bool owns_registration_; }; void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) { - // TODO Should we allow this and only register a schema without a kernel? - AT_CHECK(config.kernel_func != nullptr, - "Tried to register an operator with function schema ", toString(schema), - ", but didn't specify a kernel. Please add a c10::kernel<...>(...) parameter to the registration call."); + AT_CHECK(!config.dispatch_key.has_value() || config.kernel_func != nullptr, + "Tried to register an operator with a dispatch key but without a kernel. " + "Please either specify a kernel or omit the dispatch key to only register the schema."); + // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else. - AT_ASSERT(static_cast(config.cache_creator_func)); + AT_ASSERT((config.kernel_func != nullptr) == static_cast(config.cache_creator_func)); if (config.inferred_function_schema.get() != nullptr) { assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index f81d5ee..31ef98b 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -46,16 +46,6 @@ FunctionSchema dummySchema( (std::vector{Argument("dummy")}), (std::vector{})); -TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) { - // make sure it crashes when kernel is absent - expectThrows([&] { - c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())); - }, "but didn't specify a kernel"); - - // but make sure it doesn't crash when kernel is present - c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); -} - TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) { auto registrar = c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); @@ -172,4 +162,49 @@ TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCa EXPECT_TRUE(called_fallback); } +TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegistering_thenOnlyRegistersSchema) { + auto registrar = c10::RegisterOperators().op(dummySchema); + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); // assert schema is registered + expectThrows([&] { + callOp(*op, dummyTensor(TensorType1())); + }, "Didn't find kernel to dispatch to for operator '_test::dummy'"); +} + +TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone) { + { + auto registrar = c10::RegisterOperators().op(dummySchema); + } + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + EXPECT_FALSE(op.has_value()); +} + +TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwards_thenCanBeCalled) { + auto registrar1 = c10::RegisterOperators().op(dummySchema); + + bool called_kernel = false; + auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel(&called_kernel), dispatchKey(TensorType1())); + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); // assert schema is registered + callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(called_kernel); +} + +TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwardsAndRunsOutOfScope_thenSchemaIsStillThereButCannotBeCalledAnymore) { + auto registrar1 = c10::RegisterOperators().op(dummySchema); + + { + auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); + } + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); // assert schema is registered + expectThrows([&] { + callOp(*op, dummyTensor(TensorType1())); + }, "Didn't find kernel to dispatch to for operator '_test::dummy'"); +} + } -- 2.7.4