});
}
-void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, std::list<impl::AnnotatedKernel>::iterator handle) {
+void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, impl::OperatorEntry::AnnotatedKernelContainerIterator handle) {
std::lock_guard<std::mutex> lock(mutex_);
op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle);
const OperatorHandle& op,
const OperatorName& op_name,
c10::optional<DispatchKey> dispatch_key,
- std::list<impl::AnnotatedKernel>::iterator kernel_handle);
+ impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
void deregisterFallback_(DispatchKey dispatchKey);
void deregisterLibrary_(const std::string& ns);
dispatchKeyExtractor_.deregisterSchema();
}
-std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
+OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
const c10::Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
// Redirect catchAll registrations to CompositeImplicitAutograd.
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+ if (k[0].kernel.isValid()) {
+#else
if (k.size() > 0) {
+#endif
TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n",
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
);
}
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+ k[0].kernel = std::move(kernel);
+ k[0].inferred_function_schema = std::move(inferred_function_schema);
+ k[0].debug = std::move(debug);
+#else
k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
- std::list<AnnotatedKernel>::iterator inserted = k.begin();
+#endif
+ AnnotatedKernelContainerIterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
// that the dispatch table points to the newest kernel
if (dispatch_key.has_value()) {
void OperatorEntry::deregisterKernel_(
const c10::Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
- std::list<AnnotatedKernel>::iterator kernel
+ AnnotatedKernelContainerIterator kernel
) {
// Redirect catchAll deregistrations to CompositeImplicitAutograd.
DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::CompositeImplicitAutograd;
auto found = kernels_.find(dk);
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
auto& k = found->second;
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+ // We are about to remove the array from the map, no need to do anything.
+#else
k.erase(kernel);
+#endif
if (k.empty()) {
// the invariant says we don't want empty lists but instead remove the list from the map
kernels_.erase(found);
oss << "schema: (none)\n";
}
- auto print_kernel = [&](const char* k_desc, const std::list<AnnotatedKernel>& jts, bool is_alias_key=false) {
+ auto print_kernel = [&](const char* k_desc, const AnnotatedKernelContainer& jts, bool is_alias_key=false) {
int64_t i = 0;
for (const auto& jt : jts) {
oss << k_desc
#include <list>
#include <array>
+#ifdef C10_MOBILE
+#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+#endif
+
namespace c10 {
class Dispatcher;
return name_;
}
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+ using AnnotatedKernelContainer = std::array<AnnotatedKernel, 1>;
+#else
+ using AnnotatedKernelContainer = std::list<AnnotatedKernel>;
+#endif
+ using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator;
+
// Why are kernels and fallback asymmetric? It has to do with ownership.
// Kernels and the computed dispatch tables for them are canonically
// owned by OperatorEntry, but backend fallbacks are specified once
// Precondition: Dispatcher::mutex_ is held
// Postcondition: caller is responsible for disposing of the kernel
- std::list<AnnotatedKernel>::iterator registerKernel(
+ AnnotatedKernelContainerIterator registerKernel(
const Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
void deregisterKernel_(
const Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
- std::list<AnnotatedKernel>::iterator kernel
+ AnnotatedKernelContainerIterator kernel
);
// Precondition: Dispatcher::mutex_ is held
// re-executed and then only allow one kernel here, i.e. error if a kernel
// is already registered, but that's a lot of effort to implement and
// currently not high-pri.
- ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
+ ska::flat_hash_map<DispatchKey,
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+ // On mobile, we needn't worry about Jupyter notebooks.
+ std::array<AnnotatedKernel, 1>
+#else
+ std::list<AnnotatedKernel>
+#endif
+ > kernels_;
const AnnotatedKernel& missingKernel() const;
const AnnotatedKernel& ambiguousAutogradOtherKernel() const;