#include "THCAllocator.h"
-THCIpcDeleter::~THCIpcDeleter() {
- int prev_device;
- THCudaCheck(cudaGetDevice(&prev_device));
- THCudaCheck(cudaSetDevice(device_));
- THCudaCheck(cudaIpcCloseMemHandle(data_));
- THCudaCheck(cudaSetDevice(prev_device));
-}
+THCIpcDeleter::~THCIpcDeleter() {}
void deleteTHCIpcDeleter(void* ptr) {
delete static_cast<THCIpcDeleter*>(ptr);
}
-at::DataPtr THCIpcDeleter::makeDataPtr(void* data, int device) {
+// Refer to Note [CUDA IPC and the caching allocator] for more details
+// basePtr - device ptr of a single cudaMalloc allocation; this may be a large
+// block of memory which is managed by the caching allocator.
+// data - ptr to where the storage (of a single type) should start.
+// Invariant: data must lie within the CUDA memory allocation represented by
+// basePtr.
+// Here basePtr should be saved in the struct, while data should be used to
+// construct the new storage.
+// Every time a storage referring to the IPC memory region goes out of scope,
+// the reference count on the memory region will be decreased by one, until
+// it's zero, at which point IPC memory region is closed (by calling
+// cudaIpcCloseMemHandle).
+at::DataPtr THCIpcDeleter::makeDataPtr(std::shared_ptr<void> basePtr, void* data) {
// The dynamic allocation here is a bit unfortunate
int cur_device;
THCudaCheck(cudaGetDevice(&cur_device));
- auto* context = new THCIpcDeleter(data, device);
+ auto* context = new THCIpcDeleter(std::move(basePtr));
return {data, context, &deleteTHCIpcDeleter, at::Device(at::DeviceType::CUDA, cur_device)};
}
-THCIpcDeleter::THCIpcDeleter(void* data, int device)
- : data_(data), device_(device) {}
+THCIpcDeleter::THCIpcDeleter(std::shared_ptr<void> basePtr)
+ : basePtr_(std::move(basePtr)) {}
#ifdef __cplusplus
class CAFFE2_API THCIpcDeleter {
public:
- THCIpcDeleter(void* data, int device);
+ THCIpcDeleter(std::shared_ptr<void> basePtr);
~THCIpcDeleter();
- static at::DataPtr makeDataPtr(void* data, int device);
+ static at::DataPtr makeDataPtr(std::shared_ptr<void> basePtr, void* data);
private:
- void* data_;
- int device_;
+ std::shared_ptr<void> basePtr_;
};
#endif
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).max_amount_cached;
}
+
+//
+// In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr
+// is called by the receiving process to map the CUDA memory from the sending
+// process into its own address space.
+//
+// CUDA IPC only allows sharing a big memory block associated with a cudaIpcMemHandle_t
+// and it can be opened only **once** per context per process. There can be
+// multiple types of storage in the same IPC mem block, so we must cache the
+// device ptr to construct typed storage as it comes.
+//
+// ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the process
+// that can be used to access the memory block in the sender process.
+// It only saves a weak_ptr of the device pointer in the map, the shared_ptr
+// will be used to reconstruct all storages in this CudaMalloc allocation.
+// And it will deleted in cudaIpcCloseMemHandle when its reference count is 0.
+//
+namespace {
+ std::mutex IpcMutex;
+ std::unordered_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr;
+}
+
+AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle) {
+ std::lock_guard<std::mutex> lock(IpcMutex);
+
+ auto iter = ipcMemHandle_to_devptr.find(handle);
+ if (iter != ipcMemHandle_to_devptr.end()) {
+ auto devptr = iter->second.lock();
+ if (devptr) return devptr;
+ }
+ // This ipcMemHandle hasn't been opened, or already expired, open it to
+ // enable IPC access to that mem block.
+ void *dev = nullptr;
+ auto ipc_handle = reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str());
+ THCudaCheck(cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
+ // devPtr has to be deleted in same device when created.
+ int curr_device;
+ THCudaCheck(cudaGetDevice(&curr_device));
+ auto sp = std::shared_ptr<void>(
+ dev,
+ [handle, curr_device](void *ptr) {
+ at::cuda::CUDAGuard device_guard(curr_device);
+ std::lock_guard<std::mutex> lock(IpcMutex);
+ THCudaCheck(cudaIpcCloseMemHandle(ptr));
+ ipcMemHandle_to_devptr.erase(handle);});
+ std::weak_ptr<void> wp = sp;
+ // To eliminate an additional search, we can use insert().
+ // It doesn't overwrite when key already exists(ptr expired).
+ // But in the deleter for sp we erased the entry,
+ // this should be safe to do now.
+ ipcMemHandle_to_devptr.insert(iter, {handle, wp});
+
+ return sp;
+}
+
THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
#endif
+AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle);
#endif
queue.put(is_ok)
+def mixed_type_producer(queue, event):
+ for _ in range(10):
+ float_tensor = torch.ones(2, 2).float().cuda()
+ byte_tensor = torch.zeros(2, 2).byte().cuda()
+
+ queue.put(float_tensor)
+ queue.put(byte_tensor)
+ event.wait()
+ event.clear()
+
+
@contextlib.contextmanager
def fs_sharing():
prev_strategy = mp.get_sharing_strategy()
p.join(1)
self.assertFalse(p.is_alive())
+ # Check sharing a cudaMalloc allocation with different types of storage.
+ # (Issue #11422)
+ def _test_mixed_types_cuda_sharing(self, ctx=mp):
+ all_ones = torch.ones(2, 2).float()
+ all_zeros = torch.zeros(2, 2).byte()
+ queue = ctx.Queue()
+ event = ctx.Event()
+
+ p = ctx.Process(target=mixed_type_producer, args=(queue, event))
+
+ p.start()
+
+ for _ in range(10):
+ float_tensor = queue.get()
+ byte_tensor = queue.get()
+ self.assertEqual(float_tensor, all_ones)
+ self.assertEqual(byte_tensor, all_zeros)
+ del float_tensor, byte_tensor
+ event.set()
+
+ time.sleep(5)
+ p.join()
+
def test_variable_sharing(self):
for requires_grad in [True, False]:
var = torch.arange(1., 26).view(5, 5).requires_grad_(requires_grad)
var = torch.arange(1., 26, device='cuda').view(5, 5).requires_grad_(requires_grad)
self._test_autograd_sharing(var, mp.get_context('spawn'))
+ @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+ don't support multiprocessing with spawn start method")
+ @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+ def test_mixed_types_cuda_sharing(self):
+ self._test_mixed_types_cuda_sharing(mp.get_context('spawn'))
+
def test_parameter_sharing(self):
param = Parameter(torch.arange(1., 26).view(5, 5))
self._test_autograd_sharing(param, is_parameter=True)
THPObjectPtr device(PyLong_FromLong(storage->device().index()));
THPObjectPtr _handle(Py_None);
Py_INCREF(Py_None);
- THPObjectPtr size(PyLong_FromLong(storage->numel()));
- THPObjectPtr _offset(PyLong_FromLong(0));
+ THPObjectPtr size_bytes(PyLong_FromLong(storage->numel() * sizeof(scalar_t)));
+ THPObjectPtr _offset_bytes(PyLong_FromLong(0));
if (THWStorage_(data)(LIBRARY_STATE storage)) {
size_t base_size;
void *base_ptr = THCCachingAllocator_getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
- ptrdiff_t offset = (char*)storage->data<scalar_t>() - (char*)base_ptr;
+ ptrdiff_t offset_bytes = (char*)storage->data<scalar_t>() - (char*)base_ptr;
cudaIpcMemHandle_t handle;
THCudaCheck(cudaIpcGetMemHandle(&handle, base_ptr));
_handle = PyBytes_FromStringAndSize((char *)&handle, CUDA_IPC_HANDLE_SIZE);
- _offset = PyLong_FromSsize_t((Py_ssize_t)offset / sizeof(scalar_t));
- size = PyLong_FromSize_t(base_size / sizeof(scalar_t));
+ _offset_bytes = PyLong_FromSsize_t((Py_ssize_t)offset_bytes);
}
- if (!tuple || !device || !_handle || !size || !_offset) {
+
+ if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), 0, device.release());
+ // cudaIpcMemHandle_t(of basePtr)
PyTuple_SET_ITEM(tuple.get(), 1, _handle.release());
- PyTuple_SET_ITEM(tuple.get(), 2, size.release());
- PyTuple_SET_ITEM(tuple.get(), 3, _offset.release());
+ // Size(in bytes) of the real storage, note this is not the size of basePtr memory block.
+ PyTuple_SET_ITEM(tuple.get(), 2, size_bytes.release());
+ // Offset(in bytes) of the real storage in the basePtr memory block.
+ // NB: this offset MUST be in bytes instead of numel, since we use (storage_handle, offset)
+ // as key in shared_cache(multiprocessing/reduction.py).
+ // Offset in numel cannot uniquely represent a storage.
+ PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release());
return tuple.release();
END_HANDLE_TH_ERRORS
}
static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
- THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected");
+ THPUtils_assert(PyTuple_GET_SIZE(args) == 4, "tuple of 4 items expected");
PyObject *_device = PyTuple_GET_ITEM(args, 0);
PyObject *_handle = PyTuple_GET_ITEM(args, 1);
- PyObject *_size = PyTuple_GET_ITEM(args, 2);
- if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size)
- && (_handle == Py_None || PyBytes_Check(_handle)))) {
+ PyObject *_size_bytes = PyTuple_GET_ITEM(args, 2);
+ PyObject *_offset_bytes = PyTuple_GET_ITEM(args, 3);
+ if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes)
+ && (_handle != Py_None && PyBytes_Check(_handle))
+ && THPUtils_checkLong(_offset_bytes))) {
THPUtils_invalidArguments(args, nullptr, "_new_shared in CUDA mode", 1,
- "(int device, bytes handle, int storage_size)");
+ "(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes)");
return nullptr;
}
- size_t storage_size = (size_t)THPUtils_unpackLong(_size);
+ // Storage constructor requires size in numel.
+ size_t storage_size = (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(scalar_t);
+ ptrdiff_t storage_offset_bytes = (ptrdiff_t)THPUtils_unpackLong(_offset_bytes);
int64_t device = THPUtils_unpackLong(_device);
at::cuda::CUDAGuard device_guard(device);
return nullptr;
}
THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
- cudaIpcMemHandle_t handle = *(cudaIpcMemHandle_t*)buffer;
+ std::string s_handle = std::string(buffer, handle_size);
+ std::shared_ptr<void> basePtr = THCCaching_CUDAIpcDevptr(s_handle);
- void *devPtr = nullptr;
- THCudaCheck(cudaIpcOpenMemHandle(&devPtr, handle, cudaIpcMemLazyEnablePeerAccess));
+ // Offset the basePtr to reconstruct the real storage
+ // devPtr = basePtr + storage_offset
+ void* devPtr = basePtr.get();
+ devPtr = (char*)devPtr + storage_offset_bytes;
THWStoragePtr base(THWStorage_(newWithDataAndAllocator)(
LIBRARY_STATE
- THCIpcDeleter::makeDataPtr(devPtr, device),
+ THCIpcDeleter::makeDataPtr(std::move(basePtr), devPtr),
storage_size, /* allocator */ nullptr));
base->set_resizable(false);
def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
- storage_cls, storage_device, storage_handle, storage_size, requires_grad):
-
- storage = storage_from_cache(storage_cls, storage_handle)
- if storage is None:
- torch.cuda._lazy_init()
- storage = storage_cls._new_shared_cuda(storage_device, storage_handle, storage_size)
- shared_cache[storage_handle] = StorageWeakRef(storage)
+ storage_cls, storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
+ requires_grad):
+ # If storage_handle is None, storage points to nullptr.
+ if storage_handle is None or storage_size_bytes == 0:
+ storage = storage_cls(0)
+ else:
+ storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
+ if storage is None:
+ torch.cuda._lazy_init()
+ storage = storage_cls._new_shared_cuda(
+ storage_device,
+ storage_handle,
+ storage_size_bytes,
+ storage_offset_bytes)
+ shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
if tensor_cls == torch.nn.parameter.Parameter:
# *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
# the storage 0xA100 (because that is what CUDA supports). So, on the
# other end, there simply isn't any way to say, "Wait, you gave me
- # a bigger region (0xA000) than the one I wanted (0xA100)"; we have
- # to just make a storage for the entire caching allocator block.
+ # a bigger region (0xA000) than the one I wanted (0xA100)".
+ #
+ # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
+ # one storage itself? No, because this cudaMalloc allocation might contain
+ # storages of mixed types: float, bytes, double... If you make the entire
+ # allocation a single storage of a type A, we'll hit an error when constructing
+ # a tensor of type B on the storage.
#
- # This is fine, because all we need to do is just adjust the offset
- # on the tensor itself: instead of:
+ # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
+ # receiver side. However, cudaIpcMemHandles from each device in a given process may
+ # only be opened by one context per device per other process.
+ # If we open and close a memory handle multiples times in a process, CUDA is allowed
+ # to give it a different address; similarly, once we close the memory, we're not
+ # allowed to access it(and the storage/tensor built on top of it), even if it is
+ # still live in the original process. As we cannot make a cudaMalloc allocation
+ # to a single storage in one go, this requires us to cache the device pointer for
+ # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
+ # the old ones alives.
+ # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
#
- # Tensor(size=0x100, offset=0x020, storage=Storage(data=0xA100, size=0x0100))
+ # This is fine, because all we need to do is to save our position in the allocaiton,
+ # and reconstruct storage and tensor from it.
+ # 0xA000 -> -------CUDA Allocation------
+ # | |
+ # | |
+ # | |
+ # | |
+ # 0xA100 -> --------storage1 begin------
+ # | |
+ # 0xA120 -> --------tensor1 begin ------
+ # | |
+ # | |
+ # | |
+ # | |
+ # | |
+ # 0xA160 -> --------tensor1 end---------
+ # | |
+ # | |
+ # | |
+ # 0xA200 -> --------storage1 end--------
+ # | |
+ # 0xE000 -> --------CUDA allocation-----
#
- # we have
+ # To send tensor1, the following info are required from sender to receiver for
+ # storage recontruction.
+ # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
+ # basePtr may not be exactly 0xA000 since it's a different process.
+ # 2. offset(0xA100) of storage1 in the CUDA allocation.
+ # 3. size of storage1(0x100).
#
- # Tensor(size=0x100, offset=0x120, storage=Storage(data=0xA000, size=0x4000))
+ # On receiver side:
+ # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
+ # of the same type using (basePtr, offset, size).
+ # 2. we can reconstruct the tensor on top of the recontructed storage
+ # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
#
# This strategy has a few implications:
#
- # 1. When we serialize a CUDA tensor for IPC, we have to do it all in one
- # go (non-compositionally), instead of first serializing storage, and
- # then serializing tensor. This is because the base address of the
- # storage allocation affects what offset we write into the tensor.
+ # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
+ # go (non-compositionally), and this requires to have a global map
+ # memHandle -> devPtr for each process.
#
# 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
# of the storage beyond 0x100 would merely have caused us to do a
# thing.
#
if storage.is_cuda:
- (device, handle, storage_size, storage_offset) = storage._share_cuda_()
+ (device, handle, storage_size_bytes, storage_offset_bytes) = storage._share_cuda_()
tensor_offset = tensor.storage_offset()
shared_cache[handle] = StorageWeakRef(storage)
(type(tensor),
tensor.size(),
tensor.stride(),
- tensor_offset + storage_offset,
+ tensor_offset, # tensor offset in its storage
type(storage),
device,
- handle,
- storage_size,
+ handle, # identifier which CUDA allocation is the storage in.
+ storage_size_bytes, # size(in bytes) of the storage
+ storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
tensor.requires_grad))
# _backward_hooks purposely omitted here, see Note [Don't serialize hooks]