Revert D13440858: [pytorch][PR] Use a pool of per-thread cudnn handles for each devic...
authorEdward Yang <ezyang@fb.com>
Fri, 14 Dec 2018 22:23:13 +0000 (14:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 22:35:01 +0000 (14:35 -0800)
Differential Revision:
D13440858

Original commit changeset: 1c6af5c53538

fbshipit-source-id: fda42ea75000d4a4e9c4a8eeaaa5518f7ad9c298

aten/src/ATen/cudnn/Handle.cpp
test/test_nn.py

index e0ad2ab..62f595b 100644 (file)
@@ -3,38 +3,22 @@
 #include <ATen/cuda/Exceptions.h>
 
 #include <unordered_map>
-#include <vector>
-#include <utility>
 #include <mutex>
 
+// TODO: Get rid of the mutex, and just initialize these
+// handles in at::Context along with lazy CUDA initialization
+
 namespace at { namespace native {
 
 namespace {
 
 struct Handle {
   cudnnHandle_t handle;
-  Handle(bool create = false) : handle(nullptr)
-  {
-    if(create)
-      AT_CUDNN_CHECK(cudnnCreate(&handle));
+  Handle() : handle(NULL) {
+    AT_CUDNN_CHECK(cudnnCreate(&handle));
   }
-  // std::vector.emplace() and push_back() may route through temporaries and call
-  // copy/move constructors along the way.  If this is the case, we don't want
-  // the destructors of temporaries to call cudnnDestroy on the handle.
-  // We can achieve safety (for the narrow case of stashing within std::vectors)
-  // by making Handle moveable but not copyable, and transferring handle ownership
-  // to the latest constructed object.  This is not a substitute for full-blown 
-  // reference counting, but reference counting may be overkill here.
-  // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
-  // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
-  Handle(const Handle& rhs) = delete;
-  // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
-  Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
-  // operator= takes argument by value
-  Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
   ~Handle() {
-    if(handle)
-    {
+    if (handle) {
 // this is because of something dumb in the ordering of
 // destruction. Sometimes atexit, the cuda context (or something)
 // would already be destroyed by the time this gets destroyed. It
@@ -50,79 +34,8 @@ struct Handle {
 };
 
 std::mutex mutex;
+std::unordered_map<int, Handle> handles;
 
-// Handles are lazily created as different threads request them,
-// but are never destroyed until the end of the process.
-// The maximum number of handles this process will create is equal to the high-water
-// mark of the number of concurrently active threads that have requested handles.
-// When threads terminate, they release their handles back into the pool for reuse.
-// Otherwise, new handles would be created every time new threads were spawned,
-// resulting in poor performance for Python modules that repeatedly or frequently
-// spawned new sets of threads (like DataParallel, which creates a new set of threads
-// for each forward pass).
-//
-// To prevent potential deadlocks, we explicitly choose not to cap the number
-// of handles that are created per device.
-// Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
-// only 4 can make forward progress at any time. The other 4 will not release their
-// handles until they exit, so the fifth cannot make progress until then.  This is
-// not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
-// intermediate point (ie, before any of them have exited).  We have no way to anticipate
-// or enforce that user threads will not attempt such intermediate synchronization.
-// The only way to ensure safety is to avoid imposing a cap on the number of handles.
-std::unordered_map<int, std::vector<Handle>> created_handles;
-std::unordered_map<int, std::vector<cudnnHandle_t>> available_handles;
-
-// PoolWindow lazily creates and caches the handles that a particular thread is using,
-// so in the common case handle access doesn't incur either handle creation or a mutex lock.
-class PoolWindow
-{
-  public:
-  PoolWindow(){}
-  ~PoolWindow(){ release(); }
-
-  cudnnHandle_t reserve(int device)
-  {
-    // If this thread already has a handle for this device, return it
-    if(my_handles.find(device) != my_handles.end())
-      return my_handles[device];
-
-    // otherwise, either grab a handle from the pool if one is available,
-    // or if not, create a new one.
-    std::lock_guard<std::mutex> guard(mutex);
-
-    if(available_handles[device].size() > 0)
-    {
-      my_handles[device] = available_handles[device].back();
-      available_handles[device].pop_back();
-    }
-    else
-    {
-      // In local testing, I do observe that emplace_back sometimes routes through temporaries
-      // that incur move-constructor and destructor calls.  See comments in Handle above.
-      created_handles[device].emplace_back(true /*create*/);
-      my_handles[device] = created_handles[device].back().handle;
-    }
-
-    return my_handles[device];
-  }
-
-  private:
-  // Stores the per-device handles currently owned by this thread
-  std::unordered_map<int, cudnnHandle_t> my_handles;
-
-  // Called by the destructor.  Releases this thread's handles back into the pool.
-  void release()
-  {
-    std::lock_guard<std::mutex> guard(mutex);
-    for(auto d_h : my_handles)
-      available_handles[d_h.first].push_back(d_h.second);
-  }
-};
-
-// This will be destroyed when the thread terminates,
-// releasing its reserved handles back to the pool.
-thread_local PoolWindow myPoolWindow;
 }  // namespace
 
 
@@ -131,7 +44,8 @@ cudnnHandle_t getCudnnHandle()
   int device;
   AT_CUDA_CHECK(cudaGetDevice(&device));
 
-  return myPoolWindow.reserve(device);
+  std::lock_guard<std::mutex> guard(mutex);
+  return handles[device].handle;
 }
 
 }} // namespace at::cudnn
index 1237440..a404cb0 100644 (file)
@@ -13,7 +13,6 @@ from operator import mul
 from collections import OrderedDict
 import hashlib
 import os
-import threading
 
 import torch
 from torch._six import inf, nan
@@ -3708,55 +3707,6 @@ class TestNN(NNTestCase):
 
     @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
     @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
-    @skipIfRocm
-    def test_cudnn_multiple_threads_same_device(self):
-        # This function is intended to test the lazy creation and reuse of per-thread
-        # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp.
-        # Failure here likely indicates something wrong with that logic.
-        weight = torch.ones((1, 1, 2, 2), device='cuda')
-
-        results = {}
-
-        num_threads = 2
-        trials = 2
-        test_iters = 100
-
-        with torch.backends.cudnn.flags(enabled=True):
-            def _worker(t, input):
-                my_stream = torch.cuda.Stream()
-                results[t] = input
-                with torch.cuda.stream(my_stream):
-                    for i in range(test_iters):
-                        # If all threads are sharing the same cudnn handle,
-                        # the following sequence may occur:
-                        # thread 0 calls setCuDNNStreamToCurrent()
-                        # thread 1 calls setCuDNNStreamToCurrent()
-                        # thread 0 launches its raw convolution, which it thinks is in
-                        #          its own stream, but is actually in thread 1's stream.
-                        # thread 0 enqueues its div_, which IS is its own stream,
-                        #          but now races with its convolution.
-                        results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0)
-                        results[t].div_(4.0)
-                torch.cuda.current_stream().wait_stream(my_stream)
-
-            for trial in range(trials):
-                for t in range(num_threads):
-                    results[t] = torch.ones((1, 1, 2048, 2048), device='cuda')
-
-                threads = [threading.Thread(target=_worker,
-                                            args=(t, results[t])) for t in range(num_threads)]
-
-                for thread in threads:
-                    thread.start()
-                for thread in threads:
-                    thread.join()
-
-                for t in range(num_threads):
-                    self.assertEqual(results[t].sum().item(),
-                                     (2048 - test_iters) * (2048 - test_iters))
-
-    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
-    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
     @repeat_test_for_types(ALL_TENSORTYPES)
     @skipIfRocm
     def test_Conv2d_deterministic_cudnn(self, dtype=torch.float):