Add cuda.reset_max_memory_* (#15985)
authorSsnL <tongzhou.wang.1994@gmail.com>
Mon, 14 Jan 2019 15:28:50 +0000 (07:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 15:31:51 +0000 (07:31 -0800)
Summary:
Addresses #15968
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15985

Differential Revision: D13649916

Pulled By: soumith

fbshipit-source-id: a207aea5709a79dba7a6fc541d0a70103f49efff

aten/src/THC/THCCachingAllocator.cpp
aten/src/THC/THCCachingAllocator.h
docs/source/cuda.rst
docs/source/notes/cuda.rst
test/test_cuda.py
torch/csrc/cuda/Module.cpp
torch/cuda/__init__.py

index 44ebac1..6bf92e1 100644 (file)
@@ -578,6 +578,12 @@ THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device) {
   return caching_allocator.get_stats_for_device(device).max_amount_allocated;
 }
 
+THC_API void THCCachingAllocator_resetMaxMemoryAllocated(int device) {
+  assertValidDevice(device);
+  DeviceStats& stats = caching_allocator.get_stats_for_device(device);
+  stats.max_amount_allocated = stats.amount_allocated;
+}
+
 THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device)
 {
   assertValidDevice(device);
@@ -589,6 +595,12 @@ THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device) {
   return caching_allocator.get_stats_for_device(device).max_amount_cached;
 }
 
+THC_API void THCCachingAllocator_resetMaxMemoryCached(int device) {
+  assertValidDevice(device);
+  DeviceStats& stats = caching_allocator.get_stats_for_device(device);
+  stats.max_amount_cached = stats.amount_cached;
+}
+
 //
 // In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr
 // is called by the receiving process to map the CUDA memory from the sending
index 626694a..491562a 100644 (file)
@@ -21,8 +21,10 @@ THC_API void THCCachingAllocator_recordStream(void *ptr, at::cuda::CUDAStream st
 #endif
 THC_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device);
 THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device);
+THC_API void     THCCachingAllocator_resetMaxMemoryAllocated(int device);
 THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device);
 THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device);
+THC_API void     THCCachingAllocator_resetMaxMemoryCached(int device);
 
 #if (__cplusplus >= 201103L) || (defined(_MSC_VER) && defined(__cplusplus))
 THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
index 6da20ce..4629674 100644 (file)
@@ -46,8 +46,10 @@ Memory management
 .. autofunction:: empty_cache
 .. autofunction:: memory_allocated
 .. autofunction:: max_memory_allocated
+.. autofunction:: reset_max_memory_allocated
 .. autofunction:: memory_cached
 .. autofunction:: max_memory_cached
+.. autofunction:: reset_max_memory_cached
 
 NVIDIA Tools Extension (NVTX)
 -----------------------------
index 212f68e..7cf2fe6 100644 (file)
@@ -74,9 +74,9 @@ You can force synchronous computation by setting environment variable
 operation is actually executed, so the stack trace does not show where it was
 requested.)
 
-As an exception, several functions such as :meth:`~torch.Tensor.to` and 
-:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument, 
-which lets the caller bypass synchronization when it is unnecessary.  
+As an exception, several functions such as :meth:`~torch.Tensor.to` and
+:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument,
+which lets the caller bypass synchronization when it is unnecessary.
 Another exception is CUDA streams, explained below.
 
 CUDA streams
@@ -118,7 +118,7 @@ unused memory managed by the allocator will still show as if used in
 :meth:`~torch.cuda.max_memory_allocated` to monitor memory occupied by
 tensors, and use :meth:`~torch.cuda.memory_cached` and
 :meth:`~torch.cuda.max_memory_cached` to monitor memory managed by the caching
-allocator. Calling :meth:`~torch.cuda.empty_cache` can release all **unused**
+allocator. Calling :meth:`~torch.cuda.empty_cache` releases all **unused**
 cached memory from PyTorch so that those can be used by other GPU applications.
 However, the occupied GPU memory by tensors will not be freed so it can not
 increase the amount of GPU memory available for PyTorch.
index 26eab2a..5130d23 100644 (file)
@@ -667,7 +667,7 @@ class TestCuda(TestCase):
                 #       memory checks below to fail.
                 return torch.cuda.FloatTensor(*size)
 
-        def assert_change(comp=1, empty_cache=False):
+        def assert_change(comp=1, empty_cache=False, reset_max_alloc=False, reset_max_cached=False):
             # comp > 0: increased
             # comp = 0: equal
             # comp < 0: decreased
@@ -702,7 +702,26 @@ class TestCuda(TestCase):
                 self.assertEqual(new_max_c, max_c_arr[0])
                 last_c_arr[0] = new_c
 
+            if reset_max_alloc:
+                torch.cuda.reset_max_memory_allocated(device)
+                self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0])
+                self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0])
+                max_m_arr[0] = last_m_arr[0]
+                self.assertEqual(torch.cuda.memory_cached(device), last_c_arr[0])
+                self.assertEqual(torch.cuda.max_memory_cached(device), max_c_arr[0])
+
+            if reset_max_cached:
+                torch.cuda.reset_max_memory_cached(device)
+                self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0])
+                self.assertEqual(torch.cuda.max_memory_allocated(device), max_m_arr[0])
+                self.assertEqual(torch.cuda.memory_cached(device), last_c_arr[0])
+                self.assertEqual(torch.cuda.max_memory_cached(device), last_c_arr[0])
+                max_c_arr[0] = last_c_arr[0]
+
         assert_change(0)
+        assert_change(0, reset_max_alloc=True)
+        assert_change(0, empty_cache=True)
+        assert_change(0, reset_max_cached=True)
         assert_change(0)
         yield
 
@@ -722,7 +741,7 @@ class TestCuda(TestCase):
         for i in range(5, int(N / 2) + 5):
             # large ones
             tensors2.append(alloc(i, i * 7, i * 9, i * 11))
-            assert_change(1)
+            assert_change(1, reset_max_alloc=(i % 2 == 0), reset_max_cached=(i % 2 == 1))
             yield
 
         tensors2.append(alloc(0, 0, 0))
@@ -742,7 +761,7 @@ class TestCuda(TestCase):
         assert_change(0)
         yield
         del permute
-        assert_change(0)
+        assert_change(0, reset_max_alloc=True)
         yield
 
         for i in range(int(N / 2)):
@@ -757,17 +776,19 @@ class TestCuda(TestCase):
             yield
 
         del tensors2
-        assert_change(-1)
+        assert_change(-1, reset_max_cached=True)
         assert_change(0)
         self.assertEqual(torch.cuda.memory_allocated(device), m1)
         yield True
 
         del tensors1
-        assert_change(-1)
+        assert_change(-1, reset_max_alloc=True)
         self.assertEqual(torch.cuda.memory_allocated(device), m0)
 
-        # test empty_cache
+        # test empty_cache and reset_max_memory_*
         assert_change(0, empty_cache=True)
+        assert_change(0, reset_max_cached=True)
+        assert_change(0, reset_max_alloc=True)
 
     def test_memory_stats(self):
         torch.cuda.empty_cache()
index bc77298..5b18d63 100644 (file)
@@ -269,6 +269,16 @@ PyObject * THCPModule_maxMemoryAllocated(PyObject *_unused, PyObject *arg)
   END_HANDLE_TH_ERRORS
 }
 
+PyObject * THCPModule_resetMaxMemoryAllocated(PyObject *_unused, PyObject *arg)
+{
+  HANDLE_TH_ERRORS
+  THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_allocated");
+  int device = (int) THPUtils_unpackLong(arg);
+  THCCachingAllocator_resetMaxMemoryAllocated(device);
+  END_HANDLE_TH_ERRORS
+  Py_RETURN_NONE;
+}
+
 PyObject * THCPModule_memoryCached(PyObject *_unused, PyObject *arg)
 {
   HANDLE_TH_ERRORS
@@ -289,6 +299,16 @@ PyObject * THCPModule_maxMemoryCached(PyObject *_unused, PyObject *arg)
   END_HANDLE_TH_ERRORS
 }
 
+PyObject * THCPModule_resetMaxMemoryCached(PyObject *_unused, PyObject *arg)
+{
+  HANDLE_TH_ERRORS
+  THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_cached");
+  int device = (int) THPUtils_unpackLong(arg);
+  THCCachingAllocator_resetMaxMemoryCached(device);
+  END_HANDLE_TH_ERRORS
+  Py_RETURN_NONE;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Cuda module initialization
 ////////////////////////////////////////////////////////////////////////////////
@@ -397,8 +417,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
   {"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache,       METH_NOARGS,  nullptr},
   {"_cuda_memoryAllocated", (PyCFunction) THCPModule_memoryAllocated, METH_O,  nullptr},
   {"_cuda_maxMemoryAllocated", (PyCFunction) THCPModule_maxMemoryAllocated, METH_O,  nullptr},
+  {"_cuda_resetMaxMemoryAllocated", (PyCFunction) THCPModule_resetMaxMemoryAllocated, METH_O,  nullptr},
   {"_cuda_memoryCached", (PyCFunction) THCPModule_memoryCached, METH_O,  nullptr},
   {"_cuda_maxMemoryCached", (PyCFunction) THCPModule_maxMemoryCached, METH_O,  nullptr},
+  {"_cuda_resetMaxMemoryCached", (PyCFunction) THCPModule_resetMaxMemoryCached, METH_O,  nullptr},
   {"_cuda_manualSeed",  (PyCFunction)THCPModule_manualSeed,       METH_O,       nullptr},
   {"_cuda_manualSeedAll", (PyCFunction)THCPModule_manualSeedAll,  METH_O,       nullptr},
   {"_cuda_seed",        (PyCFunction)THCPModule_seed,             METH_NOARGS,  nullptr},
index c6abfc0..6534446 100644 (file)
@@ -375,7 +375,7 @@ def empty_cache():
 
 
 def memory_allocated(device=None):
-    r"""Returns the current GPU memory usage by tensors in bytes for a given
+    r"""Returns the current GPU memory occupied by tensors in bytes for a given
     device.
 
     Arguments:
@@ -394,9 +394,15 @@ def memory_allocated(device=None):
 
 
 def max_memory_allocated(device=None):
-    r"""Returns the maximum GPU memory usage by tensors in bytes for a given
+    r"""Returns the maximum GPU memory occupied by tensors in bytes for a given
     device.
 
+    By default, this returns the peak allocated memory since the beginning of
+    this program. :func:`~torch.cuda.reset_max_memory_allocated` can be used to
+    reset the starting point in tracking this metric. For example, these two
+    functions can measure the peak allocated memory usage of each iteration in a
+    training loop.
+
     Arguments:
         device (torch.device or int, optional): selected device. Returns
             statistic for the current device, given by :meth:`~torch.cuda.current_device`,
@@ -410,6 +416,25 @@ def max_memory_allocated(device=None):
     return torch._C._cuda_maxMemoryAllocated(device)
 
 
+def reset_max_memory_allocated(device=None):
+    r"""Resets the starting point in tracking maximum GPU memory occupied by
+    tensors for a given device.
+
+    See :func:`~torch.cuda.max_memory_allocated` for details.
+
+    Arguments:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    device = _get_device_index(device, optional=True)
+    return torch._C._cuda_resetMaxMemoryAllocated(device)
+
+
 def memory_cached(device=None):
     r"""Returns the current GPU memory managed by the caching allocator in bytes
     for a given device.
@@ -431,6 +456,12 @@ def max_memory_cached(device=None):
     r"""Returns the maximum GPU memory managed by the caching allocator in bytes
     for a given device.
 
+    By default, this returns the peak cached memory since the beginning of this
+    program. :func:`~torch.cuda.reset_max_memory_cached` can be used to reset
+    the starting point in tracking this metric. For example, these two functions
+    can measure the peak cached memory amount of each iteration in a training
+    loop.
+
     Arguments:
         device (torch.device or int, optional): selected device. Returns
             statistic for the current device, given by :meth:`~torch.cuda.current_device`,
@@ -444,6 +475,25 @@ def max_memory_cached(device=None):
     return torch._C._cuda_maxMemoryCached(device)
 
 
+def reset_max_memory_cached(device=None):
+    r"""Resets the starting point in tracking maximum GPU memory managed by the
+    caching allocator for a given device.
+
+    See :func:`~torch.cuda.max_memory_cached` for details.
+
+    Arguments:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    device = _get_device_index(device, optional=True)
+    return torch._C._cuda_resetMaxMemoryCached(device)
+
+
 def _host_allocator():
     _lazy_init()
     return torch._C._cuda_cudaHostAllocator()