From 2a58fd9844d4413c186df6129836d88b658d3513 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Sat, 30 Mar 2019 00:03:46 -0700 Subject: [PATCH] Fallback kernels (#18443) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18443 Allow registering a kernel without a dispatch key. In this case, the kernel becomes a fallback kernel that is called whenever no other kernel matches. This is also useful for the legacy function based API (since that API doesn't know about dispatch keys) or any other custom ops that don't care about dispatch and just want one kernel to be called no matter the dispatch key. Reviewed By: dzhulgakov Differential Revision: D14603258 fbshipit-source-id: 242dc8871dad2989ca25079854d0cc97429e7199 --- aten/src/ATen/core/dispatch/DispatchTable.h | 85 +++++++++----- aten/src/ATen/core/dispatch/Dispatcher.cpp | 12 +- aten/src/ATen/core/dispatch/Dispatcher.h | 23 +++- .../ATen/core/op_registration/op_registration.cpp | 24 ++-- .../core/op_registration/op_registration_test.cpp | 130 ++++++++++++++++++++- 5 files changed, 228 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 19cedf9..a41e198 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -39,6 +39,21 @@ struct DispatchTableEntry final { }; namespace detail { +inline std::string dispatch_key_to_string(TensorTypeId id) { + // TODO Find better way to stringify tensor type ids without relying on backend + std::string name = ""; + try { + name = toString(tensorTypeIdToBackend(id)); + } catch (const std::exception&) { + // This can fail if the tensor type id is not one of the preregistered backends. + // However, dispatch_key_to_string is used to generate error reports, that + // means an error already has happened when entering this function. + // We don't want inner errors during generation of a report for an + // outer error. Just report an empty name instead. + } + return name + "[" + toString(id) + "]"; +} + /// Kernel implementations in a thread-safe hash table. class ThreadsafeOperatorTable_ final { public: @@ -49,7 +64,7 @@ class ThreadsafeOperatorTable_ final { }); if (!res) { AT_ERROR("Tried to register multiple kernels with same dispatch key '", - dispatch_key_to_string(key), "' for operator '", operator_name ,"'."); + detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'."); } } @@ -60,7 +75,7 @@ class ThreadsafeOperatorTable_ final { assert(num_removed <= 1); // This is not a multi-map if (num_removed == 0) { AT_ERROR("Tried to deregister a kernel with dispatch key '", - dispatch_key_to_string(key), "' for operator '", operator_name, + detail::dispatch_key_to_string(key), "' for operator '", operator_name, "' but that kernel isn't registered. Registered dispatch keys are: ", list_all_dispatch_keys(map)); } @@ -73,9 +88,7 @@ class ThreadsafeOperatorTable_ final { if (found != map.end()) { return &found->second; } else { - AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name, - "'. Tried to look up kernel for dispatch key '", dispatch_key_to_string(key), - "'. Registered dispatch keys are: ", list_all_dispatch_keys(map)); + return nullptr; } }); } @@ -86,34 +99,25 @@ class ThreadsafeOperatorTable_ final { }); } + std::string list_all_dispatch_keys() const { + return map_.read([&](const ska::flat_hash_map& map) -> std::string { + return list_all_dispatch_keys(map); + }); + } + private: static std::string list_all_dispatch_keys(const ska::flat_hash_map& map) { if (map.size() == 0) { return ""; } std::ostringstream str; - str << dispatch_key_to_string(map.begin()->first); + str << detail::dispatch_key_to_string(map.begin()->first); for (auto iter = ++map.begin(); iter != map.end(); ++iter) { - str << ", " << dispatch_key_to_string(iter->first); + str << ", " << detail::dispatch_key_to_string(iter->first); } return str.str(); } - static std::string dispatch_key_to_string(TensorTypeId id) { - // TODO Find better way to stringify tensor type ids without relying on backend - std::string name = ""; - try { - name = toString(tensorTypeIdToBackend(id)); - } catch (const std::exception&) { - // This can fail if the tensor type id is not one of the preregistered backends. - // However, dispatch_key_to_string is used to generate error reports, that - // means an error already has happened when entering this function. - // We don't want inner errors during generation of a report for an - // outer error. Just report an empty name instead. - } - return name + "[" + toString(id) + "]"; - } - LeftRight> map_; }; } // namespace detail @@ -132,7 +136,8 @@ class DispatchTable final { explicit DispatchTable(const FunctionSchema& schema) : kernels_() , dispatch_strategy_(get_dispatch_strategy_(schema)) - , operator_name_(schema.name()) {} + , operator_name_(schema.name()) + , fallback_kernel_(c10::nullopt) {} DispatchTable(DispatchTable&&) = delete; DispatchTable& operator=(DispatchTable&&) = delete; @@ -141,8 +146,8 @@ class DispatchTable final { /** * Register a kernel in the table at some dispatch key. - * @param func Concrete kernel function implementation to register * @param dispatch_key Dispatch key to define when this kernel is selected + * @param kernel Concrete kernel function implementation to register */ void registerKernel( TensorTypeId dispatch_key, @@ -163,15 +168,42 @@ class DispatchTable final { } /** + * Register a fallback kernel. This kernel will be returned from lookup + * whenever no other kernel matches the dispatch key. + * @param kernel Concrete kernel function implementation to register + */ + void registerFallbackKernel(const DispatchTableEntry& kernel) { + fallback_kernel_ = kernel; + } + + /** + * Deregister the fallback kernel. + * Without a fallback kernel, lookup of a dispatch key that doesn't match + * a kernel will fail again. + */ + void deregisterFallbackKernel() { + fallback_kernel_.reset(); + } + + /** * Perform a dynamic dispatch on this table and find the kernel to call * for the given arguments. * * @param args Arguments to invoke the function with - * @return Kernel function pointing to the right kernel for the given arguments + * @return Kernel function pointing to the right kernel for the given arguments. */ const DispatchTableEntry& lookup(const Stack* stack) const { TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack); - return *kernels_.lookup(dispatch_key, operator_name_); + auto found = kernels_.lookup(dispatch_key, operator_name_); + if (nullptr != found) { + return *found; + } + if (fallback_kernel_.has_value()) { + return *fallback_kernel_; + } + AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name_, + "'. Tried to look up kernel for dispatch key '", detail::dispatch_key_to_string(dispatch_key), + "'. Registered dispatch keys are: ", kernels_.list_all_dispatch_keys()); } bool isEmpty() const { @@ -225,6 +257,7 @@ private: detail::ThreadsafeOperatorTable_ kernels_; DispatchStrategy dispatch_strategy_; std::string operator_name_; + c10::optional fallback_kernel_; }; } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 8626fae..c5715d2 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -110,7 +110,17 @@ void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_ void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) { // note: this doesn't need the mutex because write operations on the list keep iterators intact. - op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key); + op.operatorDefIterator_->dispatchTable.deregisterKernel(std::move(dispatch_key)); +} + +void Dispatcher::registerFallbackKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func) { + // note: this doesn't need the mutex because write operations on the list keep iterators intact. + op.operatorDefIterator_->dispatchTable.registerFallbackKernel(DispatchTableEntry{kernel_func, std::move(cache_creator_func)}); +} + +void Dispatcher::deregisterFallbackKernel(const OperatorHandle& op) { + // note: this doesn't need the mutex because write operations on the list keep iterators intact. + op.operatorDefIterator_->dispatchTable.deregisterFallbackKernel(); } void Dispatcher::addRegistrationListener(std::unique_ptr listener) { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 08645a2..d2624b2 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -70,7 +70,7 @@ class RegistrationListenerList; class CAFFE2_API Dispatcher final { private: struct OperatorDef final { - explicit OperatorDef(FunctionSchema schema_) + explicit OperatorDef(FunctionSchema&& schema_) : dispatchTable(schema_) , schema(std::move(schema_)) , refcount(0) {} @@ -124,16 +124,33 @@ public: c10::optional findSchema(const char* operator_name, const char* overload_name); /** - * Register an operator to the dispatch table for an operator. + * Register a kernel to the dispatch table for an operator. + * If dispatch_key is nullopt, then this registers a fallback kernel. */ void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func); /** - * Remove an operator from the dispatch table for an operator. + * Remove a kernel from the dispatch table for an operator. + * If dispatch_key is none, then this deregisters the fallback kernel. + * See documentation for registerKernel() for details. */ void deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key); /** + * Register a fallback kernel for an operator. + * After this, when trying to lookup a kernel for an unknown dispatch key, + * it will not fail anymore, but return the fallback kernel instead. + */ + void registerFallbackKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func); + + /** + * Remove the fallback kernel for an operator. + * After this, if trying to lookup a kernel for an unknown dispatch key, + * the lookup will fail. + */ + void deregisterFallbackKernel(const OperatorHandle& op); + + /** * Perform a dynamic dispatch and get the kernel for an operator. */ OpKernel lookup(const OperatorHandle& op, const Stack* stack) const; diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index d0236b7..f8d4182 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -11,9 +11,13 @@ RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) = default; // table deregisters it in the destructor. class RegisterOperators::OperatorRegistrar final { public: - explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator) + explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator) : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { - Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator)); + if (dispatch_key_.has_value()) { + Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator)); + } else { + Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator)); + } } OperatorRegistrar(OperatorRegistrar&& rhs) noexcept @@ -28,24 +32,22 @@ public: ~OperatorRegistrar() { if (owns_registration_) { - Dispatcher::singleton().deregisterKernel(op_, dispatch_key_); + if (dispatch_key_.has_value()) { + Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_); + } else { + Dispatcher::singleton().deregisterFallbackKernel(op_); + } Dispatcher::singleton().deregisterSchema(op_); } } private: const OperatorHandle op_; - const TensorTypeId dispatch_key_; + const c10::optional dispatch_key_; bool owns_registration_; }; void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) { - // TODO We need to support registrations without a dispatch key, - // at least for the deprecated APIs. Maybe also for new ones. - AT_CHECK(config.dispatch_key.has_value(), - "Tried to register an operator with function schema ", toString(schema), - ", but didn't specify a dispatch key. Please add a c10::dispatchKey(...) parameter to the registration call."); - // TODO Should we allow this and only register a schema without a kernel? AT_CHECK(config.kernel_func != nullptr, "Tried to register an operator with function schema ", toString(schema), @@ -57,7 +59,7 @@ void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegis assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); } - registrars_.emplace_back(std::move(schema), *config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func)); + registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func)); } } diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index aeab038..224bab4 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -16,17 +16,30 @@ using c10::FunctionSchema; using c10::Argument; using c10::kernel; using c10::dispatchKey; +using c10::Dispatcher; using at::Tensor; namespace { C10_DECLARE_TENSOR_TYPE(TensorType1); C10_DEFINE_TENSOR_TYPE(TensorType1); +C10_DECLARE_TENSOR_TYPE(TensorType2); +C10_DEFINE_TENSOR_TYPE(TensorType2); struct DummyKernel final : OperatorKernel { void operator()(Tensor) {} }; +struct MockKernel final : OperatorKernel { + MockKernel(bool* called): called_(called) {} + + void operator()(Tensor) { + *called_ = true; + } +private: + bool* called_; +}; + FunctionSchema dummySchema( "_test::dummy", "", @@ -44,15 +57,122 @@ TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) { c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); } -TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutDispatchKey_thenFails) { - // make sure it crashes when dispatch key is absent +TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) { + auto registrar = c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); EXPECT_THROW( - c10::RegisterOperators().op(dummySchema, kernel()), + callOp(*op, dummyTensor(TensorType2())), c10::Error ); +} - // but make sure it doesn't crash when dispatch key is present - c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); +TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) { + auto registrar = c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); + { + auto inner_registrar = c10::RegisterOperators().op(dummySchema, kernel()); + // this registered a fallback kernel, but now that registration goes out of scope and deregisters it + } + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_THROW( + callOp(*op, dummyTensor(TensorType2())), + c10::Error + ); +} + +TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) { + bool called = false; + auto registrar = c10::RegisterOperators().op(dummySchema, kernel(&called)); // note: no dispatch key means this is the fallback kernel + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called); + callOp(*op, dummyTensor(TensorType2())); + EXPECT_TRUE(called); +} + +TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernelAndOtherKernelOutOfScope_whenCallingOp_thenCallsFallbackKernel) { + bool called = false; + bool other_called = false; + auto registrar = c10::RegisterOperators().op(dummySchema, kernel(&called)); // note: no dispatch key means this is the fallback kernel + { + auto inner_registrar = c10::RegisterOperators().op(dummySchema, kernel(&other_called), dispatchKey(TensorType2())); + } + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called); + callOp(*op, dummyTensor(TensorType2())); + EXPECT_TRUE(called); + EXPECT_FALSE(other_called); +} + +TEST(OperatorRegistrationTest, givenOpWithFirstFallbackAndThenOtherKernel_whenCallingWithCorrectDispatchKey_thenCallsCorrectKernel) { + bool called_kernel = false; + bool called_fallback = false; + auto registrar = c10::RegisterOperators() + .op(dummySchema, kernel(&called_fallback)) // note: no dispatch key means this is the fallback kernel + .op(dummySchema, kernel(&called_kernel), dispatchKey(TensorType1())); + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called_kernel); + EXPECT_FALSE(called_fallback); + callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(called_kernel); + EXPECT_FALSE(called_fallback); +} + +TEST(OperatorRegistrationTest, givenOpWithFirstFallbackAndThenOtherKernel_whenCallingWithWrongDispatchKey_thenCallsFallbackKernel) { + bool called_kernel = false; + bool called_fallback = false; + auto registrar = c10::RegisterOperators() + .op(dummySchema, kernel(&called_fallback)) // note: no dispatch key means this is the fallback kernel + .op(dummySchema, kernel(&called_kernel), dispatchKey(TensorType1())); + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called_kernel); + EXPECT_FALSE(called_fallback); + callOp(*op, dummyTensor(TensorType2())); + EXPECT_FALSE(called_kernel); + EXPECT_TRUE(called_fallback); +} + + +TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCallingWithCorrectDispatchKey_thenCallsCorrectKernel) { + bool called_kernel = false; + bool called_fallback = false; + auto registrar = c10::RegisterOperators() + .op(dummySchema, kernel(&called_kernel), dispatchKey(TensorType1())) + .op(dummySchema, kernel(&called_fallback)); // note: no dispatch key means this is the fallback kernel + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called_kernel); + EXPECT_FALSE(called_fallback); + callOp(*op, dummyTensor(TensorType1())); + EXPECT_TRUE(called_kernel); + EXPECT_FALSE(called_fallback); +} + +TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCallingWithWrongDispatchKey_thenCallsFallbackKernel) { + bool called_kernel = false; + bool called_fallback = false; + auto registrar = c10::RegisterOperators() + .op(dummySchema, kernel(&called_kernel), dispatchKey(TensorType1())) + .op(dummySchema, kernel(&called_fallback)); // note: no dispatch key means this is the fallback kernel + + auto op = Dispatcher::singleton().findSchema("_test::dummy", ""); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called_kernel); + EXPECT_FALSE(called_fallback); + callOp(*op, dummyTensor(TensorType2())); + EXPECT_FALSE(called_kernel); + EXPECT_TRUE(called_fallback); } } -- 2.7.4