class RegisterOperators::OperatorRegistrar final {
public:
explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
- : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
- if (dispatch_key_.has_value()) {
- Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator));
- } else {
- Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator));
+ : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), has_kernel_(kernel != nullptr), owns_registration_(true) {
+ // either both, kernel and cache_creator, or none must be set.
+ AT_ASSERT((kernel != nullptr) == static_cast<bool>(cache_creator));
+
+ if (has_kernel_) {
+ if (dispatch_key_.has_value()) {
+ Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator));
+ } else {
+ Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator));
+ }
}
}
OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
- : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) {
+ : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), has_kernel_(rhs.has_kernel_), owns_registration_(rhs.owns_registration_) {
rhs.owns_registration_ = false;
}
~OperatorRegistrar() {
if (owns_registration_) {
- if (dispatch_key_.has_value()) {
- Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_);
- } else {
- Dispatcher::singleton().deregisterFallbackKernel(op_);
+ if (has_kernel_) {
+ if (dispatch_key_.has_value()) {
+ Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_);
+ } else {
+ Dispatcher::singleton().deregisterFallbackKernel(op_);
+ }
}
Dispatcher::singleton().deregisterSchema(op_);
}
private:
const OperatorHandle op_;
const c10::optional<TensorTypeId> dispatch_key_;
+ bool has_kernel_;
bool owns_registration_;
};
void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
- // TODO Should we allow this and only register a schema without a kernel?
- AT_CHECK(config.kernel_func != nullptr,
- "Tried to register an operator with function schema ", toString(schema),
- ", but didn't specify a kernel. Please add a c10::kernel<...>(...) parameter to the registration call.");
+ AT_CHECK(!config.dispatch_key.has_value() || config.kernel_func != nullptr,
+ "Tried to register an operator with a dispatch key but without a kernel. "
+ "Please either specify a kernel or omit the dispatch key to only register the schema.");
+
// if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else.
- AT_ASSERT(static_cast<bool>(config.cache_creator_func));
+ AT_ASSERT((config.kernel_func != nullptr) == static_cast<bool>(config.cache_creator_func));
if (config.inferred_function_schema.get() != nullptr) {
assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
(std::vector<Argument>{Argument("dummy")}),
(std::vector<Argument>{}));
-TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) {
- // make sure it crashes when kernel is absent
- expectThrows<c10::Error>([&] {
- c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1()));
- }, "but didn't specify a kernel");
-
- // but make sure it doesn't crash when kernel is present
- c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
-}
-
TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) {
auto registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
EXPECT_TRUE(called_fallback);
}
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegistering_thenOnlyRegistersSchema) {
+ auto registrar = c10::RegisterOperators().op(dummySchema);
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value()); // assert schema is registered
+ expectThrows<c10::Error>([&] {
+ callOp(*op, dummyTensor(TensorType1()));
+ }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
+}
+
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone) {
+ {
+ auto registrar = c10::RegisterOperators().op(dummySchema);
+ }
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ EXPECT_FALSE(op.has_value());
+}
+
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwards_thenCanBeCalled) {
+ auto registrar1 = c10::RegisterOperators().op(dummySchema);
+
+ bool called_kernel = false;
+ auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value()); // assert schema is registered
+ callOp(*op, dummyTensor(TensorType1()));
+ EXPECT_TRUE(called_kernel);
+}
+
+TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwardsAndRunsOutOfScope_thenSchemaIsStillThereButCannotBeCalledAnymore) {
+ auto registrar1 = c10::RegisterOperators().op(dummySchema);
+
+ {
+ auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+ }
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value()); // assert schema is registered
+ expectThrows<c10::Error>([&] {
+ callOp(*op, dummyTensor(TensorType1()));
+ }, "Didn't find kernel to dispatch to for operator '_test::dummy'");
+}
+
}