gradcheck (#14596)
authorWei Yang <weiyang@fb.com>
Fri, 7 Dec 2018 01:58:16 +0000 (17:58 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 7 Dec 2018 02:03:38 +0000 (18:03 -0800)
Summary:
- allow gradcheck to take sparse tensor as input
- sparse output is not allowed yet at gradcheck
- add backward for `to_dense()` to get around sparse output
- calling gradcheck at test_sparse, so that we can use `_gen_sparse()` and also easily cover coalesced / uncoalesced test cases
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14596

Differential Revision: D13271904

Pulled By: weiyangfb

fbshipit-source-id: 5317484104404fd38058884c86e987546011dd86

test/test_autograd.py
test/test_sparse.py
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp
torch/autograd/gradcheck.py

index 68525c4..cde1fbf 100644 (file)
@@ -2704,6 +2704,14 @@ class TestAutograd(TestCase):
         gradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True))
         gradgradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True))
 
+    def test_gradcheck_sparse_input(self):
+        def fn(sparse):
+            return torch.sparse.sum(sparse)
+
+        gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=True)
+        with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'):
+            gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=False)
+
 
 def index_variable(shape, max_indices):
     if not isinstance(shape, tuple):
index 5011f5c..c48c299 100644 (file)
@@ -8,6 +8,7 @@ import unittest
 from common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, do_test_empty_full, load_tests
 from common_cuda import TEST_CUDA
 from numbers import Number
+from torch.autograd.gradcheck import gradcheck
 
 # load_tests from common_utils is used to automatically filter tests for
 # sharding on sandcastle. This line silences flake warnings
@@ -214,6 +215,11 @@ class TestSparse(TestCase):
             self.assertEqual(res, x.to_dense())
             self.assertEqual(res, self.safeToDense(x))
 
+            def fn(x):
+                return x.to_dense()
+            x.requires_grad_(True)
+            gradcheck(fn, (x,), check_sparse_nnz=True)
+
         i = self.IndexTensor([
             [0, 1, 2, 2],
             [0, 0, 0, 3],
@@ -290,6 +296,11 @@ class TestSparse(TestCase):
             self.assertEqual(res, x.to_dense())
             self.assertEqual(res, self.safeToDense(x))
 
+            def fn(x):
+                return x.to_dense()
+            x.requires_grad_(True)
+            gradcheck(fn, (x,), check_sparse_nnz=True)
+
         i = self.IndexTensor([
             [0, 1, 2, 2],
             [0, 0, 0, 3],
@@ -813,13 +824,10 @@ class TestSparse(TestCase):
             S_dense = S.to_dense().requires_grad_(True)
             S.requires_grad_(True)
             self.assertEqual(torch.sparse.addmm(D1, S, D2), torch.addmm(D1, S_dense, D2))
-            y1 = torch.sparse.addmm(D1, S, D2).sum()
-            y2 = torch.addmm(D1, S_dense, D2).sum()
-            y1.backward()
-            y2.backward()
-            mask = (S_dense == 0)
-            self.assertTrue(S.grad.is_coalesced())
-            self.assertEqual(S.grad.to_dense(), S_dense.grad.masked_fill_(mask, 0))
+
+            def fn(S, D1, D2):
+                return torch.sparse.addmm(D1, S, D2)
+            gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
 
         test_shape(7, 8, 9, 20)
 
@@ -831,13 +839,10 @@ class TestSparse(TestCase):
             S_dense = S.to_dense().requires_grad_(True)
             S.requires_grad_(True)
             self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
-            y1 = torch.sparse.mm(S, D).sum()
-            y2 = torch.mm(S_dense, D).sum()
-            y1.backward()
-            y2.backward()
-            mask = (S_dense == 0)
-            self.assertTrue(S.grad.is_coalesced())
-            self.assertEqual(S.grad.to_dense(), S_dense.grad.masked_fill_(mask, 0))
+
+            def fn(S, D):
+                return torch.sparse.mm(S, D)
+            gradcheck(fn, (S, D), check_sparse_nnz=True)
 
         test_shape(7, 8, 9, 20)
 
@@ -963,21 +968,25 @@ class TestSparse(TestCase):
                 S_sum = torch.sparse.sum(S)
                 D_sum = D.sum()
                 self.assertEqual(S_sum, D_sum)
-                S_sum.backward()
-                D_sum.backward()
-                D_grad = D.grad.masked_fill_(mask, 0)
-                self.assertEqual(S.grad.to_dense(), D_grad)
+
+                def fn(S):
+                    res = torch.sparse.sum(S)
+                    if res.is_sparse:
+                        res = res.to_dense()
+                    return res
+                gradcheck(fn, (S,), check_sparse_nnz=True)
+
             else:
                 S_sum = torch.sparse.sum(S, td)
                 D_sum = D.sum(td)
                 self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum)
-                S_sum.backward(S_sum.detach())
-                S_grad = S.grad
-                data = S_sum.to_dense().detach() if S_sum.is_sparse else S_sum.detach()
-                D_sum.backward(data)
-                D_grad = D.grad.masked_fill_(mask, 0)
-                S_grad_dense = S_grad.coalesce().to_dense() if S_grad.is_sparse else S_grad
-                self.assertEqual(S_grad_dense, D_grad)
+
+                def fn(S):
+                    res = torch.sparse.sum(S, td)
+                    if res.is_sparse:
+                        res = res.to_dense()
+                    return res
+                gradcheck(fn, (S,), check_sparse_nnz=True)
 
         nnz = 10
         sparse_dims = 2
index ad83012..f907a98 100644 (file)
 - name: trunc(Tensor self)
   self: zeros_like(grad)
 
+- name: to_dense(Tensor self)
+  self: to_dense_backward(grad, self)
+
 - name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step)
   self: unfold_backward(grad, self.sizes(), dimension, size, step)
 
index 3eb0b49..7998900 100644 (file)
@@ -2055,6 +2055,12 @@ Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const T
   return flattened_dense_grad.index_select(0, flattened_indices);
 }
 
+Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) {
+  AT_ASSERT(input_.is_sparse());
+  auto input = input_.coalesce();
+  return grad.sparse_mask(at::SparseTensorRef(input));
+}
+
 // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
 Tensor constant_pad_nd_backward(const Tensor& grad, IntList pad) {
   auto negated_pad = pad.vec();
index e6dedd3..dbb6cad 100644 (file)
@@ -43,12 +43,14 @@ def iter_tensors(x, only_requiring_grad=False):
                 yield result
 
 
-# `input` is input to `fn`
-# `target` is the Tensors wrt whom Jacobians are calculated (default=`input`)
-#
-# Note that `target` may not even be part of `input` to `fn`, so please be
-# **very careful** in this to not clone `target`.
 def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
+    """
+    input: input to `fn`
+    target: the Tensors wrt whom Jacobians are calculated (default=`input`)
+
+    Note that `target` may not even be part of `input` to `fn`, so please be
+    **very careful** in this to not clone `target`.
+    """
     if target is None:
         target = input
     output_size = fn(input).numel()
@@ -64,22 +66,57 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
     for x_tensor, d_tensor in zip(x_tensors, j_tensors):
         # need data here to get around the version check because without .data,
         # the following code updates version but doesn't change content
-        x_tensor = x_tensor.data
-        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
-            orig = x_tensor[x_idx].item()
-            x_tensor[x_idx] = orig - eps
-            outa = fn(input).clone()
-            x_tensor[x_idx] = orig + eps
-            outb = fn(input).clone()
-            x_tensor[x_idx] = orig
-
-            r = (outb - outa) / (2 * eps)
-            d_tensor[d_idx] = r.detach().reshape(-1)
+        if x_tensor.is_sparse:
+            def get_stride(size):
+                dim = len(size)
+                tmp = 1
+                stride = [0] * dim
+                for i in reversed(range(dim)):
+                    stride[i] = tmp
+                    tmp *= size[i]
+                return stride
+
+            x_nnz = x_tensor._nnz()
+            x_size = list(x_tensor.size())
+            x_indices = x_tensor._indices().t()
+            x_values = x_tensor._values().data
+            x_stride = get_stride(x_size)
+
+            for i in range(x_nnz):
+                x_value = x_values[i]
+                for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
+                    indices = x_indices[i].tolist() + list(x_idx)
+                    d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
+
+                    orig = x_value[x_idx].item()
+                    x_value[x_idx] = orig - eps
+                    outa = fn(input).clone()
+                    x_value[x_idx] = orig + eps
+                    outb = fn(input).clone()
+                    x_value[x_idx] = orig
+                    r = (outb - outa) / (2 * eps)
+                    d_tensor[d_idx] = r.detach().reshape(-1)
+        else:
+            x_tensor = x_tensor.data
+            for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
+                orig = x_tensor[x_idx].item()
+                x_tensor[x_idx] = orig - eps
+                outa = fn(input).clone()
+                x_tensor[x_idx] = orig + eps
+                outb = fn(input).clone()
+                x_tensor[x_idx] = orig
+                r = (outb - outa) / (2 * eps)
+                d_tensor[d_idx] = r.detach().reshape(-1)
 
     return jacobian
 
 
 def get_analytical_jacobian(input, output):
+    # it is easier to call to_dense() on the sparse output than
+    # to modify analytical jacobian
+    if output.is_sparse:
+        raise ValueError('Sparse output is not supported at gradcheck yet. '
+                         'Please call to_dense() on the output of fn for gradcheck.')
     diff_input_list = list(iter_tensors(input, True))
     jacobian = make_jacobian(input, output.numel())
     jacobian_reentrant = make_jacobian(input, output.numel())
@@ -125,7 +162,7 @@ def _differentiable_outputs(x):
     return tuple(o for o in _as_tuple(x) if o.requires_grad)
 
 
-def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True):
+def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True, check_sparse_nnz=False):
     r"""Check gradients computed via small finite differences against analytical
     gradients w.r.t. tensors in :attr:`inputs` that are of floating point type
     and with ``requires_grad=True``.
@@ -154,11 +191,21 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
         raise_exception (bool, optional): indicating whether to raise an exception if
             the check fails. The exception gives more information about the
             exact nature of the failure. This is helpful when debugging gradchecks.
+        check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
+            and for any SparseTensor at input, gradcheck will perform check at nnz positions only.
 
     Returns:
         True if all differences satisfy allclose condition
     """
+    def fail_test(msg):
+        if raise_exception:
+            raise RuntimeError(msg)
+        return False
+
     tupled_inputs = _as_tuple(inputs)
+    if any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)) and not check_sparse_nnz:
+        fail_test('gradcheck expects all tensor inputs '
+                  'are dense when check_sparse_nnz is set to False.')
 
     # Make sure that gradients are saved for all inputs
     any_input_requiring_grad = False
@@ -180,11 +227,6 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
 
     output = _differentiable_outputs(func(*tupled_inputs))
 
-    def fail_test(msg):
-        if raise_exception:
-            raise RuntimeError(msg)
-        return False
-
     for i, o in enumerate(output):
         if not o.requires_grad:
             continue
@@ -220,6 +262,15 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
         for gi, i in zip(grads_input, diff_input_list):
             if gi is None:
                 continue
+            if isinstance(gi, torch.Tensor) and gi.is_sparse:
+                if gi.layout != i.layout:
+                    return fail_test('grad is sparse tensor, but has incorrect layout')
+                if gi.sparse_dim() != i.sparse_dim():
+                    return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
+                if gi.dense_dim() != i.dense_dim():
+                    return fail_test('grad is sparse tensor, but has incorrect dense_dim')
+                gi = gi.to_dense()
+                i = i.to_dense()
             if not gi.eq(0).all():
                 return fail_test('backward not multiplied by grad_output')
             if gi.type() != i.type():