[PyTorch] Don't store multiple kernels per key on mobile (#64447)
authorScott Wolchok <swolchok@fb.com>
Tue, 14 Sep 2021 17:35:04 +0000 (10:35 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 17:36:43 +0000 (10:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64447

As the code comment says, we needn't worry about Jupyter notebooks on mobile.
ghstack-source-id: 137951718

Test Plan: Profiled startup of //caffe2/caffe2/fb/high_perf_models/pytorch/benchmark_framework_overheads:cpp_benchmark on devserver with -niter 0 -nrep 0 and `C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY` defined. Time spent in sherwood_v3_table lookups went way down.

Reviewed By: ezyang, bhosmer

Differential Revision: D30736094

fbshipit-source-id: bcc22cd0d9adceba259a03898c992759d501fe89

aten/src/ATen/core/dispatch/Dispatcher.cpp
aten/src/ATen/core/dispatch/Dispatcher.h
aten/src/ATen/core/dispatch/OperatorEntry.cpp
aten/src/ATen/core/dispatch/OperatorEntry.h

index ebd7c48..589292b 100644 (file)
@@ -214,7 +214,7 @@ RegistrationHandleRAII Dispatcher::registerImpl(
   });
 }
 
-void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, std::list<impl::AnnotatedKernel>::iterator handle) {
+void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, impl::OperatorEntry::AnnotatedKernelContainerIterator handle) {
   std::lock_guard<std::mutex> lock(mutex_);
 
   op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle);
index cfa6b74..5fd68ab 100644 (file)
@@ -268,7 +268,7 @@ private:
     const OperatorHandle& op,
     const OperatorName& op_name,
     c10::optional<DispatchKey> dispatch_key,
-    std::list<impl::AnnotatedKernel>::iterator kernel_handle);
+    impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
   void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
   void deregisterFallback_(DispatchKey dispatchKey);
   void deregisterLibrary_(const std::string& ns);
index b82bc06..e5f9b66 100644 (file)
@@ -77,7 +77,7 @@ void OperatorEntry::deregisterSchema() {
   dispatchKeyExtractor_.deregisterSchema();
 }
 
-std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
+OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
   const c10::Dispatcher& dispatcher,
   c10::optional<DispatchKey> dispatch_key,
   KernelFunction kernel,
@@ -119,7 +119,11 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
   // Redirect catchAll registrations to CompositeImplicitAutograd.
   auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
 
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+  if (k[0].kernel.isValid()) {
+#else
   if (k.size() > 0) {
+#endif
     TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n",
                "  operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
                "    ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
@@ -129,8 +133,14 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
     );
   }
 
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+  k[0].kernel = std::move(kernel);
+  k[0].inferred_function_schema = std::move(inferred_function_schema);
+  k[0].debug = std::move(debug);
+#else
   k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
-  std::list<AnnotatedKernel>::iterator inserted = k.begin();
+#endif
+  AnnotatedKernelContainerIterator inserted = k.begin();
   // update the dispatch table, i.e. re-establish the invariant
   // that the dispatch table points to the newest kernel
   if (dispatch_key.has_value()) {
@@ -144,14 +154,18 @@ std::list<AnnotatedKernel>::iterator OperatorEntry::registerKernel(
 void OperatorEntry::deregisterKernel_(
   const c10::Dispatcher& dispatcher,
   c10::optional<DispatchKey> dispatch_key,
-  std::list<AnnotatedKernel>::iterator kernel
+  AnnotatedKernelContainerIterator kernel
 ) {
   // Redirect catchAll deregistrations to CompositeImplicitAutograd.
   DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::CompositeImplicitAutograd;
   auto found = kernels_.find(dk);
   TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_));
   auto& k = found->second;
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+  // We are about to remove the array from the map, no need to do anything.
+#else
   k.erase(kernel);
+#endif
   if (k.empty()) {
     // the invariant says we don't want empty lists but instead remove the list from the map
     kernels_.erase(found);
@@ -455,7 +469,7 @@ std::string OperatorEntry::dumpState() const {
     oss << "schema: (none)\n";
   }
 
-  auto print_kernel = [&](const char* k_desc, const std::list<AnnotatedKernel>& jts, bool is_alias_key=false) {
+  auto print_kernel = [&](const char* k_desc, const AnnotatedKernelContainer& jts, bool is_alias_key=false) {
     int64_t i = 0;
     for (const auto& jt : jts) {
       oss << k_desc
index 64e95ef..d9c7dbf 100644 (file)
 #include <list>
 #include <array>
 
+#ifdef C10_MOBILE
+#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+#endif
+
 namespace c10 {
 
 class Dispatcher;
@@ -101,6 +105,13 @@ public:
     return name_;
   }
 
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+  using AnnotatedKernelContainer = std::array<AnnotatedKernel, 1>;
+#else
+  using AnnotatedKernelContainer = std::list<AnnotatedKernel>;
+#endif
+  using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator;
+
   // Why are kernels and fallback asymmetric?  It has to do with ownership.
   // Kernels and the computed dispatch tables for them are canonically
   // owned by OperatorEntry, but backend fallbacks are specified once
@@ -114,7 +125,7 @@ public:
 
   // Precondition: Dispatcher::mutex_ is held
   // Postcondition: caller is responsible for disposing of the kernel
-  std::list<AnnotatedKernel>::iterator registerKernel(
+  AnnotatedKernelContainerIterator registerKernel(
     const Dispatcher& dispatcher,
     c10::optional<DispatchKey> dispatch_key,
     KernelFunction kernel,
@@ -127,7 +138,7 @@ public:
   void deregisterKernel_(
     const Dispatcher& dispatcher,
     c10::optional<DispatchKey> dispatch_key,
-    std::list<AnnotatedKernel>::iterator kernel
+    AnnotatedKernelContainerIterator kernel
   );
 
   // Precondition: Dispatcher::mutex_ is held
@@ -215,7 +226,14 @@ private:
   // re-executed and then only allow one kernel here, i.e. error if a kernel
   // is already registered, but that's a lot of effort to implement and
   // currently not high-pri.
-  ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
+  ska::flat_hash_map<DispatchKey,
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+                     // On mobile, we needn't worry about Jupyter notebooks.
+                     std::array<AnnotatedKernel, 1>
+#else
+                     std::list<AnnotatedKernel>
+#endif
+                     > kernels_;
 
   const AnnotatedKernel& missingKernel() const;
   const AnnotatedKernel& ambiguousAutogradOtherKernel() const;