};
namespace detail {
+inline std::string dispatch_key_to_string(TensorTypeId id) {
+ // TODO Find better way to stringify tensor type ids without relying on backend
+ std::string name = "";
+ try {
+ name = toString(tensorTypeIdToBackend(id));
+ } catch (const std::exception&) {
+ // This can fail if the tensor type id is not one of the preregistered backends.
+ // However, dispatch_key_to_string is used to generate error reports, that
+ // means an error already has happened when entering this function.
+ // We don't want inner errors during generation of a report for an
+ // outer error. Just report an empty name instead.
+ }
+ return name + "[" + toString(id) + "]";
+}
+
/// Kernel implementations in a thread-safe hash table.
class ThreadsafeOperatorTable_ final {
public:
});
if (!res) {
AT_ERROR("Tried to register multiple kernels with same dispatch key '",
- dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
+ detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
}
}
assert(num_removed <= 1); // This is not a multi-map
if (num_removed == 0) {
AT_ERROR("Tried to deregister a kernel with dispatch key '",
- dispatch_key_to_string(key), "' for operator '", operator_name,
+ detail::dispatch_key_to_string(key), "' for operator '", operator_name,
"' but that kernel isn't registered. Registered dispatch keys are: ",
list_all_dispatch_keys(map));
}
if (found != map.end()) {
return &found->second;
} else {
- AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name,
- "'. Tried to look up kernel for dispatch key '", dispatch_key_to_string(key),
- "'. Registered dispatch keys are: ", list_all_dispatch_keys(map));
+ return nullptr;
}
});
}
});
}
+ std::string list_all_dispatch_keys() const {
+ return map_.read([&](const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) -> std::string {
+ return list_all_dispatch_keys(map);
+ });
+ }
+
private:
static std::string list_all_dispatch_keys(const ska::flat_hash_map<TensorTypeId, DispatchTableEntry>& map) {
if (map.size() == 0) {
return "";
}
std::ostringstream str;
- str << dispatch_key_to_string(map.begin()->first);
+ str << detail::dispatch_key_to_string(map.begin()->first);
for (auto iter = ++map.begin(); iter != map.end(); ++iter) {
- str << ", " << dispatch_key_to_string(iter->first);
+ str << ", " << detail::dispatch_key_to_string(iter->first);
}
return str.str();
}
- static std::string dispatch_key_to_string(TensorTypeId id) {
- // TODO Find better way to stringify tensor type ids without relying on backend
- std::string name = "";
- try {
- name = toString(tensorTypeIdToBackend(id));
- } catch (const std::exception&) {
- // This can fail if the tensor type id is not one of the preregistered backends.
- // However, dispatch_key_to_string is used to generate error reports, that
- // means an error already has happened when entering this function.
- // We don't want inner errors during generation of a report for an
- // outer error. Just report an empty name instead.
- }
- return name + "[" + toString(id) + "]";
- }
-
LeftRight<ska::flat_hash_map<TensorTypeId, DispatchTableEntry>> map_;
};
} // namespace detail
explicit DispatchTable(const FunctionSchema& schema)
: kernels_()
, dispatch_strategy_(get_dispatch_strategy_(schema))
- , operator_name_(schema.name()) {}
+ , operator_name_(schema.name())
+ , fallback_kernel_(c10::nullopt) {}
DispatchTable(DispatchTable&&) = delete;
DispatchTable& operator=(DispatchTable&&) = delete;
/**
* Register a kernel in the table at some dispatch key.
- * @param func Concrete kernel function implementation to register
* @param dispatch_key Dispatch key to define when this kernel is selected
+ * @param kernel Concrete kernel function implementation to register
*/
void registerKernel(
TensorTypeId dispatch_key,
}
/**
+ * Register a fallback kernel. This kernel will be returned from lookup
+ * whenever no other kernel matches the dispatch key.
+ * @param kernel Concrete kernel function implementation to register
+ */
+ void registerFallbackKernel(const DispatchTableEntry& kernel) {
+ fallback_kernel_ = kernel;
+ }
+
+ /**
+ * Deregister the fallback kernel.
+ * Without a fallback kernel, lookup of a dispatch key that doesn't match
+ * a kernel will fail again.
+ */
+ void deregisterFallbackKernel() {
+ fallback_kernel_.reset();
+ }
+
+ /**
* Perform a dynamic dispatch on this table and find the kernel to call
* for the given arguments.
*
* @param args Arguments to invoke the function with
- * @return Kernel function pointing to the right kernel for the given arguments
+ * @return Kernel function pointing to the right kernel for the given arguments.
*/
const DispatchTableEntry& lookup(const Stack* stack) const {
TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack);
- return *kernels_.lookup(dispatch_key, operator_name_);
+ auto found = kernels_.lookup(dispatch_key, operator_name_);
+ if (nullptr != found) {
+ return *found;
+ }
+ if (fallback_kernel_.has_value()) {
+ return *fallback_kernel_;
+ }
+ AT_ERROR("Didn't find kernel to dispatch to for operator '", operator_name_,
+ "'. Tried to look up kernel for dispatch key '", detail::dispatch_key_to_string(dispatch_key),
+ "'. Registered dispatch keys are: ", kernels_.list_all_dispatch_keys());
}
bool isEmpty() const {
detail::ThreadsafeOperatorTable_ kernels_;
DispatchStrategy dispatch_strategy_;
std::string operator_name_;
+ c10::optional<DispatchTableEntry> fallback_kernel_;
};
} // namespace c10
void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
- op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
+ op.operatorDefIterator_->dispatchTable.deregisterKernel(std::move(dispatch_key));
+}
+
+void Dispatcher::registerFallbackKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func) {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ op.operatorDefIterator_->dispatchTable.registerFallbackKernel(DispatchTableEntry{kernel_func, std::move(cache_creator_func)});
+}
+
+void Dispatcher::deregisterFallbackKernel(const OperatorHandle& op) {
+ // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+ op.operatorDefIterator_->dispatchTable.deregisterFallbackKernel();
}
void Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
class CAFFE2_API Dispatcher final {
private:
struct OperatorDef final {
- explicit OperatorDef(FunctionSchema schema_)
+ explicit OperatorDef(FunctionSchema&& schema_)
: dispatchTable(schema_)
, schema(std::move(schema_))
, refcount(0) {}
c10::optional<OperatorHandle> findSchema(const char* operator_name, const char* overload_name);
/**
- * Register an operator to the dispatch table for an operator.
+ * Register a kernel to the dispatch table for an operator.
+ * If dispatch_key is nullopt, then this registers a fallback kernel.
*/
void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func);
/**
- * Remove an operator from the dispatch table for an operator.
+ * Remove a kernel from the dispatch table for an operator.
+ * If dispatch_key is none, then this deregisters the fallback kernel.
+ * See documentation for registerKernel() for details.
*/
void deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key);
/**
+ * Register a fallback kernel for an operator.
+ * After this, when trying to lookup a kernel for an unknown dispatch key,
+ * it will not fail anymore, but return the fallback kernel instead.
+ */
+ void registerFallbackKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func);
+
+ /**
+ * Remove the fallback kernel for an operator.
+ * After this, if trying to lookup a kernel for an unknown dispatch key,
+ * the lookup will fail.
+ */
+ void deregisterFallbackKernel(const OperatorHandle& op);
+
+ /**
* Perform a dynamic dispatch and get the kernel for an operator.
*/
OpKernel lookup(const OperatorHandle& op, const Stack* stack) const;
// table deregisters it in the destructor.
class RegisterOperators::OperatorRegistrar final {
public:
- explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
+ explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
: op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
- Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator));
+ if (dispatch_key_.has_value()) {
+ Dispatcher::singleton().registerKernel(op_, *dispatch_key_, kernel, std::move(cache_creator));
+ } else {
+ Dispatcher::singleton().registerFallbackKernel(op_, kernel, std::move(cache_creator));
+ }
}
OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
~OperatorRegistrar() {
if (owns_registration_) {
- Dispatcher::singleton().deregisterKernel(op_, dispatch_key_);
+ if (dispatch_key_.has_value()) {
+ Dispatcher::singleton().deregisterKernel(op_, *dispatch_key_);
+ } else {
+ Dispatcher::singleton().deregisterFallbackKernel(op_);
+ }
Dispatcher::singleton().deregisterSchema(op_);
}
}
private:
const OperatorHandle op_;
- const TensorTypeId dispatch_key_;
+ const c10::optional<TensorTypeId> dispatch_key_;
bool owns_registration_;
};
void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
- // TODO We need to support registrations without a dispatch key,
- // at least for the deprecated APIs. Maybe also for new ones.
- AT_CHECK(config.dispatch_key.has_value(),
- "Tried to register an operator with function schema ", toString(schema),
- ", but didn't specify a dispatch key. Please add a c10::dispatchKey(...) parameter to the registration call.");
-
// TODO Should we allow this and only register a schema without a kernel?
AT_CHECK(config.kernel_func != nullptr,
"Tried to register an operator with function schema ", toString(schema),
assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
}
- registrars_.emplace_back(std::move(schema), *config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
+ registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
}
}
using c10::Argument;
using c10::kernel;
using c10::dispatchKey;
+using c10::Dispatcher;
using at::Tensor;
namespace {
C10_DECLARE_TENSOR_TYPE(TensorType1);
C10_DEFINE_TENSOR_TYPE(TensorType1);
+C10_DECLARE_TENSOR_TYPE(TensorType2);
+C10_DEFINE_TENSOR_TYPE(TensorType2);
struct DummyKernel final : OperatorKernel {
void operator()(Tensor) {}
};
+struct MockKernel final : OperatorKernel {
+ MockKernel(bool* called): called_(called) {}
+
+ void operator()(Tensor) {
+ *called_ = true;
+ }
+private:
+ bool* called_;
+};
+
FunctionSchema dummySchema(
"_test::dummy",
"",
c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
}
-TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutDispatchKey_thenFails) {
- // make sure it crashes when dispatch key is absent
+TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) {
+ auto registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
EXPECT_THROW(
- c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>()),
+ callOp(*op, dummyTensor(TensorType2())),
c10::Error
);
+}
- // but make sure it doesn't crash when dispatch key is present
- c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) {
+ auto registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+ {
+ auto inner_registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>());
+ // this registered a fallback kernel, but now that registration goes out of scope and deregisters it
+ }
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_THROW(
+ callOp(*op, dummyTensor(TensorType2())),
+ c10::Error
+ );
+}
+
+TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) {
+ bool called = false;
+ auto registrar = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_FALSE(called);
+ callOp(*op, dummyTensor(TensorType2()));
+ EXPECT_TRUE(called);
+}
+
+TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernelAndOtherKernelOutOfScope_whenCallingOp_thenCallsFallbackKernel) {
+ bool called = false;
+ bool other_called = false;
+ auto registrar = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
+ {
+ auto inner_registrar = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&other_called), dispatchKey(TensorType2()));
+ }
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_FALSE(called);
+ callOp(*op, dummyTensor(TensorType2()));
+ EXPECT_TRUE(called);
+ EXPECT_FALSE(other_called);
+}
+
+TEST(OperatorRegistrationTest, givenOpWithFirstFallbackAndThenOtherKernel_whenCallingWithCorrectDispatchKey_thenCallsCorrectKernel) {
+ bool called_kernel = false;
+ bool called_fallback = false;
+ auto registrar = c10::RegisterOperators()
+ .op(dummySchema, kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
+ .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_FALSE(called_kernel);
+ EXPECT_FALSE(called_fallback);
+ callOp(*op, dummyTensor(TensorType1()));
+ EXPECT_TRUE(called_kernel);
+ EXPECT_FALSE(called_fallback);
+}
+
+TEST(OperatorRegistrationTest, givenOpWithFirstFallbackAndThenOtherKernel_whenCallingWithWrongDispatchKey_thenCallsFallbackKernel) {
+ bool called_kernel = false;
+ bool called_fallback = false;
+ auto registrar = c10::RegisterOperators()
+ .op(dummySchema, kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
+ .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_FALSE(called_kernel);
+ EXPECT_FALSE(called_fallback);
+ callOp(*op, dummyTensor(TensorType2()));
+ EXPECT_FALSE(called_kernel);
+ EXPECT_TRUE(called_fallback);
+}
+
+
+TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCallingWithCorrectDispatchKey_thenCallsCorrectKernel) {
+ bool called_kernel = false;
+ bool called_fallback = false;
+ auto registrar = c10::RegisterOperators()
+ .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()))
+ .op(dummySchema, kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_FALSE(called_kernel);
+ EXPECT_FALSE(called_fallback);
+ callOp(*op, dummyTensor(TensorType1()));
+ EXPECT_TRUE(called_kernel);
+ EXPECT_FALSE(called_fallback);
+}
+
+TEST(OperatorRegistrationTest, givenOpWithFirstOtherAndThenFallbackKernel_whenCallingWithWrongDispatchKey_thenCallsFallbackKernel) {
+ bool called_kernel = false;
+ bool called_fallback = false;
+ auto registrar = c10::RegisterOperators()
+ .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()))
+ .op(dummySchema, kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
+
+ auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
+ ASSERT_TRUE(op.has_value());
+ EXPECT_FALSE(called_kernel);
+ EXPECT_FALSE(called_fallback);
+ callOp(*op, dummyTensor(TensorType2()));
+ EXPECT_FALSE(called_kernel);
+ EXPECT_TRUE(called_fallback);
}
}