From: Wei Yang Date: Fri, 7 Dec 2018 01:58:16 +0000 (-0800) Subject: gradcheck (#14596) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2408 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1a247f872f6ce9295e849119b234b2de7c0183b2;p=platform%2Fupstream%2Fpytorch.git gradcheck (#14596) Summary: - allow gradcheck to take sparse tensor as input - sparse output is not allowed yet at gradcheck - add backward for `to_dense()` to get around sparse output - calling gradcheck at test_sparse, so that we can use `_gen_sparse()` and also easily cover coalesced / uncoalesced test cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/14596 Differential Revision: D13271904 Pulled By: weiyangfb fbshipit-source-id: 5317484104404fd38058884c86e987546011dd86 --- diff --git a/test/test_autograd.py b/test/test_autograd.py index 68525c4..cde1fbf 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2704,6 +2704,14 @@ class TestAutograd(TestCase): gradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True)) gradgradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True)) + def test_gradcheck_sparse_input(self): + def fn(sparse): + return torch.sparse.sum(sparse) + + gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=True) + with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'): + gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=False) + def index_variable(shape, max_indices): if not isinstance(shape, tuple): diff --git a/test/test_sparse.py b/test/test_sparse.py index 5011f5c..c48c299 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -8,6 +8,7 @@ import unittest from common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, do_test_empty_full, load_tests from common_cuda import TEST_CUDA from numbers import Number +from torch.autograd.gradcheck import gradcheck # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -214,6 +215,11 @@ class TestSparse(TestCase): self.assertEqual(res, x.to_dense()) self.assertEqual(res, self.safeToDense(x)) + def fn(x): + return x.to_dense() + x.requires_grad_(True) + gradcheck(fn, (x,), check_sparse_nnz=True) + i = self.IndexTensor([ [0, 1, 2, 2], [0, 0, 0, 3], @@ -290,6 +296,11 @@ class TestSparse(TestCase): self.assertEqual(res, x.to_dense()) self.assertEqual(res, self.safeToDense(x)) + def fn(x): + return x.to_dense() + x.requires_grad_(True) + gradcheck(fn, (x,), check_sparse_nnz=True) + i = self.IndexTensor([ [0, 1, 2, 2], [0, 0, 0, 3], @@ -813,13 +824,10 @@ class TestSparse(TestCase): S_dense = S.to_dense().requires_grad_(True) S.requires_grad_(True) self.assertEqual(torch.sparse.addmm(D1, S, D2), torch.addmm(D1, S_dense, D2)) - y1 = torch.sparse.addmm(D1, S, D2).sum() - y2 = torch.addmm(D1, S_dense, D2).sum() - y1.backward() - y2.backward() - mask = (S_dense == 0) - self.assertTrue(S.grad.is_coalesced()) - self.assertEqual(S.grad.to_dense(), S_dense.grad.masked_fill_(mask, 0)) + + def fn(S, D1, D2): + return torch.sparse.addmm(D1, S, D2) + gradcheck(fn, (S, D1, D2), check_sparse_nnz=True) test_shape(7, 8, 9, 20) @@ -831,13 +839,10 @@ class TestSparse(TestCase): S_dense = S.to_dense().requires_grad_(True) S.requires_grad_(True) self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D)) - y1 = torch.sparse.mm(S, D).sum() - y2 = torch.mm(S_dense, D).sum() - y1.backward() - y2.backward() - mask = (S_dense == 0) - self.assertTrue(S.grad.is_coalesced()) - self.assertEqual(S.grad.to_dense(), S_dense.grad.masked_fill_(mask, 0)) + + def fn(S, D): + return torch.sparse.mm(S, D) + gradcheck(fn, (S, D), check_sparse_nnz=True) test_shape(7, 8, 9, 20) @@ -963,21 +968,25 @@ class TestSparse(TestCase): S_sum = torch.sparse.sum(S) D_sum = D.sum() self.assertEqual(S_sum, D_sum) - S_sum.backward() - D_sum.backward() - D_grad = D.grad.masked_fill_(mask, 0) - self.assertEqual(S.grad.to_dense(), D_grad) + + def fn(S): + res = torch.sparse.sum(S) + if res.is_sparse: + res = res.to_dense() + return res + gradcheck(fn, (S,), check_sparse_nnz=True) + else: S_sum = torch.sparse.sum(S, td) D_sum = D.sum(td) self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum) - S_sum.backward(S_sum.detach()) - S_grad = S.grad - data = S_sum.to_dense().detach() if S_sum.is_sparse else S_sum.detach() - D_sum.backward(data) - D_grad = D.grad.masked_fill_(mask, 0) - S_grad_dense = S_grad.coalesce().to_dense() if S_grad.is_sparse else S_grad - self.assertEqual(S_grad_dense, D_grad) + + def fn(S): + res = torch.sparse.sum(S, td) + if res.is_sparse: + res = res.to_dense() + return res + gradcheck(fn, (S,), check_sparse_nnz=True) nnz = 10 sparse_dims = 2 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ad83012..f907a98 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -811,6 +811,9 @@ - name: trunc(Tensor self) self: zeros_like(grad) +- name: to_dense(Tensor self) + self: to_dense_backward(grad, self) + - name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step) self: unfold_backward(grad, self.sizes(), dimension, size, step) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 3eb0b49..7998900 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -2055,6 +2055,12 @@ Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const T return flattened_dense_grad.index_select(0, flattened_indices); } +Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) { + AT_ASSERT(input_.is_sparse()); + auto input = input_.coalesce(); + return grad.sparse_mask(at::SparseTensorRef(input)); +} + // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads]) Tensor constant_pad_nd_backward(const Tensor& grad, IntList pad) { auto negated_pad = pad.vec(); diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index e6dedd3..dbb6cad 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -43,12 +43,14 @@ def iter_tensors(x, only_requiring_grad=False): yield result -# `input` is input to `fn` -# `target` is the Tensors wrt whom Jacobians are calculated (default=`input`) -# -# Note that `target` may not even be part of `input` to `fn`, so please be -# **very careful** in this to not clone `target`. def get_numerical_jacobian(fn, input, target=None, eps=1e-3): + """ + input: input to `fn` + target: the Tensors wrt whom Jacobians are calculated (default=`input`) + + Note that `target` may not even be part of `input` to `fn`, so please be + **very careful** in this to not clone `target`. + """ if target is None: target = input output_size = fn(input).numel() @@ -64,22 +66,57 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3): for x_tensor, d_tensor in zip(x_tensors, j_tensors): # need data here to get around the version check because without .data, # the following code updates version but doesn't change content - x_tensor = x_tensor.data - for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): - orig = x_tensor[x_idx].item() - x_tensor[x_idx] = orig - eps - outa = fn(input).clone() - x_tensor[x_idx] = orig + eps - outb = fn(input).clone() - x_tensor[x_idx] = orig - - r = (outb - outa) / (2 * eps) - d_tensor[d_idx] = r.detach().reshape(-1) + if x_tensor.is_sparse: + def get_stride(size): + dim = len(size) + tmp = 1 + stride = [0] * dim + for i in reversed(range(dim)): + stride[i] = tmp + tmp *= size[i] + return stride + + x_nnz = x_tensor._nnz() + x_size = list(x_tensor.size()) + x_indices = x_tensor._indices().t() + x_values = x_tensor._values().data + x_stride = get_stride(x_size) + + for i in range(x_nnz): + x_value = x_values[i] + for x_idx in product(*[range(m) for m in x_values.size()[1:]]): + indices = x_indices[i].tolist() + list(x_idx) + d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) + + orig = x_value[x_idx].item() + x_value[x_idx] = orig - eps + outa = fn(input).clone() + x_value[x_idx] = orig + eps + outb = fn(input).clone() + x_value[x_idx] = orig + r = (outb - outa) / (2 * eps) + d_tensor[d_idx] = r.detach().reshape(-1) + else: + x_tensor = x_tensor.data + for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): + orig = x_tensor[x_idx].item() + x_tensor[x_idx] = orig - eps + outa = fn(input).clone() + x_tensor[x_idx] = orig + eps + outb = fn(input).clone() + x_tensor[x_idx] = orig + r = (outb - outa) / (2 * eps) + d_tensor[d_idx] = r.detach().reshape(-1) return jacobian def get_analytical_jacobian(input, output): + # it is easier to call to_dense() on the sparse output than + # to modify analytical jacobian + if output.is_sparse: + raise ValueError('Sparse output is not supported at gradcheck yet. ' + 'Please call to_dense() on the output of fn for gradcheck.') diff_input_list = list(iter_tensors(input, True)) jacobian = make_jacobian(input, output.numel()) jacobian_reentrant = make_jacobian(input, output.numel()) @@ -125,7 +162,7 @@ def _differentiable_outputs(x): return tuple(o for o in _as_tuple(x) if o.requires_grad) -def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True): +def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True, check_sparse_nnz=False): r"""Check gradients computed via small finite differences against analytical gradients w.r.t. tensors in :attr:`inputs` that are of floating point type and with ``requires_grad=True``. @@ -154,11 +191,21 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True raise_exception (bool, optional): indicating whether to raise an exception if the check fails. The exception gives more information about the exact nature of the failure. This is helpful when debugging gradchecks. + check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input, + and for any SparseTensor at input, gradcheck will perform check at nnz positions only. Returns: True if all differences satisfy allclose condition """ + def fail_test(msg): + if raise_exception: + raise RuntimeError(msg) + return False + tupled_inputs = _as_tuple(inputs) + if any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)) and not check_sparse_nnz: + fail_test('gradcheck expects all tensor inputs ' + 'are dense when check_sparse_nnz is set to False.') # Make sure that gradients are saved for all inputs any_input_requiring_grad = False @@ -180,11 +227,6 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True output = _differentiable_outputs(func(*tupled_inputs)) - def fail_test(msg): - if raise_exception: - raise RuntimeError(msg) - return False - for i, o in enumerate(output): if not o.requires_grad: continue @@ -220,6 +262,15 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True for gi, i in zip(grads_input, diff_input_list): if gi is None: continue + if isinstance(gi, torch.Tensor) and gi.is_sparse: + if gi.layout != i.layout: + return fail_test('grad is sparse tensor, but has incorrect layout') + if gi.sparse_dim() != i.sparse_dim(): + return fail_test('grad is sparse tensor, but has incorrect sparse_dim') + if gi.dense_dim() != i.dense_dim(): + return fail_test('grad is sparse tensor, but has incorrect dense_dim') + gi = gi.to_dense() + i = i.to_dense() if not gi.eq(0).all(): return fail_test('backward not multiplied by grad_output') if gi.type() != i.type():