Don't keep unnecessary saved_inputs alive (#16583)
authorAdam Paszke <adam.paszke@gmail.com>
Mon, 11 Feb 2019 21:31:06 +0000 (13:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 11 Feb 2019 21:42:09 +0000 (13:42 -0800)
Summary:
Fixes #16577.

This greatly improves memory efficiency of certain ops like Dropout2d. Previously, they were implemented as `input * mask` where mask never requires_grad, but we didn't use that knowledge in forward, and (in case of a in-place dropout) kept input.clone() for the backward, when it would simply get ignored.

This patch tries to address this situation by emitting some guards for stores like this, but only if they are as simple, as checking if a single value requires_grad.

Interestingly, the same optimizations apply to methods like bmm, baddmm, etc., but _not to mm nor addmm_, because of how their derivatives are defined. Apparently they unnecessarily use `mat1` to compute the derivative of `mat1` just to improve the error message in case `mat1` was sparse. I'd like to apply this optimization to that case, but I don't want to loose the nicer error message, so if anyone has any ideas for solutions, please let me know...

Full list of operators affected by this patch:
* _nnpack_spatial_convolution
* addbmm
* addcdiv
* addcmul
* addmv
* addr
* baddbmm
* bmm
* cross
* div
* dot
* fmod
* ger
* index_add_
* mul
* mv
* scatter_add_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16583

Differential Revision: D13900881

Pulled By: gchanan

fbshipit-source-id: dd0aeb2ab58c4b6aa95b37b46d3255b3e014291c

test/test_autograd.py
tools/autograd/gen_variable_type.py

index 24646cc..6d44e78 100644 (file)
@@ -817,6 +817,35 @@ class TestAutograd(TestCase):
         # Should not stack overflow
         scope()
 
+    @unittest.skipIf(not TEST_CUDA, "need CUDA memory stats")
+    def test_free_unneeded_tensor(self):
+        x = torch.randn(2, 3, 10, 10, device='cuda', requires_grad=True)
+        m = torch.randn(1, 3, 1, 1, device='cuda')
+
+        z = x.sum()
+        base_mem = torch.cuda.memory_allocated()
+        z = ((x + 2) * m).sum()
+        end_mem = torch.cuda.memory_allocated()
+
+        # In the end the memory usage should remain equal, because neither of
+        # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the
+        # previous allocation of z had the same size as the current one.
+        self.assertEqual(base_mem, end_mem)
+
+    def test_no_unnecessary_save(self):
+        # If we kept x in the derivative Function of x * 2 we would
+        # get an error in the backward that would complain that we've
+        # modified x, which was needed for gradient computation.
+        # Since we should elide unnecessary saves, this test should pass.
+        mu = torch.ones(1, requires_grad=True)
+        x = torch.empty(1)
+        loss = 0
+        for i in range(3):
+            x.detach_()
+            x.copy_(mu + i)
+            loss += (x * torch.tensor([float(i)])).sum()
+        loss.backward()
+
     def test_no_grad(self):
         x = torch.ones(5, 5, requires_grad=True)
         y = Variable(torch.ones(5, 5) * 4)
index 2453352..f5e5cc6 100644 (file)
@@ -511,6 +511,65 @@ def emit_body(declaration):
     if func is not None and not requires_derivative:
         print('WARNING: derivative ignored for {}'.format(name), file=sys.stderr)
 
+    def emit_save_inputs():
+        setup = []
+        if func is None:
+            return setup
+
+        has_tensorlist_arg = any(arg['type'] == 'TensorList' for arg in func['args_with_gradients'])
+
+        # We don't want to save tensors if we know that they will never be used
+        # when computing the derivative, so we add guards to those statements
+        def guard_for(arg):
+            # It's hard to determine the edge offset if we have TensorLists
+            if has_tensorlist_arg:
+                return None
+
+            # Empirical evaluation of the cases where we insert those guards in
+            # backward show that they are somewhat useless. E.g. there's no need
+            # to guard on some values captured from forward, because they had to
+            # require_grad if the backward function even gets executed. I don't
+            # have any good ideas for detecting those cases, so I simply disabled the
+            # checks.
+            if 'backward' in func['name']:
+                return None
+
+            # If there's a single derivative we could compute, we already have
+            # a requires_grad check that is sufficient
+            if len(func['args_with_gradients']) <= 1:
+                return None
+
+            # We really only care about trimming down the amount of tensors we save
+            if arg['type'] != 'Tensor':
+                return None
+
+            # We want to emit simple guards, so we only allow that if checking one
+            # input is enough to determine whether we need that value
+            used_in = [d for d in func['derivatives'] if arg in d['saved_inputs']]
+            assert len(used_in) > 0
+            if len(used_in) != 1:
+                return None
+            derivative = used_in[0]
+            if len(derivative['var_names']) != 1:
+                return None
+            derivative_var_name = derivative['var_names'][0]
+
+            # Figure out the offset of the edge that uses this variable
+            for edge_off, arg in enumerate(func['args_with_gradients']):
+                if arg['name'] == derivative_var_name:
+                    break
+            else:
+                assert False
+
+            return 'grad_fn->should_compute_output({})'.format(edge_off)
+
+        setup.extend(save_variables(func['saved_inputs'], False, guard_for))
+        for arg in func['args_with_gradients']:
+            if arg['type'] == 'TensorList':
+                setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name']))
+
+        return setup
+
     def setup_derivative():
         args_with_derivatives = find_args_with_derivatives()
 
@@ -533,11 +592,7 @@ def emit_body(declaration):
 
         setup = []
         setup.extend(ASSIGN_GRAD_FN.substitute(env).split('\n'))
-        if func is not None:
-            setup.extend(save_variables(func['saved_inputs'], False))
-            for arg in func['args_with_gradients']:
-                if arg['type'] == 'TensorList':
-                    setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name']))
+        setup.extend(emit_save_inputs())
 
         body = []
         body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives))
@@ -572,7 +627,7 @@ def emit_body(declaration):
             body.append('check_no_requires_grad({}, "{}");'.format(name, name))
         return body
 
-    def save_variables(saved_variables, is_output):
+    def save_variables(saved_variables, is_output, guard_for=lambda name: None):
         # assign the saved variables to the generated grad_fn
         stmts = []
         for arg in saved_variables:
@@ -592,7 +647,13 @@ def emit_body(declaration):
                 expr = 'make_saved_variable_list({})'.format(arg['name'])
             elif arg['type'] == 'IntArrayRef':
                 expr = expr + ".vec()"
-            stmts.append('grad_fn->{} = {};'.format(name, expr))
+            guard = guard_for(arg)
+            if guard is None:
+                stmts.append('grad_fn->{} = {};'.format(name, expr))
+            else:
+                stmts.append('if ({}) {{'.format(guard))
+                stmts.append('  grad_fn->{} = {};'.format(name, expr))
+                stmts.append('}')
         return stmts
 
     def reference_args(args):