Move THCCachingAllocator to c10_cuda. (#16119)
authorEdward Yang <ezyang@fb.com>
Thu, 24 Jan 2019 20:00:34 +0000 (12:00 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 24 Jan 2019 20:06:56 +0000 (12:06 -0800)
Summary:
Some renaming and renamespacing also took place. I was originally planning not to do anything, but it turns out that it was easier to make HIPify work by using a namespace CUDACachingAllocator:: rather than THCCachingAllocator_, since :: is a word boundary but _ is not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/16119

Reviewed By: smessmer

Differential Revision: D13718768

fbshipit-source-id: 884a481d99027fd3e34471c020f826aa12225656

23 files changed:
aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h [new file with mode: 0644]
aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp [new file with mode: 0644]
aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h [new file with mode: 0644]
aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h
aten/src/ATen/native/cudnn/Conv.cpp
aten/src/ATen/native/miopen/Conv_miopen.cpp
aten/src/THC/CMakeLists.txt
aten/src/THC/THC.h
aten/src/THC/THCCachingAllocator.h [deleted file]
aten/src/THC/THCGeneral.cpp
c10/core/Allocator.h
c10/cuda/CMakeLists.txt
c10/cuda/CUDACachingAllocator.cpp [moved from aten/src/THC/THCCachingAllocator.cpp with 95% similarity]
c10/cuda/CUDACachingAllocator.h [new file with mode: 0644]
tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
tools/autograd/templates/python_variable_methods.cpp
torch/csrc/cuda/Module.cpp
torch/csrc/cuda/nccl.cpp
torch/csrc/cuda/python_nccl.cpp
torch/csrc/generic/StorageSharing.cpp
torch/csrc/jit/fuser/cuda/fused_kernel.cpp
torch/lib/THD/base/data_channels/DataChannelNccl.cpp
torch/lib/c10d/ProcessGroupNCCL.cpp

diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
new file mode 100644 (file)
index 0000000..c764cda
--- /dev/null
@@ -0,0 +1,28 @@
+#pragma once
+
+#include <c10/core/Allocator.h>
+#include <c10/core/DeviceType.h>
+
+// Use of c10::hip namespace here makes hipification easier, because
+// I don't have to also fix namespaces.  Sorry!
+namespace c10 { namespace hip {
+
+// Takes a valid HIPAllocator (of any sort) and turns it into
+// an allocator pretending to be a CUDA allocator.  See
+// Note [Masquerading as CUDA]
+class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
+  Allocator* allocator_;
+public:
+  explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
+    : allocator_(allocator) {}
+  DataPtr allocate(size_t size) const override {
+    DataPtr r = allocator_->allocate(size);
+    r.unsafe_set_device(Device(DeviceType::CUDA, r.device().index()));
+    return r;
+  }
+  DeleterFnPtr raw_deleter() const override {
+    return allocator_->raw_deleter();
+  }
+};
+
+}} // namespace c10::hip
diff --git a/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp b/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp
new file mode 100644 (file)
index 0000000..fb76bf2
--- /dev/null
@@ -0,0 +1,12 @@
+#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
+
+namespace c10 { namespace hip {
+namespace HIPCachingAllocatorMasqueradingAsCUDA {
+
+Allocator* get() {
+  static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
+  return &allocator;
+}
+
+} // namespace HIPCachingAllocatorMasqueradingAsCUDA
+}} // namespace c10::hip
diff --git a/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h
new file mode 100644 (file)
index 0000000..dd3af64
--- /dev/null
@@ -0,0 +1,12 @@
+#pragma once
+
+#include <c10/hip/HIPCachingAllocator.h>
+#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
+
+namespace c10 { namespace hip {
+namespace HIPCachingAllocatorMasqueradingAsCUDA {
+
+Allocator* get();
+
+} // namespace HIPCachingAllocatorMasqueradingAsCUDA
+}} // namespace c10::hip
index 1d00b1f..195ce31 100644 (file)
@@ -15,6 +15,8 @@
 // I don't have to also fix namespaces.  Sorry!
 namespace c10 { namespace hip {
 
+// Note [Masquerading as CUDA]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
 // HIPGuardImplMasqueradingAsCUDA is like a HIPGuardImpl, but
 // it reports its DeviceType as CUDA (e.g., type() returns CUDA,
 // getDevice() reports the current HIP device as a CUDA device.)
index 2d4c713..2fb265f 100644 (file)
@@ -708,7 +708,7 @@ void findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args,
         // Free the cached blocks in our caching allocator. They are
         // needed here because the above benchmarking uses a huge amount of memory,
         // e.g. a few GBs.
-        at::cuda::THCCachingAllocator_emptyCache();
+        c10::cuda::CUDACachingAllocator::emptyCache();
       }
 
       *algoPerf = perfResults;
index 90ba2dd..19b2980 100644 (file)
@@ -522,7 +522,7 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
   cache.insert(args.params, *algo);
   wsscache.insert(args.params, perfResults.memory);
 
-  cuda::THCCachingAllocator_emptyCache();
+  c10::hip::HIPCachingAllocator::emptyCache();
 }
 
 template<typename algo_t>
index 6b8786a..fddfe7d 100644 (file)
@@ -19,7 +19,6 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
 endforeach()
 
 set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
-  ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingAllocator.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCGeneral.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cpp
@@ -72,7 +71,6 @@ INSTALL(FILES
           THCScanUtils.cuh
           THCSortUtils.cuh
           THCAllocator.h
-          THCCachingAllocator.h
           THCCachingHostAllocator.h
           THCDeviceUtils.cuh
           THCDeviceTensor.cuh
@@ -100,7 +98,6 @@ INSTALL(FILES
           THCThrustAllocator.cuh
           THCTensorMode.cuh
           THCTensorTopK.cuh
-          THCCachingAllocator.h
           # See Note [TH abstraction violation]
           THCGenerator.hpp
           THCTensor.hpp
index 6481392..79be433 100644 (file)
@@ -4,7 +4,7 @@
 #include <THC/THCGeneral.h>
 #include <THC/THCAllocator.h>
 #include <THC/THCBlas.h>
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
 #include <THC/THCCachingHostAllocator.h>
 #include <THC/THCSleep.h>
 #include <THC/THCStorage.h>
diff --git a/aten/src/THC/THCCachingAllocator.h b/aten/src/THC/THCCachingAllocator.h
deleted file mode 100644 (file)
index c22261b..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#ifndef THC_DEVICE_ALLOCATOR_INC
-#define THC_DEVICE_ALLOCATOR_INC
-
-#include <c10/cuda/CUDAStream.h>
-#include <c10/core/Allocator.h>
-#include <ATen/cuda/ATenCUDAGeneral.h>
-
-#include <mutex>
-
-namespace at {
-namespace cuda {
-
-AT_CUDA_API Allocator* THCCachingAllocator_get(void);
-AT_CUDA_API void THCCachingAllocator_emptyCache(void);
-AT_CUDA_API void THCCachingAllocator_cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
-AT_CUDA_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size);
-AT_CUDA_API void THCCachingAllocator_recordStream(void *ptr, at::cuda::CUDAStream stream);
-AT_CUDA_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device);
-AT_CUDA_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device);
-AT_CUDA_API void     THCCachingAllocator_resetMaxMemoryAllocated(int device);
-AT_CUDA_API uint64_t THCCachingAllocator_currentMemoryCached(int device);
-AT_CUDA_API uint64_t THCCachingAllocator_maxMemoryCached(int device);
-AT_CUDA_API void     THCCachingAllocator_resetMaxMemoryCached(int device);
-
-AT_CUDA_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
-
-AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle);
-
-}} // namespace at::cuda
-
-#endif
index a1a9079..6c7d6ca 100644 (file)
@@ -8,7 +8,7 @@
 #include <c10/cuda/CUDAStream.h>
 #include <ATen/cuda/CUDAContext.h>
 
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
 #include <stdlib.h>
 #include <stdint.h>
 
@@ -41,7 +41,7 @@ THCState* THCState_alloc(void)
 void THCudaInit(THCState* state)
 {
   if (!state->cudaDeviceAllocator) {
-    state->cudaDeviceAllocator = at::cuda::THCCachingAllocator_get();
+    state->cudaDeviceAllocator = c10::cuda::CUDACachingAllocator::get();
   }
   if (!state->cudaHostAllocator) {
     state->cudaHostAllocator = getTHCCachingHostAllocator();
@@ -130,8 +130,8 @@ void THCudaShutdown(THCState* state)
   }
 
   free(state->resourcesPerDevice);
-  if (state->cudaDeviceAllocator == at::cuda::THCCachingAllocator_get()) {
-    at::cuda::THCCachingAllocator_emptyCache();
+  if (state->cudaDeviceAllocator == c10::cuda::CUDACachingAllocator::get()) {
+    c10::cuda::CUDACachingAllocator::emptyCache();
   }
   if (state->cudaHostAllocator == getTHCCachingHostAllocator()) {
     THCCachingHostAllocator_emptyCache();
@@ -421,8 +421,8 @@ cudaError_t THCudaMemGetInfo(THCState *state,  size_t* freeBytes, size_t* totalB
   /* not always true - our optimistic guess here */
   *largestBlock = *freeBytes;
 
-  if (allocator == at::cuda::THCCachingAllocator_get()) {
-    at::cuda::THCCachingAllocator_cacheInfo(device, &cachedBytes, largestBlock);
+  if (allocator == c10::cuda::CUDACachingAllocator::get()) {
+    c10::cuda::CUDACachingAllocator::cacheInfo(device, &cachedBytes, largestBlock);
   }
 
   /* Adjust resulting free bytes number. largesBlock unused for now */
index 37ab55b..ff04cbd 100644 (file)
@@ -59,6 +59,13 @@ class DataPtr {
   Device device() const {
     return device_;
   }
+  // Unsafely mutates the device on a DataPtr.  Under normal use,
+  // you should never actually need to call this function.
+  // We need this for the implementation of the hack detailed
+  // in Note [Masquerading as CUDA]
+  void unsafe_set_device(Device device) {
+    device_ = device;
+  }
 };
 
 // NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a
index a493084..f72b866 100644 (file)
@@ -21,6 +21,7 @@ configure_file(
 # and headers you add
 set(C10_CUDA_SRCS
     CUDAStream.cpp
+    CUDACachingAllocator.cpp
     impl/CUDAGuardImpl.cpp
     impl/CUDATest.cpp
 )
similarity index 95%
rename from aten/src/THC/THCCachingAllocator.cpp
rename to c10/cuda/CUDACachingAllocator.cpp
index ff0e9dc..8e10708 100644 (file)
@@ -1,4 +1,4 @@
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
 
 #include <c10/cuda/CUDAGuard.h>
 #include <c10/cuda/CUDAException.h>
 #include <unordered_set>
 #include <vector>
 
-namespace at {
+namespace c10 {
 namespace cuda {
 
+namespace CUDACachingAllocator {
+
 //
 // Yet another caching allocator for CUDA device allocations.
 //
@@ -536,30 +538,30 @@ struct CudaCachingAllocator : public Allocator {
 
 CudaCachingAllocator device_allocator;
 
-Allocator* THCCachingAllocator_get(void)
+Allocator* get(void)
 {
   return &device_allocator;
 }
 
-void THCCachingAllocator_emptyCache(void) {
+void emptyCache(void) {
   caching_allocator.emptyCache();
 }
 
-void THCCachingAllocator_cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) {
+void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) {
   caching_allocator.cacheInfo(dev_id, cachedAndFree, largestBlock);
 }
 
-void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size)
+void* getBaseAllocation(void *ptr, size_t *size)
 {
   return caching_allocator.getBaseAllocation(ptr, size);
 }
 
-void THCCachingAllocator_recordStream(void *ptr, cuda::CUDAStream stream)
+void recordStream(void *ptr, cuda::CUDAStream stream)
 {
   caching_allocator.recordStream(ptr, stream);
 }
 
-std::mutex* THCCachingAllocator_getCudaFreeMutex()
+std::mutex* getFreeMutex()
 {
   return &caching_allocator.cuda_free_mutex;
 }
@@ -570,42 +572,42 @@ static inline void assertValidDevice(int device) {
   AT_ASSERTM(0 <= device && device < device_count, "Invalid device argument.");
 }
 
-uint64_t THCCachingAllocator_currentMemoryAllocated(int device)
+uint64_t currentMemoryAllocated(int device)
 {
   assertValidDevice(device);
   return caching_allocator.get_stats_for_device(device).amount_allocated;
 }
 
-uint64_t THCCachingAllocator_maxMemoryAllocated(int device) {
+uint64_t maxMemoryAllocated(int device) {
   assertValidDevice(device);
   return caching_allocator.get_stats_for_device(device).max_amount_allocated;
 }
 
-void THCCachingAllocator_resetMaxMemoryAllocated(int device) {
+void resetMaxMemoryAllocated(int device) {
   assertValidDevice(device);
   DeviceStats& stats = caching_allocator.get_stats_for_device(device);
   stats.max_amount_allocated = stats.amount_allocated;
 }
 
-uint64_t THCCachingAllocator_currentMemoryCached(int device)
+uint64_t currentMemoryCached(int device)
 {
   assertValidDevice(device);
   return caching_allocator.get_stats_for_device(device).amount_cached;
 }
 
-uint64_t THCCachingAllocator_maxMemoryCached(int device) {
+uint64_t maxMemoryCached(int device) {
   assertValidDevice(device);
   return caching_allocator.get_stats_for_device(device).max_amount_cached;
 }
 
-void THCCachingAllocator_resetMaxMemoryCached(int device) {
+void resetMaxMemoryCached(int device) {
   assertValidDevice(device);
   DeviceStats& stats = caching_allocator.get_stats_for_device(device);
   stats.max_amount_cached = stats.amount_cached;
 }
 
 //
-// In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr
+// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
 // is called by the receiving process to map the CUDA memory from the sending
 // process into its own address space.
 //
@@ -625,7 +627,7 @@ namespace {
   std::unordered_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr;
 }
 
-AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle) {
+std::shared_ptr<void> getIpcDevPtr(std::string handle) {
   std::lock_guard<std::mutex> lock(IpcMutex);
 
   auto iter = ipcMemHandle_to_devptr.find(handle);
@@ -658,4 +660,6 @@ AT_CUDA_API std::shared_ptr<void> THCCaching_CUDAIpcDevptr(std::string handle) {
   return sp;
 }
 
-}} // namespace at::cuda
+} // namespace CUDACachingAllocator
+
+}} // namespace c10::cuda
diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h
new file mode 100644 (file)
index 0000000..b010b9f
--- /dev/null
@@ -0,0 +1,48 @@
+#ifndef THC_DEVICE_ALLOCATOR_INC
+#define THC_DEVICE_ALLOCATOR_INC
+
+#include <c10/cuda/CUDAStream.h>
+#include <c10/core/Allocator.h>
+#include <c10/cuda/CUDAMacros.h>
+
+#include <mutex>
+
+namespace c10 {
+namespace cuda {
+
+// TODO: Turn this into an honest to goodness class. I briefly attempted to do
+// this, but it was a bit irritating to figure out how to also correctly
+// apply pimpl pattern so I didn't have to leak any internal implementation
+// details in the header (CUDACachingAllocator could be made a pimpl, but
+// you also need to appropriately define a class which is a subclass
+// of Allocator. Not impossible, but required a bit more surgery than
+// I wanted to do at the time.)
+//
+// Why is this using a namespace rather than old-style THCCachingAllocator_
+// prefix?  Mostly because it made the HIPify rules easier to write; _ is
+// not counted as a word boundary, so you would otherwise have to list each
+// of these functions.
+
+namespace CUDACachingAllocator {
+
+C10_CUDA_API Allocator* get();
+C10_CUDA_API void emptyCache();
+C10_CUDA_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
+C10_CUDA_API void* getBaseAllocation(void *ptr, size_t *size);
+C10_CUDA_API void recordStream(void *ptr, CUDAStream stream);
+C10_CUDA_API uint64_t currentMemoryAllocated(int device);
+C10_CUDA_API uint64_t maxMemoryAllocated(int device);
+C10_CUDA_API void     resetMaxMemoryAllocated(int device);
+C10_CUDA_API uint64_t currentMemoryCached(int device);
+C10_CUDA_API uint64_t maxMemoryCached(int device);
+C10_CUDA_API void     resetMaxMemoryCached(int device);
+
+C10_CUDA_API std::mutex* getFreeMutex();
+
+C10_CUDA_API std::shared_ptr<void> getIpcDevPtr(std::string handle);
+
+} // namespace CUDACachingAllocator
+
+}} // namespace c10::cuda
+
+#endif
index cb45333..30d5504 100644 (file)
@@ -2231,9 +2231,15 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict([
     ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)),
     ("OptionalCUDAStreamGuard", ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)),
 
+    # Only get needs to be transformed this way; all the other ones can go
+    # straight to the normal versions hip::HIPCachingAllocator
+    ("cuda::CUDACachingAllocator::get", ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH)),
+    ("CUDACachingAllocator::get", ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH)),
+
     # TODO: Undo this special-case; see the header for motivation behind this
     # hack.  It's VERY important this is only applied to PyTorch HIPify.
     ("c10/cuda/CUDAGuard.h", ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH)),
+    ("c10/cuda/CUDACachingAllocator.h", ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH)),
 ])
 
 CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([
@@ -2300,6 +2306,7 @@ C10_MAPPINGS = collections.OrderedDict([
     ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)),
     ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)),
     ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)),
+    ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)),
     ("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)),
     ("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)),
     ("c10/cuda/impl/cuda_cmake_macros.h", ("c10/hip/impl/hip_cmake_macros.h", API_C10)),
@@ -2320,6 +2327,8 @@ C10_MAPPINGS = collections.OrderedDict([
     ("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)),
     ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)),
     ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)),
+    ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)),
+    ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)),
 ])
 
 # NB: C10 mappings are more specific than Caffe2 mappings, so run them
index de0da43..8c25f54 100644 (file)
@@ -375,7 +375,7 @@ static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg)
     return PyErr_Format(PyExc_TypeError, "expected Stream object");
   }
   void* data = self_.data_ptr();
-  at::cuda::THCCachingAllocator_recordStream(data, at::cuda::CUDAStream::unpack(((THCPStream*)arg)->cdata));
+  c10::cuda::CUDACachingAllocator::recordStream(data, at::cuda::CUDAStream::unpack(((THCPStream*)arg)->cdata));
   Py_RETURN_NONE;
 #else
   throw std::runtime_error("PyTorch compiled without CUDA support");
index 320afb7..e5577c0 100644 (file)
@@ -7,7 +7,7 @@
 #include <TH/TH.h>
 #include <ATen/ATen.h>
 #include <ATen/cuda/CUDAContext.h>
-#include <THC/THCCachingAllocator.h>
+#include <c10/cuda/CUDACachingAllocator.h>
 #ifdef USE_NCCL
 #include <nccl.h>
 #endif
@@ -234,7 +234,7 @@ static PyGILState_STATE cudaMutexGILState;
 
 PyObject * THCPModule_cudaLockMutex(PyObject *module)
 {
-  auto mutex = at::cuda::THCCachingAllocator_getCudaFreeMutex();
+  auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex();
   // This has to be a busy loop because we **absolutely need to** hold the GIL
   // or it's a recipe for a deadlock otherwise (if we let other Python threads
   // run while we have the cudaMutex, but not the GIL, they might try to e.g.
@@ -255,7 +255,7 @@ PyObject * THCPModule_cudaLockMutex(PyObject *module)
 
 PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
 {
-  auto mutex = at::cuda::THCCachingAllocator_getCudaFreeMutex();
+  auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex();
   PyGILState_Release(cudaMutexGILState);
   mutex->unlock();
   Py_RETURN_NONE;
@@ -264,7 +264,7 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
 PyObject * THCPModule_emptyCache(PyObject *_unused)
 {
   HANDLE_TH_ERRORS
-  at::cuda::THCCachingAllocator_emptyCache();
+  c10::cuda::CUDACachingAllocator::emptyCache();
   END_HANDLE_TH_ERRORS
   Py_RETURN_NONE;
 }
@@ -274,7 +274,7 @@ PyObject * THCPModule_memoryAllocated(PyObject *_unused, PyObject *arg)
   HANDLE_TH_ERRORS
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_allocated");
   int device = (int) THPUtils_unpackLong(arg);
-  auto memory_allocated = at::cuda::THCCachingAllocator_currentMemoryAllocated(device);
+  auto memory_allocated = c10::cuda::CUDACachingAllocator::currentMemoryAllocated(device);
   return PyLong_FromUnsignedLongLong(memory_allocated);
   END_HANDLE_TH_ERRORS
 }
@@ -284,7 +284,7 @@ PyObject * THCPModule_maxMemoryAllocated(PyObject *_unused, PyObject *arg)
   HANDLE_TH_ERRORS
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to max_memory_allocated");
   int device = (int) THPUtils_unpackLong(arg);
-  auto max_memory_allocated = at::cuda::THCCachingAllocator_maxMemoryAllocated(device);
+  auto max_memory_allocated = c10::cuda::CUDACachingAllocator::maxMemoryAllocated(device);
   return PyLong_FromUnsignedLongLong(max_memory_allocated);
   END_HANDLE_TH_ERRORS
 }
@@ -294,7 +294,7 @@ PyObject * THCPModule_resetMaxMemoryAllocated(PyObject *_unused, PyObject *arg)
   HANDLE_TH_ERRORS
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_allocated");
   int device = (int) THPUtils_unpackLong(arg);
-  at::cuda::THCCachingAllocator_resetMaxMemoryAllocated(device);
+  c10::cuda::CUDACachingAllocator::resetMaxMemoryAllocated(device);
   END_HANDLE_TH_ERRORS
   Py_RETURN_NONE;
 }
@@ -304,7 +304,7 @@ PyObject * THCPModule_memoryCached(PyObject *_unused, PyObject *arg)
   HANDLE_TH_ERRORS
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_cached");
   int device = (int) THPUtils_unpackLong(arg);
-  auto memory_cached = at::cuda::THCCachingAllocator_currentMemoryCached(device);
+  auto memory_cached = c10::cuda::CUDACachingAllocator::currentMemoryCached(device);
   return PyLong_FromUnsignedLongLong(memory_cached);
   END_HANDLE_TH_ERRORS
 }
@@ -314,7 +314,7 @@ PyObject * THCPModule_maxMemoryCached(PyObject *_unused, PyObject *arg)
   HANDLE_TH_ERRORS
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to max_memory_cached");
   int device = (int) THPUtils_unpackLong(arg);
-  auto max_memory_cached = at::cuda::THCCachingAllocator_maxMemoryCached(device);
+  auto max_memory_cached = c10::cuda::CUDACachingAllocator::maxMemoryCached(device);
   return PyLong_FromUnsignedLongLong(max_memory_cached);
   END_HANDLE_TH_ERRORS
 }
@@ -324,7 +324,7 @@ PyObject * THCPModule_resetMaxMemoryCached(PyObject *_unused, PyObject *arg)
   HANDLE_TH_ERRORS
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_cached");
   int device = (int) THPUtils_unpackLong(arg);
-  at::cuda::THCCachingAllocator_resetMaxMemoryCached(device);
+  c10::cuda::CUDACachingAllocator::resetMaxMemoryCached(device);
   END_HANDLE_TH_ERRORS
   Py_RETURN_NONE;
 }
index c895def..c1867b0 100644 (file)
@@ -237,7 +237,7 @@ void broadcast(
   int64_t numel = tensors[0].numel();
 
   std::lock_guard<std::mutex> free_mutex(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
   const auto comms = user_comms.empty() ? _get_communicators(tensors)
                                         : ArrayRef<ncclComm_t>(user_comms);
 
@@ -284,7 +284,7 @@ void reduce(
   ncclDataType_t data_type = _get_data_type(inputs[0].type());
 
   const auto count = inputs[0].numel();
-  std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+  std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
   auto comms_ref = user_comms.empty() ? _get_communicators(inputs)
                                       : ArrayRef<ncclComm_t>(user_comms);
 
index 691444f..4fb98dd 100644 (file)
@@ -191,7 +191,7 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
     ncclDataType_t data_type = _get_data_type(inputs[0].type());
 
     int64_t count = inputs[0].numel();
-    std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+    std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
     auto comms = user_comms.empty() ? _get_communicators(inputs)
                                     : ArrayRef<ncclComm_t>(user_comms);
     at::cuda::OptionalCUDAGuard device_guard;
@@ -271,7 +271,7 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
     ncclDataType_t data_type = _get_data_type(inputs[0].type());
 
     int64_t count = inputs[0].numel();
-    std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+    std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
     auto comms = user_comms.empty() ? _get_communicators(inputs)
                                     : ArrayRef<ncclComm_t>(user_comms);
     at::cuda::OptionalCUDAGuard device_guard;
@@ -334,7 +334,7 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
     ncclDataType_t data_type = _get_data_type(inputs[0].type());
 
     int64_t count = inputs[0].numel() / len;
-    std::lock_guard<std::mutex> lock(*(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+    std::lock_guard<std::mutex> lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
     auto comms = user_comms.empty() ? _get_communicators(inputs)
                                     : ArrayRef<ncclComm_t>(user_comms);
     at::cuda::OptionalCUDAGuard device_guard;
index 9191257..d6187e8 100644 (file)
@@ -225,7 +225,7 @@ static PyObject * THPStorage_(shareCuda)(THPStorage *self)
   THPObjectPtr _offset_bytes(PyLong_FromLong(0));
   if (THWStorage_(data)(LIBRARY_STATE storage)) {
     size_t base_size;
-    void *base_ptr = at::cuda::THCCachingAllocator_getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
+    void *base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation(THWStorage_(data)(LIBRARY_STATE storage), &base_size);
     ptrdiff_t offset_bytes = (char*)storage->data<scalar_t>() - (char*)base_ptr;
 
     cudaIpcMemHandle_t handle;
@@ -282,7 +282,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
   }
   THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
   std::string s_handle = std::string(buffer, handle_size);
-  std::shared_ptr<void> basePtr = at::cuda::THCCaching_CUDAIpcDevptr(s_handle);
+  std::shared_ptr<void> basePtr = c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle);
 
   // Offset the basePtr to reconstruct the real storage
   // devPtr = basePtr + storage_offset
index b26d295..0a04a4a 100644 (file)
@@ -100,7 +100,7 @@ FusedKernelCUDA::FusedKernelCUDA(
   TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
   if (!pctx) {
     std::unique_lock<std::mutex> cudaFreeMutexLock(
-        *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+        *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
     cudaFree(0);
   }
 
index 131b79c..bc211dd 100644 (file)
@@ -412,7 +412,7 @@ void DataChannelNccl::allReduce(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < data.size(); ++i) {
@@ -468,7 +468,7 @@ void DataChannelNccl::allGather(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < input.size(); ++i) {
@@ -525,7 +525,7 @@ void DataChannelNccl::reduce(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < data.size(); ++i) {
@@ -584,7 +584,7 @@ void DataChannelNccl::broadcast(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < data.size(); ++i) {
index dbf9807..5031e6f 100644 (file)
@@ -370,7 +370,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   C10D_NCCL_CHECK(ncclGroupStart());
 
@@ -417,7 +417,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   C10D_NCCL_CHECK(ncclGroupStart());
 
@@ -465,7 +465,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   C10D_NCCL_CHECK(ncclGroupStart());
 
@@ -534,7 +534,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
   at::cuda::OptionalCUDAGuard gpuGuard;
 
   std::unique_lock<std::mutex> cudaFreeMutexLock(
-      *(at::cuda::THCCachingAllocator_getCudaFreeMutex()));
+      *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
 
   C10D_NCCL_CHECK(ncclGroupStart());