from torch.autograd.gradcheck import gradgradcheck, gradcheck
from torch.autograd.function import once_differentiable
from torch.autograd.profiler import profile
+from torch.utils.checkpoint import checkpoint
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
suppress_warnings, skipIfRocm,
prod_single_zero, random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix,
random_symmetric_pd_matrix, make_nonzero_det,
random_fullrank_matrix_distinct_singular_value, load_tests)
+from common_cuda import TEST_CUDA
from torch.autograd import Variable, Function, detect_anomaly
from torch.autograd.function import InplaceFunction
from torch.testing import make_non_contiguous, randn_like
with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=False)
+ @unittest.skipIf(not TEST_CUDA, "Requires cuda for multi device")
+ def test_multi_device_reentrant_autograd(self):
+ # Output on gpu so that this task will be associated with the gpu thread
+ def fn_on_gpu(inp):
+ # Artificially increase the priority of the next op to make sure it runs
+ # as soon as we reach it before the ops of branch1.
+ dummy = inp * 2 * 2 * 2 * 2
+ return inp.cuda()
+
+ def parent_on_cpu(inp):
+ # Slow branch of ops on gpu so that the work queue for the gpu thread
+ # won't empty too quickly. They also have smaller priorities than the
+ # ones created by fn_on_gpu
+ branch1 = inp.cuda()
+ branch1 = branch1 / branch1
+ branch1 = branch1 / branch1
+ branch1 = branch1 / branch1
+ # Perform checkpoint on cpu tensors. So the last op performed in the reentrant
+ # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread.
+ # So the cpu thread will notify the gpu thread with an empty FunctionTask.
+ branch2 = checkpoint(fn_on_gpu, inp)
+ out = branch2 + branch1
+ return out
+
+ inp = torch.rand(2, requires_grad=True)
+ out = parent_on_cpu(inp)
+ # This will segfault if the empty FunctionTask is not handled properly in the
+ # gpu thread ReadyQueue
+ out.sum().backward()
+
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):