--- /dev/null
+#pragma once
+
+#include <c10/core/Allocator.h>
+#include <c10/core/DeviceType.h>
+
+// Use of c10::hip namespace here makes hipification easier, because
+// I don't have to also fix namespaces. Sorry!
+namespace c10 { namespace hip {
+
+// Takes a valid HIPAllocator (of any sort) and turns it into
+// an allocator pretending to be a CUDA allocator. See
+// Note [Masquerading as CUDA]
+class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
+ Allocator* allocator_;
+public:
+ explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
+ : allocator_(allocator) {}
+ DataPtr allocate(size_t size) const override {
+ DataPtr r = allocator_->allocate(size);
+ r.unsafe_set_device(Device(DeviceType::CUDA, r.device().index()));
+ return r;
+ }
+ DeleterFnPtr raw_deleter() const override {
+ return allocator_->raw_deleter();
+ }
+};
+
+}} // namespace c10::hip
--- /dev/null
+#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
+
+namespace c10 { namespace hip {
+namespace HIPCachingAllocatorMasqueradingAsCUDA {
+
+Allocator* get() {
+ static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
+ return &allocator;
+}
+
+} // namespace HIPCachingAllocatorMasqueradingAsCUDA
+}} // namespace c10::hip
--- /dev/null
+#pragma once
+
+#include <c10/hip/HIPCachingAllocator.h>
+#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
+
+namespace c10 { namespace hip {
+namespace HIPCachingAllocatorMasqueradingAsCUDA {
+
+Allocator* get();
+
+} // namespace HIPCachingAllocatorMasqueradingAsCUDA
+}} // namespace c10::hip
// I don't have to also fix namespaces. Sorry!
namespace c10 { namespace hip {
+// Note [Masquerading as CUDA]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
// HIPGuardImplMasqueradingAsCUDA is like a HIPGuardImpl, but
// it reports its DeviceType as CUDA (e.g., type() returns CUDA,
// getDevice() reports the current HIP device as a CUDA device.)
// Free the cached blocks in our caching allocator. They are
// needed here because the above benchmarking uses a huge amount of memory,
// e.g. a few GBs.
- at::cuda::THCCachingAllocator_emptyCache();
+ c10::cuda::CUDACachingAllocator::emptyCache();
}
*algoPerf = perfResults;
cache.insert(args.params, *algo);
wsscache.insert(args.params, perfResults.memory);
- cuda::THCCachingAllocator_emptyCache();
+ c10::hip::HIPCachingAllocator::emptyCache();
}
template<typename algo_t>
endforeach()
set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
- ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingAllocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THCGeneral.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cpp
THCScanUtils.cuh
THCSortUtils.cuh
THCAllocator.h
- THCCachingAllocator.h
THCCachingHostAllocator.h
THCDeviceUtils.cuh
THCDeviceTensor.cuh
THCThrustAllocator.cuh
THCTensorMode.cuh
THCTensorTopK.cuh
- THCCachingAllocator.h
# See Note [TH abstraction violation]
THCGenerator.hpp
THCTensor.hpp
#include <THC/THCGeneral.h>
#include <THC/THCAllocator.h>
#include <THC/THCBlas.h>
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
#include <THC/THCCachingHostAllocator.h>
#include <THC/THCSleep.h>
#include <THC/THCStorage.h>
+++ /dev/null
-#ifndef THC_DEVICE_ALLOCATOR_INC
-#define THC_DEVICE_ALLOCATOR_INC
-
-#include <c10/cuda/CUDAStream.h>
-#include <c10/core/Allocator.h>
-#include <ATen/cuda/ATenCUDAGeneral.h>
-
-#include <mutex>
-
-namespace at {
-namespace cuda {
-
-AT_CUDA_API Allocator* THCCachingAllocator_get(void);
-AT_CUDA_API void THCCachingAllocator_emptyCache(void);
-AT_CUDA_API void THCCachingAllocator_cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
-AT_CUDA_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size);
-AT_CUDA_API void THCCachingAllocator_recordStream(void *ptr, at::cuda::CUDAStream stream);
-AT_CUDA_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device);
-AT_CUDA_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device);
-AT_CUDA_API void THCCachingAllocator_resetMaxMemoryAllocated(int device);
-AT_CUDA_API uint64_t THCCachingAllocator_currentMemoryCached(int device);
-AT_CUDA_API uint64_t THCCachingAllocator_maxMemoryCached(int device);
-AT_CUDA_API void THCCachingAllocator_resetMaxMemoryCached(int device);
-
-AT_CUDA_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
-
-AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle);
-
-}} // namespace at::cuda
-
-#endif
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAContext.h>
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
#include <stdlib.h>
#include <stdint.h>
void THCudaInit(THCState* state)
{
if (!state->cudaDeviceAllocator) {
- state->cudaDeviceAllocator = at::cuda::THCCachingAllocator_get();
+ state->cudaDeviceAllocator = c10::cuda::CUDACachingAllocator::get();
}
if (!state->cudaHostAllocator) {
state->cudaHostAllocator = getTHCCachingHostAllocator();
}
free(state->resourcesPerDevice);
- if (state->cudaDeviceAllocator == at::cuda::THCCachingAllocator_get()) {
- at::cuda::THCCachingAllocator_emptyCache();
+ if (state->cudaDeviceAllocator == c10::cuda::CUDACachingAllocator::get()) {
+ c10::cuda::CUDACachingAllocator::emptyCache();
}
if (state->cudaHostAllocator == getTHCCachingHostAllocator()) {
THCCachingHostAllocator_emptyCache();
/* not always true - our optimistic guess here */
*largestBlock = *freeBytes;
- if (allocator == at::cuda::THCCachingAllocator_get()) {
- at::cuda::THCCachingAllocator_cacheInfo(device, &cachedBytes, largestBlock);
+ if (allocator == c10::cuda::CUDACachingAllocator::get()) {
+ c10::cuda::CUDACachingAllocator::cacheInfo(device, &cachedBytes, largestBlock);
}
/* Adjust resulting free bytes number. largesBlock unused for now */
Device device() const {
return device_;
}
+ // Unsafely mutates the device on a DataPtr. Under normal use,
+ // you should never actually need to call this function.
+ // We need this for the implementation of the hack detailed
+ // in Note [Masquerading as CUDA]
+ void unsafe_set_device(Device device) {
+ device_ = device;
+ }
};
// NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a
# and headers you add
set(C10_CUDA_SRCS
CUDAStream.cpp
+ CUDACachingAllocator.cpp
impl/CUDAGuardImpl.cpp
impl/CUDATest.cpp
)
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <unordered_set>
#include <vector>
-namespace at {
+namespace c10 {
namespace cuda {
+namespace CUDACachingAllocator {
+
//
// Yet another caching allocator for CUDA device allocations.
//
CudaCachingAllocator device_allocator;
-Allocator* THCCachingAllocator_get(void)
+Allocator* get(void)
{
return &device_allocator;
}
-void THCCachingAllocator_emptyCache(void) {
+void emptyCache(void) {
caching_allocator.emptyCache();
}
-void THCCachingAllocator_cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) {
+void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) {
caching_allocator.cacheInfo(dev_id, cachedAndFree, largestBlock);
}
-void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size)
+void* getBaseAllocation(void *ptr, size_t *size)
{
return caching_allocator.getBaseAllocation(ptr, size);
}
-void THCCachingAllocator_recordStream(void *ptr, cuda::CUDAStream stream)
+void recordStream(void *ptr, cuda::CUDAStream stream)
{
caching_allocator.recordStream(ptr, stream);
}
-std::mutex* THCCachingAllocator_getCudaFreeMutex()
+std::mutex* getFreeMutex()
{
return &caching_allocator.cuda_free_mutex;
}
AT_ASSERTM(0 <= device && device < device_count, "Invalid device argument.");
}
-uint64_t THCCachingAllocator_currentMemoryAllocated(int device)
+uint64_t currentMemoryAllocated(int device)
{
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).amount_allocated;
}
-uint64_t THCCachingAllocator_maxMemoryAllocated(int device) {
+uint64_t maxMemoryAllocated(int device) {
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).max_amount_allocated;
}
-void THCCachingAllocator_resetMaxMemoryAllocated(int device) {
+void resetMaxMemoryAllocated(int device) {
assertValidDevice(device);
DeviceStats& stats = caching_allocator.get_stats_for_device(device);
stats.max_amount_allocated = stats.amount_allocated;
}
-uint64_t THCCachingAllocator_currentMemoryCached(int device)
+uint64_t currentMemoryCached(int device)
{
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).amount_cached;
}
-uint64_t THCCachingAllocator_maxMemoryCached(int device) {
+uint64_t maxMemoryCached(int device) {
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).max_amount_cached;
}
-void THCCachingAllocator_resetMaxMemoryCached(int device) {
+void resetMaxMemoryCached(int device) {
assertValidDevice(device);
DeviceStats& stats = caching_allocator.get_stats_for_device(device);
stats.max_amount_cached = stats.amount_cached;
}
//
-// In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr
+// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
// is called by the receiving process to map the CUDA memory from the sending
// process into its own address space.
//
std::unordered_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr;
}
-AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle) {
+std::shared_ptr<void> getIpcDevPtr(std::string handle) {
std::lock_guard<std::mutex> lock(IpcMutex);
auto iter = ipcMemHandle_to_devptr.find(handle);
return sp;
}
-}} // namespace at::cuda
+} // namespace CUDACachingAllocator
+
+}} // namespace c10::cuda
--- /dev/null
+#ifndef THC_DEVICE_ALLOCATOR_INC
+#define THC_DEVICE_ALLOCATOR_INC
+
+#include <c10/cuda/CUDAStream.h>
+#include <c10/core/Allocator.h>
+#include <c10/cuda/CUDAMacros.h>
+
+#include <mutex>
+
+namespace c10 {
+namespace cuda {
+
+// TODO: Turn this into an honest to goodness class. I briefly attempted to do
+// this, but it was a bit irritating to figure out how to also correctly
+// apply pimpl pattern so I didn't have to leak any internal implementation
+// details in the header (CUDACachingAllocator could be made a pimpl, but
+// you also need to appropriately define a class which is a subclass
+// of Allocator. Not impossible, but required a bit more surgery than
+// I wanted to do at the time.)
+//
+// Why is this using a namespace rather than old-style THCCachingAllocator_
+// prefix? Mostly because it made the HIPify rules easier to write; _ is
+// not counted as a word boundary, so you would otherwise have to list each
+// of these functions.
+
+namespace CUDACachingAllocator {
+
+C10_CUDA_API Allocator* get();
+C10_CUDA_API void emptyCache();
+C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
+C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size);
+C10_CUDA_API void recordStream(void *ptr, CUDAStream stream);
+C10_CUDA_API uint64_t currentMemoryAllocated(int device);
+C10_CUDA_API uint64_t maxMemoryAllocated(int device);
+C10_CUDA_API void resetMaxMemoryAllocated(int device);
+C10_CUDA_API uint64_t currentMemoryCached(int device);
+C10_CUDA_API uint64_t maxMemoryCached(int device);
+C10_CUDA_API void resetMaxMemoryCached(int device);
+
+C10_CUDA_API std::mutex* getFreeMutex();
+
+C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
+
+} // namespace CUDACachingAllocator
+
+}} // namespace c10::cuda
+
+#endif
("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)),
("OptionalCUDAStreamGuard", ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)),
+ # Only get needs to be transformed this way; all the other ones can go
+ # straight to the normal versions hip::HIPCachingAllocator
+ ("cuda::CUDACachingAllocator::get", ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH)),
+ ("CUDACachingAllocator::get", ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH)),
+
# TODO: Undo this special-case; see the header for motivation behind this
# hack. It's VERY important this is only applied to PyTorch HIPify.
("c10/cuda/CUDAGuard.h", ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH)),
+ ("c10/cuda/CUDACachingAllocator.h", ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH)),
])
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([
("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)),
("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)),
("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)),
+ ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)),
("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)),
("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)),
("c10/cuda/impl/cuda_cmake_macros.h", ("c10/hip/impl/hip_cmake_macros.h", API_C10)),
("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)),
("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)),
("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)),
+ ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)),
+ ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)),
])
# NB: C10 mappings are more specific than Caffe2 mappings, so run them
return PyErr_Format(PyExc_TypeError, "expected Stream object");
}
void* data = self_.data_ptr();
- at::cuda::THCCachingAllocator_recordStream(data, at::cuda::CUDAStream::unpack(((THCPStream*)arg)->cdata));
+ c10::cuda::CUDACachingAllocator::recordStream(data, at::cuda::CUDAStream::unpack(((THCPStream*)arg)->cdata));
Py_RETURN_NONE;
#else
throw std::runtime_error("PyTorch compiled without CUDA support");
#include <TH/TH.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
#ifdef USE_NCCL
#include <nccl.h>
#endif
PyObject * THCPModule_cudaLockMutex(PyObject *module)
{
- auto mutex = at::cuda::THCCachingAllocator_getCudaFreeMutex();
+ auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex();
// This has to be a busy loop because we **absolutely need to** hold the GIL
// or it's a recipe for a deadlock otherwise (if we let other Python threads
// run while we have the cudaMutex, but not the GIL, they might try to e.g.
PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
{
- auto mutex = at::cuda::THCCachingAllocator_getCudaFreeMutex();
+ auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex();
PyGILState_Release(cudaMutexGILState);
mutex->unlock();
Py_RETURN_NONE;
PyObject * THCPModule_emptyCache(PyObject *_unused)
{
HANDLE_TH_ERRORS
- at::cuda::THCCachingAllocator_emptyCache();
+ c10::cuda::CUDACachingAllocator::emptyCache();
END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
}
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_allocated");
int device = (int) THPUtils_unpackLong(arg);
- auto memory_allocated = at::cuda::THCCachingAllocator_currentMemoryAllocated(device);
+ auto memory_allocated = c10::cuda::CUDACachingAllocator::currentMemoryAllocated(device);
return PyLong_FromUnsignedLongLong(memory_allocated);
END_HANDLE_TH_ERRORS
}
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to max_memory_allocated");
int device = (int) THPUtils_unpackLong(arg);
- auto max_memory_allocated = at::cuda::THCCachingAllocator_maxMemoryAllocated(device);
+ auto max_memory_allocated = c10::cuda::CUDACachingAllocator::maxMemoryAllocated(device);
return PyLong_FromUnsignedLongLong(max_memory_allocated);
END_HANDLE_TH_ERRORS
}
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_allocated");
int device = (int) THPUtils_unpackLong(arg);
- at::cuda::THCCachingAllocator_resetMaxMemoryAllocated(device);
+ c10::cuda::CUDACachingAllocator::resetMaxMemoryAllocated(device);
END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
}
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_cached");
int device = (int) THPUtils_unpackLong(arg);
- auto memory_cached = at::cuda::THCCachingAllocator_currentMemoryCached(device);
+ auto memory_cached = c10::cuda::CUDACachingAllocator::currentMemoryCached(device);
return PyLong_FromUnsignedLongLong(memory_cached);
END_HANDLE_TH_ERRORS
}
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to max_memory_cached");
int device = (int) THPUtils_unpackLong(arg);
- auto max_memory_cached = at::cuda::THCCachingAllocator_maxMemoryCached(device);
+ auto max_memory_cached = c10::cuda::CUDACachingAllocator::maxMemoryCached(device);
return PyLong_FromUnsignedLongLong(max_memory_cached);
END_HANDLE_TH_ERRORS
}
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_cached");
int device = (int) THPUtils_unpackLong(arg);
- at::cuda::THCCachingAllocator_resetMaxMemoryCached(device);
+ c10::cuda::CUDACachingAllocator::resetMaxMemoryCached(device);
END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
}
int64_t numel = tensors[0].numel();
std::lock_guard<std::mutex> free_mutex(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
const auto comms = user_comms.empty() ? _get_communicators(tensors)
: ArrayRef<ncclComm_t>(user_comms);
ncclDataType_t data_type = _get_data_type(inputs[0].type());
const auto count = inputs[0].numel();
- std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
auto comms_ref = user_comms.empty() ? _get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
ncclDataType_t data_type = _get_data_type(inputs[0].type());
int64_t count = inputs[0].numel();
- std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
auto comms = user_comms.empty() ? _get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
at::cuda::OptionalCUDAGuard device_guard;
ncclDataType_t data_type = _get_data_type(inputs[0].type());
int64_t count = inputs[0].numel();
- std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
auto comms = user_comms.empty() ? _get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
at::cuda::OptionalCUDAGuard device_guard;
ncclDataType_t data_type = _get_data_type(inputs[0].type());
int64_t count = inputs[0].numel() / len;
- std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
auto comms = user_comms.empty() ? _get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
at::cuda::OptionalCUDAGuard device_guard;
THPObjectPtr _offset_bytes(PyLong_FromLong(0));
if (THWStorage_(data)(LIBRARY_STATE storage)) {
size_t base_size;
- void *base_ptr = at::cuda::THCCachingAllocator_getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
+ void *base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
ptrdiff_t offset_bytes = (char*)storage->data<scalar_t>() - (char*)base_ptr;
cudaIpcMemHandle_t handle;
}
THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
std::string s_handle = std::string(buffer, handle_size);
- std::shared_ptr<void> basePtr = at::cuda::THCCaching_CUDAIpcDevptr(s_handle);
+ std::shared_ptr<void> basePtr = c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle);
// Offset the basePtr to reconstruct the real storage
// devPtr = basePtr + storage_offset
TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
cudaFree(0);
}
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < data.size(); ++i) {
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < input.size(); ++i) {
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < data.size(); ++i) {
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < data.size(); ++i) {
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
C10D_NCCL_CHECK(ncclGroupStart());
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
C10D_NCCL_CHECK(ncclGroupStart());
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
C10D_NCCL_CHECK(ncclGroupStart());
at::cuda::OptionalCUDAGuard gpuGuard;
std::unique_lock<std::mutex> cudaFreeMutexLock(
- *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
C10D_NCCL_CHECK(ncclGroupStart());