Rename away uses of THAllocator and THCDeviceAllocator (#16061)
authorEdward Yang <ezyang@fb.com>
Wed, 16 Jan 2019 13:33:14 +0000 (05:33 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 16 Jan 2019 13:36:47 +0000 (05:36 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16061

I discovered I needed to delete these names in preparation of moving
THCCachingAllocator to c10_cuda; might as well also fix all the other
sites too.

Reviewed By: dzhulgakov

Differential Revision: D13686869

fbshipit-source-id: e8cc55d39ac4bfd3e3a22c761f89a7a111ce5f5e

aten/src/ATen/CPUFixedAllocator.h
aten/src/TH/THAllocator.h
aten/src/TH/generic/THStorage.h
aten/src/THC/THCCachingAllocator.h
aten/src/THC/THCCachingHostAllocator.h
aten/src/THC/THCGeneral.cpp
aten/src/THC/THCGeneral.h.in
torch/csrc/cuda/Module.cpp
torch/csrc/generic/Storage.cpp

index becdeea..bc0918a 100644 (file)
@@ -25,7 +25,7 @@ static cpu_fixed_free(void * state, void * allocation) {
     delete on_release;
 }
 
-static THAllocator CPU_fixed_allocator =
+static Allocator CPU_fixed_allocator =
   { cpu_fixed_malloc, cpu_fixed_realloc, cpu_fixed_free };
 
 }
index 85b07e0..5413da0 100644 (file)
 #define TH_ALLOCATOR_MAPPED_FROMFD 32
 #define TH_ALLOCATOR_MAPPED_UNLINK 64
 
-using THAllocator = at::Allocator;
-
 /* default malloc/free allocator. malloc and realloc raise an error (using
  * THError) on allocation failure.
  */
-TH_API THAllocator* getTHDefaultAllocator(void);
+TH_API c10::Allocator* getTHDefaultAllocator(void);
 
 // Sentinel value/type to help distinguish the file descriptor constructor from
 // the non-file descriptor constructor
index e7769e2..14dbf25 100644 (file)
@@ -51,7 +51,7 @@ TH_API THStorage* THStorage_(newWithSize4)(scalar_t, scalar_t, scalar_t, scalar_
 TH_API THStorage* THStorage_(newWithMapping)(const char *filename, ptrdiff_t size, int flags);
 
 TH_API THStorage* THStorage_(newWithAllocator)(ptrdiff_t size,
-                                               THAllocator* allocator);
+                                               c10::Allocator* allocator);
 TH_API THStorage* THStorage_(newWithDataAndAllocator)(
     at::DataPtr&& data, ptrdiff_t size, at::Allocator* allocator);
 
index 647c9d6..f9ace9b 100644 (file)
@@ -8,7 +8,7 @@
 
 #include <THC/THCGeneral.h>
 
-THC_API THCDeviceAllocator* THCCachingAllocator_get(void);
+THC_API c10::Allocator* THCCachingAllocator_get(void);
 THC_API void THCCachingAllocator_emptyCache(void);
 THC_API void THCCachingAllocator_cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock);
 THC_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size);
index cc175e1..3868884 100644 (file)
@@ -21,7 +21,7 @@
 // Note that this allocator does not split larger allocations into smaller
 // blocks, unlike the caching device allocator.
 //
-THC_API THAllocator* getTHCCachingHostAllocator(void);
+THC_API c10::Allocator* getTHCCachingHostAllocator(void);
 
 // Records an event in the specified stream. The allocation 'ptr' will not be
 // re-used until the event has occurred.
index 047d994..137343f 100644 (file)
@@ -178,7 +178,7 @@ struct THCRNGState* THCState_getRngState(THCState *state)
   return state->rngState;
 }
 
-THAllocator* THCState_getCudaHostAllocator(THCState* state)
+c10::Allocator* THCState_getCudaHostAllocator(THCState* state)
 {
   return state->cudaHostAllocator;
 }
@@ -381,7 +381,7 @@ void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line
 void* THCudaMalloc(THCState *state, size_t size)
 {
   THCudaCheck(cudaGetLastError());
-  THCDeviceAllocator* allocator = state->cudaDeviceAllocator;
+  c10::Allocator* allocator = state->cudaDeviceAllocator;
   return allocator->raw_allocate(size);
 }
 
@@ -392,7 +392,7 @@ void THCudaFree(THCState *state, void* ptr) {
 at::DataPtr THCudaHostAlloc(THCState *state, size_t size)
 {
   THCudaCheck(cudaGetLastError());
-  THAllocator* allocator = state->cudaHostAllocator;
+  c10::Allocator* allocator = state->cudaHostAllocator;
   return allocator->allocate(size);
 }
 
@@ -405,7 +405,7 @@ void THCudaHostRecord(THCState *state, void *ptr) {
 cudaError_t THCudaMemGetInfo(THCState *state,  size_t* freeBytes, size_t* totalBytes, size_t* largestBlock)
 {
   size_t cachedBytes = 0;
-  THCDeviceAllocator* allocator = state->cudaDeviceAllocator;
+  c10::Allocator* allocator = state->cudaDeviceAllocator;
 
   *largestBlock = 0;
   /* get info from CUDA first */
index 297c987..80e71de 100644 (file)
@@ -46,8 +46,6 @@ struct THCRNGState;  /* Random number generator state. */
 typedef struct THCState THCState;
 struct THCState;
 
-typedef THAllocator THCDeviceAllocator;
-
 typedef struct _THCCudaResourcesPerDevice {
   /* cuBLAS handle is lazily initialized */
   cublasHandle_t blasHandle;
@@ -68,7 +66,7 @@ THC_API void THCudaShutdown(THCState* state);
 THC_API int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess);
 
 THC_API struct THCRNGState* THCState_getRngState(THCState* state);
-THC_API THAllocator* THCState_getCudaHostAllocator(THCState* state);
+THC_API c10::Allocator* THCState_getCudaHostAllocator(THCState* state);
 
 THC_API void THCMagma_init(THCState *state);
 
index 5b18d63..bbace72 100644 (file)
@@ -183,7 +183,7 @@ PyObject * THCPModule_initialSeed(PyObject *_unused)
 PyObject * THCPModule_cudaHostAllocator(PyObject *_unused)
 {
   HANDLE_TH_ERRORS
-  THAllocator* allocator = THCState_getCudaHostAllocator(state);
+  c10::Allocator* allocator = THCState_getCudaHostAllocator(state);
   return PyLong_FromVoidPtr(allocator);
   END_HANDLE_TH_ERRORS
 }
index 765efe3..e8fd9bb 100644 (file)
@@ -40,14 +40,14 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
 
   THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0));
   THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
-  THAllocator* allocator = nullptr;
+  c10::Allocator* allocator = nullptr;
 
   // Internally we allow constructing with a keywoard only argument cdata
   if (kwargs != nullptr) {
     PyObject *allocator_ptr = PyDict_GetItemString(kwargs, "allocator");
     if (allocator_ptr) {
       THPUtils_assert(THPUtils_checkLong(allocator_ptr), "invalid allocator");
-      allocator = (THAllocator*) PyLong_AsVoidPtr(allocator_ptr);
+      allocator = static_cast<c10::Allocator*>(PyLong_AsVoidPtr(allocator_ptr));
       PyDict_DelItemString(kwargs, "allocator");
     }