From: Andy Chen Date: Wed, 5 Dec 2018 02:45:45 +0000 (-0800) Subject: Make checkpoint_sequential work with multiple arguments (#14278) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2470 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=33ea7eafefb6a74d4d87b4e02f8d182640051ffc;p=platform%2Fupstream%2Fpytorch.git Make checkpoint_sequential work with multiple arguments (#14278) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278 In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest. We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors. Reviewed By: ezyang Differential Revision: D13144672 fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7 --- diff --git a/test/test_utils.py b/test/test_utils.py index 8cba636..dbaa484 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -48,6 +48,52 @@ class RandomDatasetMock(object): class TestCheckpoint(TestCase): + # This runs checkpoint_sequential on each of the nets in + # module_lists_to_compare, and compares them against the uncheckpointed model. + # To compare, it checks outputs as well as input gradients and parameter gradients + def _check_checkpoint_sequential( + self, + model, + module_lists_to_compare, + num_chunks, + *inputs + ): + + # not checkpointed + if not isinstance(inputs, tuple): + inputs = (inputs,) + out = model(*inputs) + out_not_checkpointed = out.data.clone() + model.zero_grad() + out.sum().backward() + grad_not_checkpointed = { + name: param.grad.data.clone() + for name, param in model.named_parameters() + } + input_grad_not_checkpointed = [i.grad.data.clone() for i in inputs] + for model_to_compare in module_lists_to_compare: + # checkpointed model by passing list of modules + detached_inputs = [i.detach() for i in inputs] + for detached in detached_inputs: + detached.requires_grad = True + + # pass list of modules to checkpoint + out = checkpoint_sequential(model_to_compare, num_chunks, *detached_inputs) + out_checkpointed = out.data.clone() + model.zero_grad() + out.sum().backward() + grad_checkpointed = { + name: param.grad.data.clone() + for name, param in model.named_parameters() + } + input_grad_checkpointed = [d.grad.data.clone() for d in detached_inputs] + # compare outputs as well as the gradients of input and parameters + self.assertEqual(out_checkpointed, out_not_checkpointed) + for i, j in zip(input_grad_not_checkpointed, input_grad_checkpointed): + self.assertEqual(i, j) + for name in grad_checkpointed: + self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) + # Test whether checkpoint is being triggered or not. For this, we check # the number of times forward pass happens def test_checkpoint_trigger(self): @@ -107,55 +153,52 @@ class TestCheckpoint(TestCase): nn.ReLU() ) - x = torch.randn(1, 100, requires_grad=True) - - # not checkpointed - out = model(x) - out_not_checkpointed = out.data.clone() - model.zero_grad() - out.sum().backward() - grad_not_checkpointed = {} - for name, param in model.named_parameters(): - grad_not_checkpointed[name] = param.grad.data.clone() - input_grad = x.grad.data.clone() + # Compare uncheckpointed model with its checkpointed counterparts + # In addition to running checkpoint_sequential on the nn.Sequential + # instance, we also run the function on the list of functions within + # the module. + self._check_checkpoint_sequential( + model, + [list(model.children()), model], + 2, + torch.randn(1, 100, requires_grad=True) + ) - # checkpointed model by passing list of modules - chunks = 2 - modules = list(model.children()) - input_var = x.detach() - input_var.requires_grad = True - # pass list of modules to checkpoint - out = checkpoint_sequential(modules, chunks, input_var) - out_checkpointed = out.data.clone() - model.zero_grad() - out.sum().backward() - grad_checkpointed = {} - for name, param in model.named_parameters(): - grad_checkpointed[name] = param.grad.data.clone() - checkpoint_input_grad = input_var.grad.data.clone() - # compare the output, input and parameters gradients - self.assertEqual(out_checkpointed, out_not_checkpointed) - self.assertEqual(input_grad, checkpoint_input_grad) - for name in grad_checkpointed: - self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) - - # checkpointed by passing sequential directly - input_var1 = x.detach() - input_var1.requires_grad = True - # pass the sequential itself - out = checkpoint_sequential(model, 2, input_var1) - out_checkpointed = out.data.clone() - model.zero_grad() - out.sum().backward() - grad_checkpointed = {} - for name, param in model.named_parameters(): - grad_checkpointed[name] = param.grad.data.clone() - checkpoint_input_grad = input_var1.grad.data.clone() - # compare the output, input and parameters gradients - self.assertEqual(out_checkpointed, out_not_checkpointed) - self.assertEqual(input_grad, checkpoint_input_grad) - for name in grad_checkpointed: - self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) + def test_checkpoint_module_list_multiple_args(self): + class ModuleListNet(nn.Module): + def __init__(self): + super(ModuleListNet, self).__init__() + module_list = [ + nn.Bilinear(100, 60, 50), + nn.ReLU(), + nn.Linear(50, 20), + nn.ReLU(), + nn.Linear(20, 5), + nn.ReLU(), + ] + self.module_list = nn.ModuleList(module_list) + + def forward(self, *inputs): + for layer in self.module_list: + if isinstance(inputs, tuple): + inputs = layer(*inputs) + else: + inputs = layer(inputs) + return inputs + + model = ModuleListNet() + + # Compare uncheckpointed model with its checkpointed counterparts + # In addition to running checkpoint_sequential on the nn.ModuleList + # instance, we also run the function on the list of functions within + # the ModuleList. + self._check_checkpoint_sequential( + model, + [list(model.module_list.children()), model.module_list], + 2, + torch.randn(1, 100, requires_grad=True), + torch.randn(1, 60, requires_grad=True) + ) def test_checkpoint_rng_cpu(self): for i in range(5): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index d32f53b..557eabd 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -165,10 +165,12 @@ def checkpoint_sequential(functions, segments, *inputs): def run_function(start, end, functions): def forward(*inputs): - input = inputs[0] for j in range(start, end + 1): - input = functions[j](input) - return input + if isinstance(inputs, tuple): + inputs = functions[j](*inputs) + else: + inputs = functions[j](inputs) + return inputs return forward if isinstance(functions, torch.nn.Sequential):