Make checkpoint_sequential work with multiple arguments (#14278)
authorAndy Chen <andersenchen@fb.com>
Wed, 5 Dec 2018 02:45:45 +0000 (18:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 02:47:43 +0000 (18:47 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278

In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.

We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.

Reviewed By: ezyang

Differential Revision: D13144672

fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7

test/test_utils.py
torch/utils/checkpoint.py

index 8cba636..dbaa484 100644 (file)
@@ -48,6 +48,52 @@ class RandomDatasetMock(object):
 
 class TestCheckpoint(TestCase):
 
+    # This runs checkpoint_sequential on each of the nets in
+    # module_lists_to_compare, and compares them against the uncheckpointed model.
+    # To compare, it checks outputs as well as input gradients and parameter gradients
+    def _check_checkpoint_sequential(
+        self,
+        model,
+        module_lists_to_compare,
+        num_chunks,
+        *inputs
+    ):
+
+        # not checkpointed
+        if not isinstance(inputs, tuple):
+            inputs = (inputs,)
+        out = model(*inputs)
+        out_not_checkpointed = out.data.clone()
+        model.zero_grad()
+        out.sum().backward()
+        grad_not_checkpointed = {
+            name: param.grad.data.clone()
+            for name, param in model.named_parameters()
+        }
+        input_grad_not_checkpointed = [i.grad.data.clone() for i in inputs]
+        for model_to_compare in module_lists_to_compare:
+            # checkpointed model by passing list of modules
+            detached_inputs = [i.detach() for i in inputs]
+            for detached in detached_inputs:
+                detached.requires_grad = True
+
+            # pass list of modules to checkpoint
+            out = checkpoint_sequential(model_to_compare, num_chunks, *detached_inputs)
+            out_checkpointed = out.data.clone()
+            model.zero_grad()
+            out.sum().backward()
+            grad_checkpointed = {
+                name: param.grad.data.clone()
+                for name, param in model.named_parameters()
+            }
+            input_grad_checkpointed = [d.grad.data.clone() for d in detached_inputs]
+            # compare outputs as well as the gradients of input and parameters
+            self.assertEqual(out_checkpointed, out_not_checkpointed)
+            for i, j in zip(input_grad_not_checkpointed, input_grad_checkpointed):
+                self.assertEqual(i, j)
+            for name in grad_checkpointed:
+                self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
+
     # Test whether checkpoint is being triggered or not. For this, we check
     # the number of times forward pass happens
     def test_checkpoint_trigger(self):
@@ -107,55 +153,52 @@ class TestCheckpoint(TestCase):
             nn.ReLU()
         )
 
-        x = torch.randn(1, 100, requires_grad=True)
-
-        # not checkpointed
-        out = model(x)
-        out_not_checkpointed = out.data.clone()
-        model.zero_grad()
-        out.sum().backward()
-        grad_not_checkpointed = {}
-        for name, param in model.named_parameters():
-            grad_not_checkpointed[name] = param.grad.data.clone()
-        input_grad = x.grad.data.clone()
+        # Compare uncheckpointed model with its checkpointed counterparts
+        # In addition to running checkpoint_sequential on the nn.Sequential
+        # instance, we also run the function on the list of functions within
+        # the module.
+        self._check_checkpoint_sequential(
+            model,
+            [list(model.children()), model],
+            2,
+            torch.randn(1, 100, requires_grad=True)
+        )
 
-        # checkpointed model by passing list of modules
-        chunks = 2
-        modules = list(model.children())
-        input_var = x.detach()
-        input_var.requires_grad = True
-        # pass list of modules to checkpoint
-        out = checkpoint_sequential(modules, chunks, input_var)
-        out_checkpointed = out.data.clone()
-        model.zero_grad()
-        out.sum().backward()
-        grad_checkpointed = {}
-        for name, param in model.named_parameters():
-            grad_checkpointed[name] = param.grad.data.clone()
-        checkpoint_input_grad = input_var.grad.data.clone()
-        # compare the output, input and parameters gradients
-        self.assertEqual(out_checkpointed, out_not_checkpointed)
-        self.assertEqual(input_grad, checkpoint_input_grad)
-        for name in grad_checkpointed:
-            self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
-
-        # checkpointed by passing sequential directly
-        input_var1 = x.detach()
-        input_var1.requires_grad = True
-        # pass the sequential itself
-        out = checkpoint_sequential(model, 2, input_var1)
-        out_checkpointed = out.data.clone()
-        model.zero_grad()
-        out.sum().backward()
-        grad_checkpointed = {}
-        for name, param in model.named_parameters():
-            grad_checkpointed[name] = param.grad.data.clone()
-        checkpoint_input_grad = input_var1.grad.data.clone()
-        # compare the output, input and parameters gradients
-        self.assertEqual(out_checkpointed, out_not_checkpointed)
-        self.assertEqual(input_grad, checkpoint_input_grad)
-        for name in grad_checkpointed:
-            self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
+    def test_checkpoint_module_list_multiple_args(self):
+        class ModuleListNet(nn.Module):
+            def __init__(self):
+                super(ModuleListNet, self).__init__()
+                module_list = [
+                    nn.Bilinear(100, 60, 50),
+                    nn.ReLU(),
+                    nn.Linear(50, 20),
+                    nn.ReLU(),
+                    nn.Linear(20, 5),
+                    nn.ReLU(),
+                ]
+                self.module_list = nn.ModuleList(module_list)
+
+            def forward(self, *inputs):
+                for layer in self.module_list:
+                    if isinstance(inputs, tuple):
+                        inputs = layer(*inputs)
+                    else:
+                        inputs = layer(inputs)
+                return inputs
+
+        model = ModuleListNet()
+
+        # Compare uncheckpointed model with its checkpointed counterparts
+        # In addition to running checkpoint_sequential on the nn.ModuleList
+        # instance, we also run the function on the list of functions within
+        # the ModuleList.
+        self._check_checkpoint_sequential(
+            model,
+            [list(model.module_list.children()), model.module_list],
+            2,
+            torch.randn(1, 100, requires_grad=True),
+            torch.randn(1, 60, requires_grad=True)
+        )
 
     def test_checkpoint_rng_cpu(self):
         for i in range(5):
index d32f53b..557eabd 100644 (file)
@@ -165,10 +165,12 @@ def checkpoint_sequential(functions, segments, *inputs):
 
     def run_function(start, end, functions):
         def forward(*inputs):
-            input = inputs[0]
             for j in range(start, end + 1):
-                input = functions[j](input)
-            return input
+                if isinstance(inputs, tuple):
+                    inputs = functions[j](*inputs)
+                else:
+                    inputs = functions[j](inputs)
+            return inputs
         return forward
 
     if isinstance(functions, torch.nn.Sequential):