return caching_allocator.get_stats_for_device(device).max_amount_allocated;
}
+THC_API void THCCachingAllocator_resetMaxMemoryAllocated(int device) {
+ assertValidDevice(device);
+ DeviceStats& stats = caching_allocator.get_stats_for_device(device);
+ stats.max_amount_allocated = stats.amount_allocated;
+}
+
THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device)
{
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).max_amount_cached;
}
+THC_API void THCCachingAllocator_resetMaxMemoryCached(int device) {
+ assertValidDevice(device);
+ DeviceStats& stats = caching_allocator.get_stats_for_device(device);
+ stats.max_amount_cached = stats.amount_cached;
+}
+
//
// In CUDA IPC, sender sends a tensor to receiver, THCCaching_CUDAIpcDevptr
// is called by the receiving process to map the CUDA memory from the sending
#endif
THC_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device);
THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device);
+THC_API void THCCachingAllocator_resetMaxMemoryAllocated(int device);
THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device);
THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device);
+THC_API void THCCachingAllocator_resetMaxMemoryCached(int device);
#if (__cplusplus >= 201103L) || (defined(_MSC_VER) && defined(__cplusplus))
THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
.. autofunction:: empty_cache
.. autofunction:: memory_allocated
.. autofunction:: max_memory_allocated
+.. autofunction:: reset_max_memory_allocated
.. autofunction:: memory_cached
.. autofunction:: max_memory_cached
+.. autofunction:: reset_max_memory_cached
NVIDIA Tools Extension (NVTX)
-----------------------------
operation is actually executed, so the stack trace does not show where it was
requested.)
-As an exception, several functions such as :meth:`~torch.Tensor.to` and
-:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument,
-which lets the caller bypass synchronization when it is unnecessary.
+As an exception, several functions such as :meth:`~torch.Tensor.to` and
+:meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument,
+which lets the caller bypass synchronization when it is unnecessary.
Another exception is CUDA streams, explained below.
CUDA streams
:meth:`~torch.cuda.max_memory_allocated` to monitor memory occupied by
tensors, and use :meth:`~torch.cuda.memory_cached` and
:meth:`~torch.cuda.max_memory_cached` to monitor memory managed by the caching
-allocator. Calling :meth:`~torch.cuda.empty_cache` can release all **unused**
+allocator. Calling :meth:`~torch.cuda.empty_cache` releases all **unused**
cached memory from PyTorch so that those can be used by other GPU applications.
However, the occupied GPU memory by tensors will not be freed so it can not
increase the amount of GPU memory available for PyTorch.
# memory checks below to fail.
return torch.cuda.FloatTensor(*size)
- def assert_change(comp=1, empty_cache=False):
+ def assert_change(comp=1, empty_cache=False, reset_max_alloc=False, reset_max_cached=False):
# comp > 0: increased
# comp = 0: equal
# comp < 0: decreased
self.assertEqual(new_max_c, max_c_arr[0])
last_c_arr[0] = new_c
+ if reset_max_alloc:
+ torch.cuda.reset_max_memory_allocated(device)
+ self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0])
+ self.assertEqual(torch.cuda.max_memory_allocated(device), last_m_arr[0])
+ max_m_arr[0] = last_m_arr[0]
+ self.assertEqual(torch.cuda.memory_cached(device), last_c_arr[0])
+ self.assertEqual(torch.cuda.max_memory_cached(device), max_c_arr[0])
+
+ if reset_max_cached:
+ torch.cuda.reset_max_memory_cached(device)
+ self.assertEqual(torch.cuda.memory_allocated(device), last_m_arr[0])
+ self.assertEqual(torch.cuda.max_memory_allocated(device), max_m_arr[0])
+ self.assertEqual(torch.cuda.memory_cached(device), last_c_arr[0])
+ self.assertEqual(torch.cuda.max_memory_cached(device), last_c_arr[0])
+ max_c_arr[0] = last_c_arr[0]
+
assert_change(0)
+ assert_change(0, reset_max_alloc=True)
+ assert_change(0, empty_cache=True)
+ assert_change(0, reset_max_cached=True)
assert_change(0)
yield
for i in range(5, int(N / 2) + 5):
# large ones
tensors2.append(alloc(i, i * 7, i * 9, i * 11))
- assert_change(1)
+ assert_change(1, reset_max_alloc=(i % 2 == 0), reset_max_cached=(i % 2 == 1))
yield
tensors2.append(alloc(0, 0, 0))
assert_change(0)
yield
del permute
- assert_change(0)
+ assert_change(0, reset_max_alloc=True)
yield
for i in range(int(N / 2)):
yield
del tensors2
- assert_change(-1)
+ assert_change(-1, reset_max_cached=True)
assert_change(0)
self.assertEqual(torch.cuda.memory_allocated(device), m1)
yield True
del tensors1
- assert_change(-1)
+ assert_change(-1, reset_max_alloc=True)
self.assertEqual(torch.cuda.memory_allocated(device), m0)
- # test empty_cache
+ # test empty_cache and reset_max_memory_*
assert_change(0, empty_cache=True)
+ assert_change(0, reset_max_cached=True)
+ assert_change(0, reset_max_alloc=True)
def test_memory_stats(self):
torch.cuda.empty_cache()
END_HANDLE_TH_ERRORS
}
+PyObject * THCPModule_resetMaxMemoryAllocated(PyObject *_unused, PyObject *arg)
+{
+ HANDLE_TH_ERRORS
+ THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_allocated");
+ int device = (int) THPUtils_unpackLong(arg);
+ THCCachingAllocator_resetMaxMemoryAllocated(device);
+ END_HANDLE_TH_ERRORS
+ Py_RETURN_NONE;
+}
+
PyObject * THCPModule_memoryCached(PyObject *_unused, PyObject *arg)
{
HANDLE_TH_ERRORS
END_HANDLE_TH_ERRORS
}
+PyObject * THCPModule_resetMaxMemoryCached(PyObject *_unused, PyObject *arg)
+{
+ HANDLE_TH_ERRORS
+ THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_max_memory_cached");
+ int device = (int) THPUtils_unpackLong(arg);
+ THCCachingAllocator_resetMaxMemoryCached(device);
+ END_HANDLE_TH_ERRORS
+ Py_RETURN_NONE;
+}
+
////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization
////////////////////////////////////////////////////////////////////////////////
{"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, nullptr},
{"_cuda_memoryAllocated", (PyCFunction) THCPModule_memoryAllocated, METH_O, nullptr},
{"_cuda_maxMemoryAllocated", (PyCFunction) THCPModule_maxMemoryAllocated, METH_O, nullptr},
+ {"_cuda_resetMaxMemoryAllocated", (PyCFunction) THCPModule_resetMaxMemoryAllocated, METH_O, nullptr},
{"_cuda_memoryCached", (PyCFunction) THCPModule_memoryCached, METH_O, nullptr},
{"_cuda_maxMemoryCached", (PyCFunction) THCPModule_maxMemoryCached, METH_O, nullptr},
+ {"_cuda_resetMaxMemoryCached", (PyCFunction) THCPModule_resetMaxMemoryCached, METH_O, nullptr},
{"_cuda_manualSeed", (PyCFunction)THCPModule_manualSeed, METH_O, nullptr},
{"_cuda_manualSeedAll", (PyCFunction)THCPModule_manualSeedAll, METH_O, nullptr},
{"_cuda_seed", (PyCFunction)THCPModule_seed, METH_NOARGS, nullptr},
def memory_allocated(device=None):
- r"""Returns the current GPU memory usage by tensors in bytes for a given
+ r"""Returns the current GPU memory occupied by tensors in bytes for a given
device.
Arguments:
def max_memory_allocated(device=None):
- r"""Returns the maximum GPU memory usage by tensors in bytes for a given
+ r"""Returns the maximum GPU memory occupied by tensors in bytes for a given
device.
+ By default, this returns the peak allocated memory since the beginning of
+ this program. :func:`~torch.cuda.reset_max_memory_allocated` can be used to
+ reset the starting point in tracking this metric. For example, these two
+ functions can measure the peak allocated memory usage of each iteration in a
+ training loop.
+
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :meth:`~torch.cuda.current_device`,
return torch._C._cuda_maxMemoryAllocated(device)
+def reset_max_memory_allocated(device=None):
+ r"""Resets the starting point in tracking maximum GPU memory occupied by
+ tensors for a given device.
+
+ See :func:`~torch.cuda.max_memory_allocated` for details.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_resetMaxMemoryAllocated(device)
+
+
def memory_cached(device=None):
r"""Returns the current GPU memory managed by the caching allocator in bytes
for a given device.
r"""Returns the maximum GPU memory managed by the caching allocator in bytes
for a given device.
+ By default, this returns the peak cached memory since the beginning of this
+ program. :func:`~torch.cuda.reset_max_memory_cached` can be used to reset
+ the starting point in tracking this metric. For example, these two functions
+ can measure the peak cached memory amount of each iteration in a training
+ loop.
+
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :meth:`~torch.cuda.current_device`,
return torch._C._cuda_maxMemoryCached(device)
+def reset_max_memory_cached(device=None):
+ r"""Resets the starting point in tracking maximum GPU memory managed by the
+ caching allocator for a given device.
+
+ See :func:`~torch.cuda.max_memory_cached` for details.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_resetMaxMemoryCached(device)
+
+
def _host_allocator():
_lazy_init()
return torch._C._cuda_cudaHostAllocator()