Fix cuda multiprocessing cached memory (#14736)
authorAiling Zhang <ailzhang@fb.com>
Wed, 5 Dec 2018 18:52:39 +0000 (10:52 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 18:55:43 +0000 (10:55 -0800)
Summary:
This PR fixes #11422

In the old world of CUDA IPC, when we want to share a tensor T from A to B, we have to share the whole CUDA mem allocation where T's storage sit in. And we casted it to the same type of storage of T's.

This causes problem when two different types of storage got allocated to the same CUDA mem block. When we try to reconstruct the second tensor, it will complain about wrong storage type.

In this PR we reconstruct the storage only (not the entire mem block). However, CUDA only allows one open memHandle once per process, we have to save the device pointer in a global cache so that we can reconstruct tensors as they come.

Thanks a ton to ezyang who helped design the solution and debugged the issue!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14736

Differential Revision: D13335899

Pulled By: ailzhang

fbshipit-source-id: cad69db392ed6f8fdc2b93a9dc2899f6d378c371

aten/src/THC/THCAllocator.cpp
aten/src/THC/THCAllocator.h
aten/src/THC/THCCachingAllocator.cpp
aten/src/THC/THCCachingAllocator.h
test/test_multiprocessing.py
torch/csrc/generic/StorageSharing.cpp
torch/multiprocessing/reductions.py

index 78b650a..29a25cd 100644 (file)
@@ -1,24 +1,30 @@
 #include "THCAllocator.h"
 
-THCIpcDeleter::~THCIpcDeleter() {
-  int prev_device;
-  THCudaCheck(cudaGetDevice(&prev_device));
-  THCudaCheck(cudaSetDevice(device_));
-  THCudaCheck(cudaIpcCloseMemHandle(data_));
-  THCudaCheck(cudaSetDevice(prev_device));
-}
+THCIpcDeleter::~THCIpcDeleter() {}
 
 void deleteTHCIpcDeleter(void* ptr) {
   delete static_cast<THCIpcDeleter*>(ptr);
 }
 
-at::DataPtr THCIpcDeleter::makeDataPtr(void* data, int device) {
+// Refer to Note [CUDA IPC and the caching allocator] for more details
+// basePtr - device ptr of a single cudaMalloc allocation; this may be a large
+//           block of memory which is managed by the caching allocator.
+// data    - ptr to where the storage (of a single type) should start.
+// Invariant: data must lie within the CUDA memory allocation represented by
+//   basePtr.
+// Here basePtr should be saved in the struct, while data should be used to
+// construct the new storage.
+// Every time a storage referring to the IPC memory region goes out of scope,
+// the reference count on the memory region will be decreased by one, until
+// it's zero, at which point IPC memory region is closed (by calling
+// cudaIpcCloseMemHandle).
+at::DataPtr THCIpcDeleter::makeDataPtr(std::shared_ptr<void> basePtr, void* data) {
   // The dynamic allocation here is a bit unfortunate
   int cur_device;
   THCudaCheck(cudaGetDevice(&cur_device));
-  auto* context = new THCIpcDeleter(data, device);
+  auto* context = new THCIpcDeleter(std::move(basePtr));
   return {data, context, &deleteTHCIpcDeleter, at::Device(at::DeviceType::CUDA, cur_device)};
 }
 
-THCIpcDeleter::THCIpcDeleter(void* data, int device)
-    : data_(data), device_(device) {}
+THCIpcDeleter::THCIpcDeleter(std::shared_ptr<void> basePtr)
+    : basePtr_(std::move(basePtr)) {}
index 5ff8de1..177d3ba 100644 (file)
@@ -8,12 +8,11 @@
 #ifdef __cplusplus
 class CAFFE2_API THCIpcDeleter {
  public:
-  THCIpcDeleter(void* data, int device);
+  THCIpcDeleter(std::shared_ptr<void> basePtr);
   ~THCIpcDeleter();
-  static at::DataPtr makeDataPtr(void* data, int device);
+  static at::DataPtr makeDataPtr(std::shared_ptr<void> basePtr, void* data);
 private:
-  void* data_;
-  int device_;
+  std::shared_ptr<void> basePtr_;
 };
 #endif
 
index 802d7c0..ccaf1da 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <ATen/Context.h>
 #include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAGuard.h>
 #include <ATen/cuda/Exceptions.h>
 
 #include <cuda_runtime_api.h>
@@ -587,3 +588,58 @@ THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device) {
   assertValidDevice(device);
   return caching_allocator.get_stats_for_device(device).max_amount_cached;
 }
+
+//
+// In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr
+// is called by the receiving process to map the CUDA memory from the sending
+// process into its own address space.
+//
+// CUDA IPC only allows sharing a big memory block associated with a cudaIpcMemHandle_t
+// and it can be opened only **once** per context per process. There can be
+// multiple types of storage in the same IPC mem block, so we must cache the
+// device ptr to construct typed storage as it comes.
+//
+// ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the process
+// that can be used to access the memory block in the sender process.
+// It only saves a weak_ptr of the device pointer in the map, the shared_ptr
+// will be used to reconstruct all storages in this CudaMalloc allocation.
+// And it will deleted in cudaIpcCloseMemHandle when its reference count is 0.
+//
+namespace {
+  std::mutex IpcMutex;
+  std::unordered_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr;
+}
+
+AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle) {
+  std::lock_guard<std::mutex> lock(IpcMutex);
+
+  auto iter = ipcMemHandle_to_devptr.find(handle);
+  if (iter != ipcMemHandle_to_devptr.end()) {
+    auto devptr = iter->second.lock();
+    if (devptr) return devptr;
+  }
+  // This ipcMemHandle hasn't been opened, or already expired, open it to
+  // enable IPC access to that mem block.
+  void *dev = nullptr;
+  auto ipc_handle = reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str());
+  THCudaCheck(cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
+  // devPtr has to be deleted in same device when created.
+  int curr_device;
+  THCudaCheck(cudaGetDevice(&curr_device));
+  auto sp = std::shared_ptr<void>(
+      dev,
+      [handle, curr_device](void *ptr) {
+        at::cuda::CUDAGuard device_guard(curr_device);
+        std::lock_guard<std::mutex> lock(IpcMutex);
+        THCudaCheck(cudaIpcCloseMemHandle(ptr));
+        ipcMemHandle_to_devptr.erase(handle);});
+  std::weak_ptr<void> wp = sp;
+  // To eliminate an additional search, we can use insert().
+  // It doesn't overwrite when key already exists(ptr expired).
+  // But in the deleter for sp we erased the entry,
+  // this should be safe to do now.
+  ipcMemHandle_to_devptr.insert(iter, {handle, wp});
+
+  return sp;
+}
+
index 188f4f2..41991df 100644 (file)
@@ -27,4 +27,5 @@ THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device);
 THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
 #endif
 
+AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle);
 #endif
index b2b9cb5..c509d37 100644 (file)
@@ -118,6 +118,17 @@ def autograd_sharing(queue, ready, master_modified, device, is_parameter):
     queue.put(is_ok)
 
 
+def mixed_type_producer(queue, event):
+    for _ in range(10):
+        float_tensor = torch.ones(2, 2).float().cuda()
+        byte_tensor = torch.zeros(2, 2).byte().cuda()
+
+        queue.put(float_tensor)
+        queue.put(byte_tensor)
+        event.wait()
+        event.clear()
+
+
 @contextlib.contextmanager
 def fs_sharing():
     prev_strategy = mp.get_sharing_strategy()
@@ -441,6 +452,29 @@ class TestMultiprocessing(TestCase):
         p.join(1)
         self.assertFalse(p.is_alive())
 
+    # Check sharing a cudaMalloc allocation with different types of storage.
+    # (Issue #11422)
+    def _test_mixed_types_cuda_sharing(self, ctx=mp):
+        all_ones = torch.ones(2, 2).float()
+        all_zeros = torch.zeros(2, 2).byte()
+        queue = ctx.Queue()
+        event = ctx.Event()
+
+        p = ctx.Process(target=mixed_type_producer, args=(queue, event))
+
+        p.start()
+
+        for _ in range(10):
+            float_tensor = queue.get()
+            byte_tensor = queue.get()
+            self.assertEqual(float_tensor, all_ones)
+            self.assertEqual(byte_tensor, all_zeros)
+            del float_tensor, byte_tensor
+            event.set()
+
+        time.sleep(5)
+        p.join()
+
     def test_variable_sharing(self):
         for requires_grad in [True, False]:
             var = torch.arange(1., 26).view(5, 5).requires_grad_(requires_grad)
@@ -483,6 +517,12 @@ class TestMultiprocessing(TestCase):
             var = torch.arange(1., 26, device='cuda').view(5, 5).requires_grad_(requires_grad)
             self._test_autograd_sharing(var, mp.get_context('spawn'))
 
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_mixed_types_cuda_sharing(self):
+        self._test_mixed_types_cuda_sharing(mp.get_context('spawn'))
+
     def test_parameter_sharing(self):
         param = Parameter(torch.arange(1., 26).view(5, 5))
         self._test_autograd_sharing(param, is_parameter=True)
index 723f4fd..0b73880 100644 (file)
@@ -221,27 +221,33 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
   THPObjectPtr device(PyLong_FromLong(storage->device().index()));
   THPObjectPtr _handle(Py_None);
   Py_INCREF(Py_None);
-  THPObjectPtr size(PyLong_FromLong(storage->numel()));
-  THPObjectPtr _offset(PyLong_FromLong(0));
+  THPObjectPtr size_bytes(PyLong_FromLong(storage->numel() * sizeof(scalar_t)));
+  THPObjectPtr _offset_bytes(PyLong_FromLong(0));
   if (THWStorage_(data)(LIBRARY_STATE storage)) {
     size_t base_size;
     void *base_ptr = THCCachingAllocator_getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
-    ptrdiff_t offset = (char*)storage->data<scalar_t>() - (char*)base_ptr;
+    ptrdiff_t offset_bytes = (char*)storage->data<scalar_t>() - (char*)base_ptr;
 
     cudaIpcMemHandle_t handle;
     THCudaCheck(cudaIpcGetMemHandle(&handle, base_ptr));
 
     _handle = PyBytes_FromStringAndSize((char *)&handle, CUDA_IPC_HANDLE_SIZE);
-    _offset = PyLong_FromSsize_t((Py_ssize_t)offset / sizeof(scalar_t));
-    size = PyLong_FromSize_t(base_size / sizeof(scalar_t));
+    _offset_bytes = PyLong_FromSsize_t((Py_ssize_t)offset_bytes);
   }
-  if (!tuple || !device || !_handle || !size || !_offset) {
+
+  if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes) {
     return nullptr;
   }
   PyTuple_SET_ITEM(tuple.get(), 0, device.release());
+  // cudaIpcMemHandle_t(of basePtr)
   PyTuple_SET_ITEM(tuple.get(), 1, _handle.release());
-  PyTuple_SET_ITEM(tuple.get(), 2, size.release());
-  PyTuple_SET_ITEM(tuple.get(), 3, _offset.release());
+  // Size(in bytes) of the real storage, note this is not the size of basePtr memory block.
+  PyTuple_SET_ITEM(tuple.get(), 2, size_bytes.release());
+  // Offset(in bytes) of the real storage in the basePtr memory block.
+  // NB: this offset MUST be in bytes instead of numel, since we use (storage_handle, offset)
+  //     as key in shared_cache(multiprocessing/reduction.py).
+  //     Offset in numel cannot uniquely represent a storage.
+  PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release());
   return tuple.release();
   END_HANDLE_TH_ERRORS
 }
@@ -249,18 +255,22 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
 static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
 {
   HANDLE_TH_ERRORS
-  THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected");
+  THPUtils_assert(PyTuple_GET_SIZE(args) == 4, "tuple of 4 items expected");
   PyObject *_device = PyTuple_GET_ITEM(args, 0);
   PyObject *_handle = PyTuple_GET_ITEM(args, 1);
-  PyObject *_size = PyTuple_GET_ITEM(args, 2);
-  if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size)
-      && (_handle == Py_None || PyBytes_Check(_handle)))) {
+  PyObject *_size_bytes = PyTuple_GET_ITEM(args, 2);
+  PyObject *_offset_bytes = PyTuple_GET_ITEM(args, 3);
+  if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes)
+      && (_handle != Py_None && PyBytes_Check(_handle))
+      && THPUtils_checkLong(_offset_bytes))) {
     THPUtils_invalidArguments(args, nullptr, "_new_shared in CUDA mode", 1,
-        "(int device, bytes handle, int storage_size)");
+        "(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes)");
     return nullptr;
   }
 
-  size_t storage_size = (size_t)THPUtils_unpackLong(_size);
+  // Storage constructor requires size in numel.
+  size_t storage_size = (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(scalar_t);
+  ptrdiff_t storage_offset_bytes = (ptrdiff_t)THPUtils_unpackLong(_offset_bytes);
 
   int64_t device = THPUtils_unpackLong(_device);
   at::cuda::CUDAGuard device_guard(device);
@@ -271,14 +281,17 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
     return nullptr;
   }
   THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
-  cudaIpcMemHandle_t handle = *(cudaIpcMemHandle_t*)buffer;
+  std::string s_handle = std::string(buffer, handle_size);
+  std::shared_ptr<void> basePtr = THCCaching_CUDAIpcDevptr(s_handle);
 
-  void *devPtr = nullptr;
-  THCudaCheck(cudaIpcOpenMemHandle(&devPtr, handle, cudaIpcMemLazyEnablePeerAccess));
+  // Offset the basePtr to reconstruct the real storage
+  // devPtr = basePtr + storage_offset
+  void* devPtr = basePtr.get();
+  devPtr = (char*)devPtr + storage_offset_bytes;
 
   THWStoragePtr base(THWStorage_(newWithDataAndAllocator)(
       LIBRARY_STATE
-      THCIpcDeleter::makeDataPtr(devPtr, device),
+      THCIpcDeleter::makeDataPtr(std::move(basePtr), devPtr),
       storage_size, /* allocator */ nullptr));
   base->set_resizable(false);
 
index 9ce9b7c..b55e4ae 100644 (file)
@@ -85,13 +85,21 @@ def rebuild_tensor(cls, storage, metadata):
 
 
 def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
-                        storage_cls, storage_device, storage_handle, storage_size, requires_grad):
-
-    storage = storage_from_cache(storage_cls, storage_handle)
-    if storage is None:
-        torch.cuda._lazy_init()
-        storage = storage_cls._new_shared_cuda(storage_device, storage_handle, storage_size)
-        shared_cache[storage_handle] = StorageWeakRef(storage)
+                        storage_cls, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
+                        requires_grad):
+    # If storage_handle is None, storage points to nullptr.
+    if storage_handle is None or storage_size_bytes == 0:
+        storage = storage_cls(0)
+    else:
+        storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
+        if storage is None:
+            torch.cuda._lazy_init()
+            storage = storage_cls._new_shared_cuda(
+                storage_device,
+                storage_handle,
+                storage_size_bytes,
+                storage_offset_bytes)
+            shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
 
     t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
     if tensor_cls == torch.nn.parameter.Parameter:
@@ -125,24 +133,67 @@ def reduce_tensor(tensor):
     # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
     # the storage 0xA100 (because that is what CUDA supports).  So, on the
     # other end, there simply isn't any way to say, "Wait, you gave me
-    # a bigger region (0xA000) than the one I wanted (0xA100)"; we have
-    # to just make a storage for the entire caching allocator block.
+    # a bigger region (0xA000) than the one I wanted (0xA100)".
+    #
+    # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
+    # one storage itself? No, because this cudaMalloc allocation might contain
+    # storages of mixed types: float, bytes, double... If you make the entire
+    # allocation a single storage of a type A, we'll hit an error when constructing
+    # a tensor of type B on the storage.
     #
-    # This is fine, because all we need to do is just adjust the offset
-    # on the tensor itself: instead of:
+    # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
+    # receiver side. However, cudaIpcMemHandles from each device in a given process may
+    # only be opened by one context per device per other process.
+    # If we open and close a memory handle multiples times in a process, CUDA is allowed
+    # to give it a different address; similarly, once we close the memory, we're not
+    # allowed to access it(and the storage/tensor built on top of it), even if it is
+    # still live in the original process. As we cannot make a cudaMalloc allocation
+    # to a single storage in one go, this requires us to cache the device pointer for
+    # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
+    # the old ones alives.
+    # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
     #
-    #   Tensor(size=0x100, offset=0x020, storage=Storage(data=0xA100, size=0x0100))
+    # This is fine, because all we need to do is to save our position in the allocaiton,
+    # and reconstruct storage and tensor from it.
+    # 0xA000 ->  -------CUDA Allocation------
+    #           |                            |
+    #           |                            |
+    #           |                            |
+    #           |                            |
+    # 0xA100 ->  --------storage1 begin------
+    #           |                            |
+    # 0xA120 ->  --------tensor1 begin ------
+    #           |                            |
+    #           |                            |
+    #           |                            |
+    #           |                            |
+    #           |                            |
+    # 0xA160 ->  --------tensor1 end---------
+    #           |                            |
+    #           |                            |
+    #           |                            |
+    # 0xA200 ->  --------storage1 end--------
+    #           |                            |
+    # 0xE000 ->  --------CUDA allocation-----
     #
-    # we have
+    # To send tensor1, the following info are required from sender to receiver for
+    # storage recontruction.
+    #   1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
+    #      basePtr may not be exactly 0xA000 since it's a different process.
+    #   2. offset(0xA100) of storage1 in the CUDA allocation.
+    #   3. size of storage1(0x100).
     #
-    #   Tensor(size=0x100, offset=0x120, storage=Storage(data=0xA000, size=0x4000))
+    # On receiver side:
+    #   1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
+    #      of the same type using (basePtr, offset, size).
+    #   2. we can reconstruct the tensor on top of the recontructed storage
+    #   Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
     #
     # This strategy has a few implications:
     #
-    # 1. When we serialize a CUDA tensor for IPC, we have to do it all in one
-    #    go (non-compositionally), instead of first serializing storage, and
-    #    then serializing tensor.  This is because the base address of the
-    #    storage allocation affects what offset we write into the tensor.
+    # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
+    #    go (non-compositionally), and this requires to have a global map
+    #    memHandle -> devPtr for each process.
     #
     # 2. We MUST NOT let the new IPC tensor be resizable.  Originally, a resize
     #    of the storage beyond 0x100 would merely have caused us to do a
@@ -159,7 +210,7 @@ def reduce_tensor(tensor):
     # thing.
     #
     if storage.is_cuda:
-        (device, handle, storage_size, storage_offset) = storage._share_cuda_()
+        (device, handle, storage_size_bytes, storage_offset_bytes) = storage._share_cuda_()
         tensor_offset = tensor.storage_offset()
 
         shared_cache[handle] = StorageWeakRef(storage)
@@ -170,11 +221,12 @@ def reduce_tensor(tensor):
                 (type(tensor),
                  tensor.size(),
                  tensor.stride(),
-                 tensor_offset + storage_offset,
+                 tensor_offset,  # tensor offset in its storage
                  type(storage),
                  device,
-                 handle,
-                 storage_size,
+                 handle,  # identifier which CUDA allocation is the storage in.
+                 storage_size_bytes,  # size(in bytes) of the storage
+                 storage_offset_bytes,  # offset(in bytes) of the storage in the CUDA allocation
                  tensor.requires_grad))
 
     # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]