# Should not stack overflow
scope()
+ @unittest.skipIf(not TEST_CUDA, "need CUDA memory stats")
+ def test_free_unneeded_tensor(self):
+ x = torch.randn(2, 3, 10, 10, device='cuda', requires_grad=True)
+ m = torch.randn(1, 3, 1, 1, device='cuda')
+
+ z = x.sum()
+ base_mem = torch.cuda.memory_allocated()
+ z = ((x + 2) * m).sum()
+ end_mem = torch.cuda.memory_allocated()
+
+ # In the end the memory usage should remain equal, because neither of
+ # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the
+ # previous allocation of z had the same size as the current one.
+ self.assertEqual(base_mem, end_mem)
+
+ def test_no_unnecessary_save(self):
+ # If we kept x in the derivative Function of x * 2 we would
+ # get an error in the backward that would complain that we've
+ # modified x, which was needed for gradient computation.
+ # Since we should elide unnecessary saves, this test should pass.
+ mu = torch.ones(1, requires_grad=True)
+ x = torch.empty(1)
+ loss = 0
+ for i in range(3):
+ x.detach_()
+ x.copy_(mu + i)
+ loss += (x * torch.tensor([float(i)])).sum()
+ loss.backward()
+
def test_no_grad(self):
x = torch.ones(5, 5, requires_grad=True)
y = Variable(torch.ones(5, 5) * 4)
if func is not None and not requires_derivative:
print('WARNING: derivative ignored for {}'.format(name), file=sys.stderr)
+ def emit_save_inputs():
+ setup = []
+ if func is None:
+ return setup
+
+ has_tensorlist_arg = any(arg['type'] == 'TensorList' for arg in func['args_with_gradients'])
+
+ # We don't want to save tensors if we know that they will never be used
+ # when computing the derivative, so we add guards to those statements
+ def guard_for(arg):
+ # It's hard to determine the edge offset if we have TensorLists
+ if has_tensorlist_arg:
+ return None
+
+ # Empirical evaluation of the cases where we insert those guards in
+ # backward show that they are somewhat useless. E.g. there's no need
+ # to guard on some values captured from forward, because they had to
+ # require_grad if the backward function even gets executed. I don't
+ # have any good ideas for detecting those cases, so I simply disabled the
+ # checks.
+ if 'backward' in func['name']:
+ return None
+
+ # If there's a single derivative we could compute, we already have
+ # a requires_grad check that is sufficient
+ if len(func['args_with_gradients']) <= 1:
+ return None
+
+ # We really only care about trimming down the amount of tensors we save
+ if arg['type'] != 'Tensor':
+ return None
+
+ # We want to emit simple guards, so we only allow that if checking one
+ # input is enough to determine whether we need that value
+ used_in = [d for d in func['derivatives'] if arg in d['saved_inputs']]
+ assert len(used_in) > 0
+ if len(used_in) != 1:
+ return None
+ derivative = used_in[0]
+ if len(derivative['var_names']) != 1:
+ return None
+ derivative_var_name = derivative['var_names'][0]
+
+ # Figure out the offset of the edge that uses this variable
+ for edge_off, arg in enumerate(func['args_with_gradients']):
+ if arg['name'] == derivative_var_name:
+ break
+ else:
+ assert False
+
+ return 'grad_fn->should_compute_output({})'.format(edge_off)
+
+ setup.extend(save_variables(func['saved_inputs'], False, guard_for))
+ for arg in func['args_with_gradients']:
+ if arg['type'] == 'TensorList':
+ setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name']))
+
+ return setup
+
def setup_derivative():
args_with_derivatives = find_args_with_derivatives()
setup = []
setup.extend(ASSIGN_GRAD_FN.substitute(env).split('\n'))
- if func is not None:
- setup.extend(save_variables(func['saved_inputs'], False))
- for arg in func['args_with_gradients']:
- if arg['type'] == 'TensorList':
- setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name']))
+ setup.extend(emit_save_inputs())
body = []
body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives))
body.append('check_no_requires_grad({}, "{}");'.format(name, name))
return body
- def save_variables(saved_variables, is_output):
+ def save_variables(saved_variables, is_output, guard_for=lambda name: None):
# assign the saved variables to the generated grad_fn
stmts = []
for arg in saved_variables:
expr = 'make_saved_variable_list({})'.format(arg['name'])
elif arg['type'] == 'IntArrayRef':
expr = expr + ".vec()"
- stmts.append('grad_fn->{} = {};'.format(name, expr))
+ guard = guard_for(arg)
+ if guard is None:
+ stmts.append('grad_fn->{} = {};'.format(name, expr))
+ else:
+ stmts.append('if ({}) {{'.format(guard))
+ stmts.append(' grad_fn->{} = {};'.format(name, expr))
+ stmts.append('}')
return stmts
def reference_args(args):