Implement reference counting for shared IPC CUDA tensors (#16854)
authorVitaly Fedyunin <vitalyf@fb.com>
Mon, 25 Mar 2019 17:18:29 +0000 (10:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Mar 2019 17:24:38 +0000 (10:24 -0700)
Summary:
This is to fix #16141 and similar issues.

The idea is to track a reference to every shared CUDA Storage and deallocate memory only after a consumer process deallocates received Storage.

ezyang Done with cleanup. Same (insignificantly better) performance as in file-per-share solution, but handles millions of shared tensors easily. Note [ ] documentation in progress.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16854

Differential Revision: D13994490

Pulled By: VitalyFedyunin

fbshipit-source-id: 565148ec3ac4fafb32d37fde0486b325bed6fbd1

15 files changed:
c10/core/StorageImpl.h
c10/cuda/CUDACachingAllocator.cpp
c10/cuda/CUDACachingAllocator.h
docs/source/multiprocessing.rst
test/test_multiprocessing.py
torch/CMakeLists.txt
torch/csrc/CudaIPCTypes.cpp [new file with mode: 0644]
torch/csrc/CudaIPCTypes.h [new file with mode: 0644]
torch/csrc/Storage.cpp
torch/csrc/cuda/Module.cpp
torch/csrc/cuda/Storage.cpp
torch/csrc/generic/StorageSharing.cpp
torch/cuda/__init__.py
torch/multiprocessing/cuda_multiprocessing.md [new file with mode: 0644]
torch/multiprocessing/reductions.py

index 122fd08..579ef00 100644 (file)
@@ -19,6 +19,7 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
         data_ptr_(std::move(data_ptr)),
         numel_(numel),
         resizable_(resizable),
+        received_cuda_(false),
         allocator_(allocator) {
     if (resizable) {
       AT_ASSERTM(
@@ -210,11 +211,24 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
     resizable_ = false;
   }
 
+  // This method can be used only after storage construction and cannot be used
+  // to modify storage status
+  void set_received_cuda(bool received_cuda) {
+    received_cuda_ = received_cuda;
+  }
+
+  bool received_cuda() {
+    return received_cuda_;
+  }
+
  private:
   caffe2::TypeMeta data_type_;
   DataPtr data_ptr_;
   int64_t numel_;
   bool resizable_;
+  // Identifies that Storage was received from another process and doesn't have
+  // local to process cuda memory allocation
+  bool received_cuda_;
   Allocator* allocator_;
 };
 } // namespace c10
index ff79762..33d26ab 100644 (file)
 #include <vector>
 
 namespace c10 {
-namespace cuda {
 
+C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
+
+namespace cuda {
 namespace CUDACachingAllocator {
 
 //
@@ -47,6 +49,8 @@ namespace CUDACachingAllocator {
 // work.
 //
 
+
+
 namespace {
 
 using stream_set = std::unordered_set<cuda::CUDAStream>;
@@ -154,7 +158,7 @@ struct THCCachingAllocator
   std::vector<DeviceStats> device_stats;
 
   // lock around all operations
-  std::mutex mutex;
+  std::recursive_mutex mutex;
 
   // lock around calls to cudaFree (to prevent deadlocks with NCCL)
   std::mutex cuda_free_mutex;
@@ -186,7 +190,7 @@ struct THCCachingAllocator
   /** allocates a block which is safe to use from the provided stream */
   void malloc(void** devPtr, size_t size, cudaStream_t stream)
   {
-    std::lock_guard<std::mutex> lock(mutex);
+    std::lock_guard<std::recursive_mutex> lock(mutex);
 
     int device;
     C10_CUDA_CHECK(cudaGetDevice(&device));
@@ -201,14 +205,29 @@ struct THCCachingAllocator
     Block search_key(device, stream, size);
     auto& pool = get_pool(size);
 
-    Block* block = nullptr;
-    Block* remaining = nullptr;
-
-    auto it = pool.lower_bound(&search_key);
-    if (it != pool.end() && (*it)->device == device && (*it)->stream == stream) {
-      block = *it;
-      pool.erase(it);
-    } else {
+    auto find_free_block = [&]()->Block*{
+      auto it = pool.lower_bound(&search_key);
+      if (it != pool.end() && (*it)->device == device &&
+          (*it)->stream == stream) {
+        Block* block = *it;
+        pool.erase(it);
+        return block;
+      }
+      return nullptr;
+    };
+
+    Block* block = find_free_block();
+    if (block == nullptr) {
+      bool freed_memory = false;
+      for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
+        freed_memory |=
+            FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
+      }
+      if (freed_memory) {
+        block = find_free_block();
+      }
+    }
+    if (block == nullptr) {
       void* ptr;
       size_t alloc_size = get_allocation_size(size);
       cudaError_t err = cuda_malloc_retry(device, &ptr, alloc_size);
@@ -253,8 +272,10 @@ struct THCCachingAllocator
       block = new Block(device, stream, alloc_size, &pool, ptr);
     }
 
+    Block* remaining = nullptr;
     AT_ASSERT(block);
     if (should_split(block, size)) {
+
       remaining = block;
 
       block = new Block(device, stream, size, &pool, block->ptr);
@@ -280,7 +301,7 @@ struct THCCachingAllocator
 
   void free(void* ptr)
   {
-    std::lock_guard<std::mutex> lock(mutex);
+    std::lock_guard<std::recursive_mutex> lock(mutex);
     if (!ptr) {
       return;
     }
@@ -305,14 +326,14 @@ struct THCCachingAllocator
   /** returns cached blocks to the system allocator */
   void emptyCache()
   {
-    std::lock_guard<std::mutex> lock(mutex);
+    std::lock_guard<std::recursive_mutex> lock(mutex);
     free_blocks(large_blocks, large_blocks.begin(), large_blocks.end());
     free_blocks(small_blocks, small_blocks.begin(), small_blocks.end());
   }
 
   void* getBaseAllocation(void* ptr, size_t* outSize)
   {
-    std::lock_guard<std::mutex> lock(mutex);
+    std::lock_guard<std::recursive_mutex> lock(mutex);
     Block* block = find_allocated_block(ptr);
     if (!block) {
       AT_ERROR("invalid device pointer: %p", ptr);
@@ -348,14 +369,14 @@ struct THCCachingAllocator
 
   void cacheInfo(int dev_id, size_t* total, size_t* largest)
   {
-    std::lock_guard<std::mutex> lock(mutex);
+    std::lock_guard<std::recursive_mutex> lock(mutex);
     cacheInfoAux(large_blocks, dev_id, total, largest);
     cacheInfoAux(small_blocks, dev_id, total, largest);
   }
 
   void recordStream(void* ptr, cuda::CUDAStream stream)
   {
-    std::lock_guard<std::mutex> lock(mutex);
+    std::lock_guard<std::recursive_mutex> lock(mutex);
     Block* block = find_allocated_block(ptr);
     if (!block) {
       AT_ERROR("invalid device pointer: %p", ptr);
index f146513..2376446 100644 (file)
@@ -4,10 +4,24 @@
 #include <c10/cuda/CUDAStream.h>
 #include <c10/core/Allocator.h>
 #include <c10/cuda/CUDAMacros.h>
+#include <c10/util/Registry.h>
 
 #include <mutex>
 
 namespace c10 {
+
+// Caching allocator will execute every registered callback if it unable to find
+// block inside of already allocated area.
+class C10_CUDA_API FreeMemoryCallback {
+ public:
+  virtual ~FreeMemoryCallback() {};
+  virtual bool Execute() = 0;
+};
+
+C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
+#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
+  C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
+
 namespace cuda {
 
 // TODO: Turn this into an honest to goodness class. I briefly attempted to do
index e5bffec..f76b579 100644 (file)
@@ -28,57 +28,65 @@ Python 2 can only create subprocesses using ``fork``, and it's not supported
 by the CUDA runtime.
 
 Unlike CPU tensors, the sending process is required to keep the original tensor
-as long as the receiving process retains a copy of the tensor.
-This shouldn't be a problem for sharing model parameters (which stay live
-for the entire execution of the model), but passing other
-kinds of data should be done with care.
+as long as the receiving process retains a copy of the tensor. It is implemented
+under the hood but requires users to follow the next best practices.
 
-Here is an example program which handles these requirements correctly:
+1. Release memory ASAP in the consumer.
 
 ::
 
-    import torch
-    import torch.multiprocessing as mp
-
-    torch.set_default_tensor_type(torch.cuda.FloatTensor)
-
-    def sender(q, e):
-        for i in range(10):
-            s_sample = [torch.zeros(1), torch.ones(1)]
-            q.put(s_sample)
-            e.wait()
-            del s_sample
-            e.clear()
-
-    if __name__ == "__main__":
-        ctx = mp.get_context("spawn")
-        q = ctx.Queue()
-        e = ctx.Event()
-        p = ctx.Process(target=sender, args=(q, e))
-        p.start()
-
-        for i in range(10):
-            print('=== ITER {} ===".format(i))
-            r_sample = q.get()
-            del r_sample
-            e.set()
-
-        p.join()
-
-In the example above, calling `e.wait()`
-on sender side ensures tensor `s_sample` doesn't get deleted while
-receiver is working on it.  The receiver signals when it is done
-with the tensor using `e.set()`, being careful to `del` its reference
-to the received tensor first.  It is INSUFFICIENT to promise never to call
-`r_sample` again; while `r_sample` is live, it may be confused with
-any subsequent tensors allocated by the source process at the same address.
-
-If a receiver wants to save the data of `r_sample` for future use while
-letting the source process deallocate the original, it must
-`clone()` it.
-
-This behavior is very confusing, and we are tracking a fix for it
-at https://github.com/pytorch/pytorch/issues/16141
+    ## Good
+    x = queue.get()
+    # do somethings with x
+    del x
+
+::
+
+    ## Bad
+    x = queue.get()
+    # do somethings with x
+    # do everything else (producer have to keep x in memory)
+
+2. Keep producer process running until all consumers exits. This will prevent
+the situation when the producer process releasing memory which is still in use
+by the consumer.
+
+::
+
+    ## producer
+    # send tensors, do something
+    event.wait()
+
+
+::
+
+    ## consumer
+    # receive tensors and use them
+    event.set()
+
+3. Don't pass received tensors.
+
+::
+
+    # not going to work
+    x = queue.get()
+    queue_2.put(x)
+
+
+::
+
+    # you need to create a process-local copy
+    x = queue.get()
+    x_clone = x.clone()
+    queue_2.put(x_clone)
+
+
+::
+
+    # putting and getting from the same queue in the same process will likely end up with segfault
+    queue.put(tensor)
+    x = queue.get()
+
 
 Sharing strategies
 ------------------
index 76b3563..ec989c8 100644 (file)
@@ -12,7 +12,7 @@ import torch.multiprocessing as mp
 import torch.utils.hooks
 from torch.nn import Parameter
 from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN,
-                          load_tests)
+                          load_tests, slowTest)
 from multiprocessing.reduction import ForkingPickler
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -56,6 +56,30 @@ def send_tensor(queue, event, tp):
     event.wait()
 
 
+def send_and_delete_tensors(queue, event, tp, count, size=5):
+    for i in range(count):
+        t = torch.full([size], i).type(tp)
+        queue.put(t)
+        del t
+    event.wait()
+
+
+def receive_and_send_sum(queue, out_queue, event, tp, count, size=5):
+    s = torch.full([size], 0).type(tp)
+    for i in range(count):
+        t = queue.get()
+        s += t
+    out_queue.put(s)
+    event.wait()
+
+
+def receive_and_send(queue, out_queue, event, count):
+    for i in range(count):
+        t = queue.get()
+        out_queue.put(t.clone())
+    event.wait()
+
+
 def call_backward():
     x = torch.randn(3, 3, requires_grad=True)
     x.sum().backward()
@@ -150,6 +174,8 @@ class leak_checker(object):
         return self
 
     def __exit__(self, *args):
+        if torch.cuda.is_available():
+            torch.cuda.ipc_collect()
         if args[0] is None:
             # Check that the 10th available file-descriptor at the end of the
             # test is no more than 4 higher than the 10th available at the
@@ -193,6 +219,11 @@ class leak_checker(object):
 
 class TestMultiprocessing(TestCase):
 
+    def tearDown(self):
+        # This will keep tests isolated from each-other
+        if torch.cuda.is_available():
+            torch.cuda.ipc_collect()
+
     def _test_sharing(self, ctx=mp, type=torch.FloatTensor, repeat=1):
         def test_fill():
             x = torch.zeros(5, 5).type(type)
@@ -222,6 +253,9 @@ class TestMultiprocessing(TestCase):
             t2 = q.get()
             self.assertTrue(t1.eq(1).all())
             self.assertTrue(id(t1.storage()) == id(t2.storage()))
+            # We need to delete this tensors to allow producer (child process)
+            # collect them properly
+            del t1, t2
             e.set()
             p.join(1)
             self.assertFalse(p.is_alive())
@@ -322,13 +356,58 @@ class TestMultiprocessing(TestCase):
     @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
                      don't support multiprocessing with spawn start method")
     @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
-    def test_cuda(self):
+    def test_cuda_simple(self):
         torch.cuda.FloatTensor([1])  # initialize CUDA outside of leak checker
         self._test_sharing(mp.get_context('spawn'), torch.cuda.FloatTensor)
 
     @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
                      don't support multiprocessing with spawn start method")
     @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_cuda_memory_allocation(self):
+        ctx = mp.get_context('spawn')
+        q = ctx.Queue()
+        e = ctx.Event()
+        p = ctx.Process(target=send_and_delete_tensors, args=(q, e, torch.cuda.IntTensor, 5))
+        p.start()
+        t = []
+        for _ in range(5):
+            t.append(q.get())
+        self.assertEqual(t[0], torch.full([5], 0))
+        del t
+        e.set()
+        p.join(1)
+
+    @slowTest
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_cuda_send_many(self, name=None, size=5, count=100000):
+        ctx = mp.get_context('spawn')
+        q1 = ctx.Queue()
+        q2 = ctx.Queue()
+        q3 = ctx.Queue()
+        e1 = ctx.Event()
+        e2 = ctx.Event()
+        e3 = ctx.Event()
+        p1 = ctx.Process(target=send_and_delete_tensors, args=(q1, e1, torch.cuda.LongTensor, count, size))
+        p2 = ctx.Process(target=receive_and_send, args=(q1, q2, e2, count))
+        p3 = ctx.Process(target=receive_and_send_sum, args=(q2, q3, e3, torch.cuda.LongTensor, count, size))
+        p1.start()
+        p2.start()
+        p3.start()
+        result = q3.get()
+        self.assertEqual(result[0], int(count * (count - 1) / 2))
+        del result
+        e1.set()
+        e2.set()
+        e3.set()
+        p1.join(1)
+        p2.join(1)
+        p3.join(1)
+
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
     @unittest.skipIf(not TEST_MULTIGPU, 'found only 1 GPU')
     def test_cuda_small_tensors(self):
         # Check multiple small tensors which will likely use the same
@@ -355,6 +434,7 @@ class TestMultiprocessing(TestCase):
             self.assertEqual(v, torch.arange(i * 5., (i + 1) * 5).sum())
             self.assertEqual(device, i % 2)
             self.assertEqual(tensor_size, 5)
+
             # You might think this should be the case, but it's not!  After
             # data from the CUDA caching allocator goes through IPC, the
             # size of the storage is the size of the *cached cudaMalloc for
@@ -363,6 +443,15 @@ class TestMultiprocessing(TestCase):
             #
             # self.assertEqual(storage_size, 5)
 
+        # Collect current process (producer) files, make sure nothing holds
+        # ref to the sent tensors
+        del _tensor
+        del tensors
+
+        # We need to collect, as CUDA MP implementation holds one shared
+        # memory 'file' for performance reason
+        torch.cuda.ipc_collect()
+
     @unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)')
     @unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
     def test_cuda_bad_call(self):
index bee3ef5..71f9908 100644 (file)
@@ -489,6 +489,7 @@ if (BUILD_PYTHON)
   endif()
 
   set(TORCH_PYTHON_SRCS
+    ${TORCH_SRC_DIR}/csrc/CudaIPCTypes.cpp
     ${TORCH_SRC_DIR}/csrc/DataLoader.cpp
     ${TORCH_SRC_DIR}/csrc/Device.cpp
     ${TORCH_SRC_DIR}/csrc/Dtype.cpp
diff --git a/torch/csrc/CudaIPCTypes.cpp b/torch/csrc/CudaIPCTypes.cpp
new file mode 100644 (file)
index 0000000..c31ff1f
--- /dev/null
@@ -0,0 +1,240 @@
+#ifdef USE_CUDA
+#include <torch/csrc/CudaIPCTypes.h>
+#include <TH/THAllocator.h>
+#include <map>
+#include <mutex>
+#include <random>
+
+#ifdef _MSC_VER
+#include <windows.h>
+#else
+#include <sys/types.h>
+#include <unistd.h>
+#endif
+
+namespace torch {
+
+namespace {
+
+void warnProducerTerminatedBeforeSharedTensorsReleased() {
+  static bool warned = false;
+  if (!warned) {
+    LOG(WARNING)
+        << "Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]";
+    warned = true;
+  }
+}
+
+struct CudaIPCGlobalEntities {
+  std::mutex ref_counters_mutex_;
+  std::atomic<int64_t> sync_events_used_;
+  std::map<std::string, std::shared_ptr<CudaIPCRefCountersFile>>
+      ref_counters_files_;
+  std::shared_ptr<CudaIPCRefCountersFile> next_available_ref_counters_file_;
+  CudaIPCSentDataLimbo CudaIPCSentDataLimbo_;
+  CudaIPCGlobalEntities() : ref_counters_files_() {}
+  ~CudaIPCGlobalEntities() {
+    CudaIPCSentDataLimbo_.collect();
+    safe_clean_current_file();
+    if (next_available_ref_counters_file_) {
+      warnProducerTerminatedBeforeSharedTensorsReleased();
+    }
+  }
+  void safe_clean_current_file() {
+    std::lock_guard<std::mutex> lock(ref_counters_mutex_);
+    if (next_available_ref_counters_file_ &&
+        next_available_ref_counters_file_->offsets_in_use() == 0) {
+      ref_counters_files_.erase(next_available_ref_counters_file_->handle());
+      next_available_ref_counters_file_.reset();
+    }
+  }
+};
+
+CudaIPCGlobalEntities cuda_ipc_global_entities;
+
+CudaIPCSentDataLimbo::~CudaIPCSentDataLimbo() {
+  collect();
+  if (shared_blocks_.size() > 0) {
+    warnProducerTerminatedBeforeSharedTensorsReleased();
+  }
+}
+
+bool CudaIPCSentDataLimbo::collect() {
+  bool freed_memory = false;
+  std::lock_guard<std::mutex> lock(limbo_mutex_);
+  std::vector<std::unique_ptr<CudaIPCSentData>> kept_blocks;
+  for (auto& sd : shared_blocks_) {
+    if (sd->counter_value() > 0) {
+      kept_blocks.push_back(std::move(sd));
+    } else {
+      freed_memory = true;
+      sd.reset();
+    }
+  }
+  shared_blocks_ = std::move(kept_blocks);
+  return freed_memory;
+}
+
+void CudaIPCSentDataLimbo::add(std::unique_ptr<CudaIPCSentData> shared_block) {
+  std::lock_guard<std::mutex> lock(limbo_mutex_);
+  static bool warned = false;
+  if (shared_blocks_.size() > CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO &&
+      !warned) {
+    LOG(WARNING)
+        << "Producer process tried to deallocate over "
+        << CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO
+        << " memory blocks referred by consumer processes. Deallocation might be significantly slowed down. "
+        << "We assume it will never going to be the case, but if it is, please file but to https://github.com/pytorch/pytorch";
+    warned = true;
+  }
+  shared_blocks_.push_back(std::move(shared_block));
+}
+
+void CudaIPCSentDataDelete(void* ptr) {
+  std::unique_ptr<CudaIPCSentData> sent_data(
+      static_cast<CudaIPCSentData*>(ptr));
+  if (sent_data->counter_value() > 0) {
+    cuda_ipc_global_entities.CudaIPCSentDataLimbo_.add(std::move(sent_data));
+  }
+  cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
+}
+
+void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) {
+  std::lock_guard<std::mutex> lock(
+      cuda_ipc_global_entities.ref_counters_mutex_);
+  cuda_ipc_global_entities.ref_counters_files_[handle]->return_offset(offset);
+  if (cuda_ipc_global_entities.ref_counters_files_[handle]->offsets_in_use() ==
+          0 &&
+      !cuda_ipc_global_entities.ref_counters_files_[handle]->have_offsets()) {
+    cuda_ipc_global_entities.ref_counters_files_.erase(handle);
+  }
+}
+
+} // namespace
+
+CudaIPCSentData::CudaIPCSentData(
+    std::string handle,
+    int64_t offset,
+    int64_t* counter_ptr,
+    at::Device device)
+    : handle_(handle),
+      offset_(offset),
+      counter_ptr_(counter_ptr),
+      original_ptr_(),
+      device_(device) {
+#ifndef __HIP_PLATFORM_HCC__
+  // CUDA have the unofficial limit on the number of recorded blocking interprocess
+  // events, to prevent using of all events, we are switching to StreamSync
+  // before limit reached.
+  //
+  //  ```python
+  //  import torch
+  //  a = [ torch.cuda.Event(
+  //      enable_timing=False, blocking=True, interprocess=True) for i in range(30000) ]
+  //  [i.record() for i in a]
+  //  ```
+  //
+  if (cuda_ipc_global_entities.sync_events_used_.load() < CUDA_IPC_MAXIMUM_EVENTS_TO_USE) {
+    // TODO: More efficient would be to create event inside of main thread (at
+    // the moment of the queue.put). The reason this is more efficient is
+    // because the main thread may have queued extra work on the stream, which
+    // this event will consequently wait for (uselessly).
+    cuda_ipc_global_entities.sync_events_used_ ++;
+    C10_CUDA_CHECK(cudaEventCreateWithFlags(
+        &event_,
+        cudaEventDisableTiming | cudaEventInterprocess |
+            cudaEventBlockingSync));
+    C10_CUDA_CHECK(cudaEventRecord(
+        event_, c10::cuda::getCurrentCUDAStream(device.index())));
+    event_sync_required_ = true;
+  } else {
+    auto stream = c10::cuda::getCurrentCUDAStream(device.index());
+    C10_CUDA_CHECK(cudaStreamSynchronize(stream));
+    event_sync_required_ = false;
+  }
+#else
+  // cuIpcGetEventHandle with HIP is not supported, so we have to sync
+  // stream instead of passing event
+  auto stream = c10::cuda::getCurrentCUDAStream(device.index());
+  C10_CUDA_CHECK(cudaStreamSynchronize(stream));
+  event_sync_required_ = false;
+#endif
+}
+
+CudaIPCSentData::~CudaIPCSentData() {
+  ReturnRefCounter(handle_, offset_);
+#ifndef __HIP_PLATFORM_HCC__
+  try {
+    if (event_sync_required_) {
+      at::cuda::CUDAGuard device_guard(device_.index());
+      cudaEventDestroy(event_);
+      cuda_ipc_global_entities.sync_events_used_ --;
+    }
+  } catch (...) { /* No throw */
+  }
+#endif
+}
+
+int64_t CudaIPCSentData::counter_value() {
+  return *counter_ptr_;
+}
+
+at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) {
+  {
+    std::lock_guard<std::mutex> lock(
+        cuda_ipc_global_entities.ref_counters_mutex_);
+    if (!cuda_ipc_global_entities.next_available_ref_counters_file_) {
+      static std::random_device rd;
+      std::string ref_counter_handle = "/torch_";
+#ifdef _MSC_VER
+      ref_counter_handle += std::to_string(GetCurrentProcessId());
+#else
+      ref_counter_handle += std::to_string(getpid());
+#endif
+      ref_counter_handle += "_";
+      ref_counter_handle += std::to_string(rd());
+
+      int flags = TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_EXCLUSIVE;
+      at::DataPtr sptr = THRefcountedMapAllocator::makeDataPtr(
+          ref_counter_handle.c_str(),
+          flags,
+          sizeof(int64_t) * CUDA_IPC_REF_COUNTER_FILE_SIZE,
+          nullptr);
+      auto rc = std::make_shared<CudaIPCRefCountersFile>(
+          ref_counter_handle, CUDA_IPC_REF_COUNTER_FILE_SIZE, std::move(sptr));
+      cuda_ipc_global_entities.ref_counters_files_[ref_counter_handle] = rc;
+      cuda_ipc_global_entities.next_available_ref_counters_file_ = rc;
+    }
+  }
+  cuda_ipc_global_entities.next_available_ref_counters_file_->set_counter(1);
+  auto sent_data = new CudaIPCSentData(
+      cuda_ipc_global_entities.next_available_ref_counters_file_->handle(),
+      cuda_ipc_global_entities.next_available_ref_counters_file_->get_offset(),
+      cuda_ipc_global_entities.next_available_ref_counters_file_->counter_ptr(),
+      device);
+
+  cuda_ipc_global_entities.next_available_ref_counters_file_->rotate_offset();
+  if (!cuda_ipc_global_entities.next_available_ref_counters_file_
+           ->have_offsets()) {
+    cuda_ipc_global_entities.next_available_ref_counters_file_.reset();
+  }
+  return at::DataPtr(data, sent_data, CudaIPCSentDataDelete, device);
+}
+
+bool CudaIPCCollect() {
+  bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
+  if (cuda_ipc_global_entities.CudaIPCSentDataLimbo_.size() == 0) {
+    cuda_ipc_global_entities.safe_clean_current_file();
+  }
+  return freed_memory;
+}
+
+} // namespace torch
+
+namespace c10 {
+namespace {
+REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
+}
+} // namespace c10
+
+#endif
diff --git a/torch/csrc/CudaIPCTypes.h b/torch/csrc/CudaIPCTypes.h
new file mode 100644 (file)
index 0000000..a9d5efd
--- /dev/null
@@ -0,0 +1,146 @@
+#pragma once
+#ifdef USE_CUDA
+#include <c10/core/Allocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
+#include <c10/cuda/CUDAException.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <c10/cuda/CUDAStream.h>
+#include <c10/util/Logging.h>
+#include <cuda_runtime_api.h>
+#include <cstddef>
+
+namespace torch {
+
+bool CudaIPCCollect();
+
+struct CudaIPCReceivedData final {
+  explicit CudaIPCReceivedData(std::shared_ptr<void> shared_ptr)
+      : shared_ptr_(std::move(shared_ptr)) {}
+  std::shared_ptr<void> shared_ptr_;
+};
+
+struct CudaIPCSentData final {
+  std::string handle_;
+  int64_t offset_;
+  int64_t* counter_ptr_; // Reference counter shared memory block
+  at::DataPtr original_ptr_; // Original mem allocation
+  cudaEvent_t event_; // Sync cuEventDestroy
+  bool event_sync_required_;
+  at::Device device_;
+
+  CudaIPCSentData(
+      std::string handle,
+      int64_t offset,
+      int64_t* counter_ptr,
+      at::Device device);
+  ~CudaIPCSentData();
+
+  int64_t counter_value();
+  std::string handle() {
+    return handle_;
+  }
+  int64_t offset() {
+    return offset_;
+  }
+  void set_original_ptr(at::DataPtr data_ptr) {
+    original_ptr_ = std::move(data_ptr);
+  }
+};
+
+at::DataPtr GetNewRefCountedSentData(void* data, at::Device device);
+
+namespace {
+
+constexpr int64_t CUDA_IPC_REF_COUNTER_FILE_SIZE = 10000;
+constexpr int64_t CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO = 1000;
+// This was determined empirically that CUDA (v10.1 and below) have the limit
+// on the number of recorded blocking interprocess events. It is around ~22,000.
+// And to give us leeway, we picked 1000 as it gives us enough events to share
+// tensors effectively.
+constexpr int64_t CUDA_IPC_MAXIMUM_EVENTS_TO_USE = 1000;
+
+// All to be deleted data blocks with non zero reference counter goes there
+struct CudaIPCSentDataLimbo final {
+  ~CudaIPCSentDataLimbo();
+  bool collect();
+  void add(std::unique_ptr<CudaIPCSentData> shared_block);
+  uint64_t size() {
+    return shared_blocks_.size();
+  }
+
+ private:
+  // TODO: Can be changed to FIFO in order to avoid full traverse on every
+  // collect()
+  std::vector<std::unique_ptr<CudaIPCSentData>> shared_blocks_;
+  std::mutex limbo_mutex_;
+};
+
+struct CudaIPCRefCountersFile final {
+  CudaIPCRefCountersFile(
+      std::string handle,
+      uint64_t size,
+      at::DataPtr data_ptr)
+      : next_offset_(0),
+        size_(size),
+        used_slots_(0),
+        handle_(handle),
+        refcounted_shared_mem_(std::move(data_ptr)) {}
+
+  int64_t* counter_ptr() {
+    return static_cast<int64_t*>(refcounted_shared_mem_.get()) + next_offset_;
+  }
+
+  void set_counter(uint64_t value) {
+    *counter_ptr() = value;
+  }
+
+  bool have_offsets() {
+    return next_offset_ < size_;
+  }
+
+  bool offsets_in_use() {
+    return used_slots_;
+  }
+
+  int64_t get_offset() {
+    return next_offset_;
+  }
+
+  void rotate_offset() {
+    next_offset_++;
+    used_slots_++;
+  }
+
+  void return_offset(uint64_t offset /* unused */) {
+    used_slots_--;
+  }
+
+  std::string handle() {
+    return handle_;
+  }
+
+ private:
+  uint64_t next_offset_;
+  uint64_t size_;
+  uint64_t used_slots_;
+  std::string handle_;
+  at::DataPtr refcounted_shared_mem_;
+};
+
+} // namespace
+} // namespace torch
+
+namespace c10 {
+namespace {
+class CudaIPCCollectCallback : public FreeMemoryCallback {
+ public:
+  ~CudaIPCCollectCallback() {};
+  bool Execute() override {
+    return torch::CudaIPCCollect();
+  }
+};
+} // namespace
+
+} // namespace c10
+
+#endif
index 1d3d4c5..b150dad 100644 (file)
@@ -16,6 +16,7 @@
 #include <torch/csrc/THP.h>
 #include <torch/csrc/copy_utils.h>
 #include <torch/csrc/DynamicTypes.h>
+#include <torch/csrc/CudaIPCTypes.h>
 
 #include <torch/csrc/generic/Storage.cpp>
 #include <TH/THGenerateAllTypes.h>
index a6b3e03..1aa4e65 100644 (file)
@@ -13,7 +13,7 @@
 #endif
 
 #include <torch/csrc/cuda/THCP.h>
-
+#include <torch/csrc/CudaIPCTypes.h>
 #include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/autograd/generated/VariableType.h>
 #include <torch/csrc/utils/python_strings.h>
@@ -217,6 +217,14 @@ PyObject * THCPModule_cudaSynchronize(PyObject *_unused)
   END_HANDLE_TH_ERRORS
 }
 
+PyObject * THCPModule_cudaIPCCollect(PyObject *_unused /* unused */)
+{
+  HANDLE_TH_ERRORS
+  torch::CudaIPCCollect();
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles)
 {
   HANDLE_TH_ERRORS
@@ -453,6 +461,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
   {"_cuda_initialSeed", (PyCFunction)THCPModule_initialSeed,      METH_NOARGS,  nullptr},
   {"_cuda_cudaHostAllocator", (PyCFunction)THCPModule_cudaHostAllocator, METH_NOARGS, nullptr},
   {"_cuda_synchronize", (PyCFunction)THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
+  {"_cuda_ipc_collect", (PyCFunction)THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
   {"_cuda_sleep", (PyCFunction)THCPModule_cudaSleep, METH_O, nullptr},
   {"_cuda_lock_mutex",   (PyCFunction)THCPModule_cudaLockMutex,   METH_NOARGS,  nullptr},
   {"_cuda_unlock_mutex", (PyCFunction)THCPModule_cudaUnlockMutex, METH_NOARGS,  nullptr},
index 6a103a7..05c8645 100644 (file)
@@ -12,6 +12,7 @@
 #include <torch/csrc/cuda/override_macros.h>
 #include <torch/csrc/copy_utils.h>
 #include <torch/csrc/DynamicTypes.h>
+#include <torch/csrc/CudaIPCTypes.h>
 
 #define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
 #include <THC/THCGenerateAllTypes.h>
index d6187e8..efe210c 100644 (file)
@@ -216,13 +216,26 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
 {
   HANDLE_TH_ERRORS
   THWStorage *storage = self->cdata;
+
+  if (storage->received_cuda()) {
+    AT_ERROR(
+        "Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending.");
+  }
+
   at::DeviceGuard device_guard(storage->device());
-  THPObjectPtr tuple(PyTuple_New(4));
+  THPObjectPtr tuple(PyTuple_New(8));
   THPObjectPtr device(PyLong_FromLong(storage->device().index()));
   THPObjectPtr _handle(Py_None);
   Py_INCREF(Py_None);
   THPObjectPtr size_bytes(PyLong_FromLong(storage->numel() * sizeof(scalar_t)));
   THPObjectPtr _offset_bytes(PyLong_FromLong(0));
+  THPObjectPtr _ref_counter(Py_None);
+  Py_INCREF(Py_None);
+  THPObjectPtr _ref_counter_offset(PyLong_FromLong(0));
+  THPObjectPtr _event_handle(Py_None);
+  Py_INCREF(Py_None);
+  THPObjectPtr _event_sync_required(Py_None);
+  Py_INCREF(Py_None);
   if (THWStorage_(data)(LIBRARY_STATE storage)) {
     size_t base_size;
     void *base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
@@ -233,9 +246,33 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
 
     _handle = PyBytes_FromStringAndSize((char *)&handle, CUDA_IPC_HANDLE_SIZE);
     _offset_bytes = PyLong_FromSsize_t((Py_ssize_t)offset_bytes);
+
+    // Put Storage Data behind new ref counting context
+    // See Note [CUDA IPC Refcounting implementation explained]
+    at::DataPtr sent_data_ptr = torch::GetNewRefCountedSentData(storage->data(), storage->device());
+    auto old_data_ptr = storage->set_data_ptr(std::move(sent_data_ptr));
+    auto sent_data  =  static_cast<torch::CudaIPCSentData*>(storage->data_ptr().get_context());
+    sent_data->set_original_ptr(std::move(old_data_ptr));
+    _ref_counter = PyBytes_FromString((sent_data->handle()).c_str());
+    _ref_counter_offset = PyLong_FromLong(sent_data->offset());
+
+
+    cudaIpcEventHandle_t ipc_event_handle;
+
+#ifndef __HIP_PLATFORM_HCC__
+    if (sent_data->event_sync_required_) {
+      THCudaCheck(cudaIpcGetEventHandle(&ipc_event_handle, sent_data->event_));
+    }
+#else
+    // ipc_event_handle unused in storage receiver, we can leave it uninitialized.
+#endif
+
+    _event_handle = PyBytes_FromStringAndSize((char *)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE);
+    _event_sync_required = PyBool_FromLong(sent_data->event_sync_required_);
+
   }
 
-  if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes) {
+  if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes || !_event_handle) {
     return nullptr;
   }
   PyTuple_SET_ITEM(tuple.get(), 0, device.release());
@@ -248,40 +285,111 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
   //     as key in shared_cache(multiprocessing/reduction.py).
   //     Offset in numel cannot uniquely represent a storage.
   PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release());
+  PyTuple_SET_ITEM(tuple.get(), 4, _ref_counter.release());
+  PyTuple_SET_ITEM(tuple.get(), 5, _ref_counter_offset.release());
+  PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release());
+  PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release());
   return tuple.release();
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject * THPStorage_(releaseIPCCounter)(PyObject *_unused, PyObject *args)
+{
+  HANDLE_TH_ERRORS
+  THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected");
+  PyObject *_ref_counter = PyTuple_GET_ITEM(args, 0);
+  PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 1);
+  if (!(PyBytes_Check(_ref_counter) &&
+        THPUtils_checkLong(_ref_counter_offset))) {
+    THPUtils_invalidArguments(
+        args,
+        nullptr,
+        "_release_ipc_counter in CUDA mode",
+        1,
+        "(bytes _ref_counter, int _ref_counter_offset)");
+    return nullptr;
+  }
+  std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
+  ptrdiff_t ref_counter_offset =
+      (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
+  // We don't want to break existing code, so resource deletion is best
+  // effort basis. Exception expected if producer process terminated
+  // before consumer released data.
+  int flags =
+      TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_NOCREATE;
+  try {
+    auto sptr = THRefcountedMapAllocator::makeDataPtr(
+        ref_counter_handle.c_str(),
+        flags,
+        sizeof(int64_t) * torch::CUDA_IPC_REF_COUNTER_FILE_SIZE,
+        nullptr);
+    *(static_cast<int64_t*>(sptr.get()) + ref_counter_offset) -= 1;
+  } catch (c10::Error) {
+    // Already warned inside of producer process
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static std::string THPStorage_(bytesAsHandleString)(PyObject *handle) {
+  char* buffer;
+  Py_ssize_t handle_size;
+  if (PyBytes_AsStringAndSize(handle, &buffer, &handle_size) == -1) {
+    return nullptr;
+  }
+  THPUtils_assert(
+      handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
+  return std::string(buffer, handle_size);
+}
+
 static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
 {
   HANDLE_TH_ERRORS
-  THPUtils_assert(PyTuple_GET_SIZE(args) == 4, "tuple of 4 items expected");
+  THPUtils_assert(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected");
   PyObject *_device = PyTuple_GET_ITEM(args, 0);
   PyObject *_handle = PyTuple_GET_ITEM(args, 1);
   PyObject *_size_bytes = PyTuple_GET_ITEM(args, 2);
   PyObject *_offset_bytes = PyTuple_GET_ITEM(args, 3);
-  if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes)
-      && (_handle != Py_None && PyBytes_Check(_handle))
-      && THPUtils_checkLong(_offset_bytes))) {
-    THPUtils_invalidArguments(args, nullptr, "_new_shared in CUDA mode", 1,
-        "(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes)");
+  PyObject *_ref_counter = PyTuple_GET_ITEM(args, 4);
+  PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 5);
+  PyObject *_event_handle = PyTuple_GET_ITEM(args, 6);
+  PyObject *_event_sync_required = PyTuple_GET_ITEM(args, 7);
+  if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes) &&
+        PyBytes_Check(_handle) && PyBytes_Check(_ref_counter) &&
+        PyBytes_Check(_event_handle) && THPUtils_checkLong(_offset_bytes) &&
+        THPUtils_checkLong(_ref_counter_offset) && PyBool_Check(_event_sync_required))) {
+    THPUtils_invalidArguments(
+        args,
+        nullptr,
+        "_new_shared in CUDA mode",
+        1,
+        "(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes, bytes _ref_counter, int _ref_counter_offset, bytes event_handle, bool event_sync_required)");
     return nullptr;
   }
 
-  // Storage constructor requires size in numel.
   size_t storage_size = (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(scalar_t);
   ptrdiff_t storage_offset_bytes = (ptrdiff_t)THPUtils_unpackLong(_offset_bytes);
 
   int64_t device = THPUtils_unpackLong(_device);
   at::cuda::CUDAGuard device_guard(device);
 
-  char *buffer;
-  Py_ssize_t handle_size;
-  if (PyBytes_AsStringAndSize(_handle, &buffer, &handle_size) == -1) {
-    return nullptr;
+#ifndef __HIP_PLATFORM_HCC__
+  if (PyObject_IsTrue(_event_sync_required)) {
+    // Ensure that producer prepared all tensor's data
+    std::string s_ipc_event_handle =
+        THPStorage_(bytesAsHandleString)(_event_handle);
+    auto ipc_event_handle = reinterpret_cast<const cudaIpcEventHandle_t*>(
+        s_ipc_event_handle.c_str());
+    cudaEvent_t event;
+    cudaIpcOpenEventHandle(&event, *ipc_event_handle);
+    AT_CUDA_CHECK(
+        cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0));
   }
-  THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
-  std::string s_handle = std::string(buffer, handle_size);
+#else
+  // Already synchronized inside producer stream
+#endif
+
+  std::string s_handle = THPStorage_(bytesAsHandleString)(_handle);
   std::shared_ptr<void> basePtr = c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle);
 
   // Offset the basePtr to reconstruct the real storage
@@ -289,11 +397,50 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
   void* devPtr = basePtr.get();
   devPtr = (char*)devPtr + storage_offset_bytes;
 
+  std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter);
+  ptrdiff_t ref_counter_offset = (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset);
+
+  auto c = new torch::CudaIPCReceivedData(std::move(basePtr));
+  auto sp = std::shared_ptr<void>(
+      (void*)c, [ref_counter_handle, ref_counter_offset, device](void* ptr) {
+        delete static_cast<torch::CudaIPCReceivedData*>(ptr);
+        // Sync default stream to make sure all operations related to the storage is
+        // finished (otherwise another process may reuse memory and corrupt
+        // data)
+
+        // Ideally all shared memory reference counting could be replaced by
+        // sending untriggered CUDA event from the producer to consumer and
+        // using this event as the criteria of memory release. However, CUDA (atm 10.1)
+        // does not support the creation of untriggered events and performance
+        // impact of having thousands of shared events is unknown.
+
+        // TODO: Instead of cudaStreamSynchronize it is possible to add Stream
+        // Callback and release counter inside of it (need to check performance impact)
+        cudaStreamSynchronize(c10::cuda::getCurrentCUDAStream(device));
+
+        // We don't want to break existing code, so resource deletion is best
+        // effort basis. Exception expected if producer process terminated
+        // before consumer released data.
+        int flags =
+            TH_ALLOCATOR_MAPPED_SHAREDMEM | TH_ALLOCATOR_MAPPED_NOCREATE;
+        try {
+          auto sptr = THRefcountedMapAllocator::makeDataPtr(
+              ref_counter_handle.c_str(),
+              flags,
+              sizeof(int64_t) * torch::CUDA_IPC_REF_COUNTER_FILE_SIZE,
+              nullptr);
+          *(static_cast<int64_t*>(sptr.get()) + ref_counter_offset) -= 1;
+        } catch (c10::Error) {
+          // Already warned inside of producer process
+        }
+      });
+
   THWStoragePtr base(THWStorage_(newWithDataAndAllocator)(
       LIBRARY_STATE
-      THCIpcDeleter::makeDataPtr(std::move(basePtr), devPtr),
+      THCIpcDeleter::makeDataPtr(std::move(sp), devPtr),
       storage_size, /* allocator */ nullptr));
   base->set_resizable(false);
+  base->set_received_cuda(true);
 
   return THPStorage_(New)(base.release());
   END_HANDLE_TH_ERRORS
@@ -382,6 +529,7 @@ static PyMethodDef THPStorage_(sharingMethods)[] = {
 #ifdef THC_GENERIC_FILE
   {"_share_cuda_", (PyCFunction)THPStorage_(shareCuda), METH_NOARGS, nullptr},
   {"_new_shared_cuda", (PyCFunction)THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr},
+  {"_release_ipc_counter", (PyCFunction)THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr},
 #else
   {"_share_fd_", (PyCFunction)THPStorage_(shareFd), METH_NOARGS, nullptr},
   {"_new_shared_fd", (PyCFunction)THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr},
index 08d8565..da7149e 100644 (file)
@@ -358,6 +358,19 @@ def synchronize():
     return torch._C._cuda_synchronize()
 
 
+def ipc_collect():
+    r"""Force collects GPU memory after it has been released by CUDA IPC.
+
+    .. note::
+        Checks if any sent CUDA tensors could be cleaned from the memory. Force
+        closes shared memory file used for reference counting if there is no
+        active counters. Useful when the producer process stopped actively sending
+        tensors and want to release unused memory.
+    """
+    _lazy_init()
+    return torch._C._cuda_ipc_collect()
+
+
 def current_stream(device=None):
     r"""Returns the currently selected :class:`Stream` for a given device.
 
diff --git a/torch/multiprocessing/cuda_multiprocessing.md b/torch/multiprocessing/cuda_multiprocessing.md
new file mode 100644 (file)
index 0000000..f6b3283
--- /dev/null
@@ -0,0 +1,32 @@
+# CUDA IPC Refcounting implementation explained
+
+Since shared CUDA memory belongs to the producer process, we need to take special precautions to make sure that it is stays allocated for entire shared tensor life-span.
+
+It could be done manually by syncing on an event:
+
+```python
+# Producer
+queue.put(tensor)
+event.wait()
+
+# Consumer
+tensor = queue.get()
+safe_to_use_tensor = tensor.clone()
+event.set()
+```
+
+However, this requires blocking producer process (and gets overcomplicated in case of multiple consumers and handling various race-conditions).
+
+Instead, we implement cross-process reference counting for shared CUDA (and HIP) tensors, which will take care of keeping producers memory allocated for entire tensor's life-span.
+
+Details of implementation follow.
+
+At the moment of sending tensor, we are wrapping DataPtr of the tensor with additional structure CudaIPCSentData. It still points to the same memory, but have other behavior on destruction.
+
+Instead of simply removing the allocated block, it checks if there are any active references to this block (references are stored in shared memory files described by CudaIPCRefCountersFile structure). If such exists, instead of deleting blocks DataPtr it is moved to the global state CudaIPCSentDataLimbo.
+
+Each individual CudaIPCRefCountersFile contains multiple reference counters for multiple tensors. Current implementation sequentially provides next available reference counter by increasing offset.
+
+CudaIPCSentDataLimbo is keeping references to data blocks which are not in use by producer process (i.e., tensor when out of scope), but still in use (or will be in use) by a consumer. It also tries to reduce the number of stored blocks by scanning the limbo list for blocks whose ref count has gone to zero on various events such as CudaCaching allocator haven't found any suitable block for the next allocation, the attempt of any shared block deletion, explicit call of cuda_ipc_collect.
+
+Consumer's side wraps received data into the different structure CudaIPCReceivedData. On destruction, it takes care of decreasing reference count to the received tensor.
index 79c01ba..1bf2268 100644 (file)
@@ -87,7 +87,7 @@ def rebuild_tensor(cls, storage, metadata):
 
 def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
                         storage_cls, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
-                        requires_grad):
+                        requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required):
     # If storage_handle is None, storage points to nullptr.
     if storage_handle is None or storage_size_bytes == 0:
         storage = storage_cls(0)
@@ -99,8 +99,15 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
                 storage_device,
                 storage_handle,
                 storage_size_bytes,
-                storage_offset_bytes)
+                storage_offset_bytes,
+                ref_counter_handle,
+                ref_counter_offset,
+                event_handle,
+                event_sync_required)
             shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
+        else:
+            # We already ref counting this Storage, but producer needs new ref-counters to be released.
+            storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
 
     t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
     if tensor_cls == torch.nn.parameter.Parameter:
@@ -211,11 +218,16 @@ def reduce_tensor(tensor):
     # thing.
     #
     if storage.is_cuda:
-        (device, handle, storage_size_bytes, storage_offset_bytes) = storage._share_cuda_()
+        (device,
+         handle,
+         storage_size_bytes,
+         storage_offset_bytes,
+         ref_counter_handle,
+         ref_counter_offset,
+         event_handle,
+         event_sync_required) = storage._share_cuda_()
         tensor_offset = tensor.storage_offset()
-
         shared_cache[handle] = StorageWeakRef(storage)
-
         # _backward_hooks purposely omitted here, see
         # Note [Don't serialize hooks]
         return (rebuild_cuda_tensor,
@@ -228,7 +240,11 @@ def reduce_tensor(tensor):
                  handle,  # identifier which CUDA allocation is the storage in.
                  storage_size_bytes,  # size(in bytes) of the storage
                  storage_offset_bytes,  # offset(in bytes) of the storage in the CUDA allocation
-                 tensor.requires_grad))
+                 tensor.requires_grad,
+                 ref_counter_handle,
+                 ref_counter_offset,
+                 event_handle,
+                 event_sync_required))
 
     # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
     metadata = (tensor.storage_offset(), tensor.size(), tensor.stride(), tensor.requires_grad)