class TestCheckpoint(TestCase):
+ # This runs checkpoint_sequential on each of the nets in
+ # module_lists_to_compare, and compares them against the uncheckpointed model.
+ # To compare, it checks outputs as well as input gradients and parameter gradients
+ def _check_checkpoint_sequential(
+ self,
+ model,
+ module_lists_to_compare,
+ num_chunks,
+ *inputs
+ ):
+
+ # not checkpointed
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+ out = model(*inputs)
+ out_not_checkpointed = out.data.clone()
+ model.zero_grad()
+ out.sum().backward()
+ grad_not_checkpointed = {
+ name: param.grad.data.clone()
+ for name, param in model.named_parameters()
+ }
+ input_grad_not_checkpointed = [i.grad.data.clone() for i in inputs]
+ for model_to_compare in module_lists_to_compare:
+ # checkpointed model by passing list of modules
+ detached_inputs = [i.detach() for i in inputs]
+ for detached in detached_inputs:
+ detached.requires_grad = True
+
+ # pass list of modules to checkpoint
+ out = checkpoint_sequential(model_to_compare, num_chunks, *detached_inputs)
+ out_checkpointed = out.data.clone()
+ model.zero_grad()
+ out.sum().backward()
+ grad_checkpointed = {
+ name: param.grad.data.clone()
+ for name, param in model.named_parameters()
+ }
+ input_grad_checkpointed = [d.grad.data.clone() for d in detached_inputs]
+ # compare outputs as well as the gradients of input and parameters
+ self.assertEqual(out_checkpointed, out_not_checkpointed)
+ for i, j in zip(input_grad_not_checkpointed, input_grad_checkpointed):
+ self.assertEqual(i, j)
+ for name in grad_checkpointed:
+ self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
+
# Test whether checkpoint is being triggered or not. For this, we check
# the number of times forward pass happens
def test_checkpoint_trigger(self):
nn.ReLU()
)
- x = torch.randn(1, 100, requires_grad=True)
-
- # not checkpointed
- out = model(x)
- out_not_checkpointed = out.data.clone()
- model.zero_grad()
- out.sum().backward()
- grad_not_checkpointed = {}
- for name, param in model.named_parameters():
- grad_not_checkpointed[name] = param.grad.data.clone()
- input_grad = x.grad.data.clone()
+ # Compare uncheckpointed model with its checkpointed counterparts
+ # In addition to running checkpoint_sequential on the nn.Sequential
+ # instance, we also run the function on the list of functions within
+ # the module.
+ self._check_checkpoint_sequential(
+ model,
+ [list(model.children()), model],
+ 2,
+ torch.randn(1, 100, requires_grad=True)
+ )
- # checkpointed model by passing list of modules
- chunks = 2
- modules = list(model.children())
- input_var = x.detach()
- input_var.requires_grad = True
- # pass list of modules to checkpoint
- out = checkpoint_sequential(modules, chunks, input_var)
- out_checkpointed = out.data.clone()
- model.zero_grad()
- out.sum().backward()
- grad_checkpointed = {}
- for name, param in model.named_parameters():
- grad_checkpointed[name] = param.grad.data.clone()
- checkpoint_input_grad = input_var.grad.data.clone()
- # compare the output, input and parameters gradients
- self.assertEqual(out_checkpointed, out_not_checkpointed)
- self.assertEqual(input_grad, checkpoint_input_grad)
- for name in grad_checkpointed:
- self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
-
- # checkpointed by passing sequential directly
- input_var1 = x.detach()
- input_var1.requires_grad = True
- # pass the sequential itself
- out = checkpoint_sequential(model, 2, input_var1)
- out_checkpointed = out.data.clone()
- model.zero_grad()
- out.sum().backward()
- grad_checkpointed = {}
- for name, param in model.named_parameters():
- grad_checkpointed[name] = param.grad.data.clone()
- checkpoint_input_grad = input_var1.grad.data.clone()
- # compare the output, input and parameters gradients
- self.assertEqual(out_checkpointed, out_not_checkpointed)
- self.assertEqual(input_grad, checkpoint_input_grad)
- for name in grad_checkpointed:
- self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
+ def test_checkpoint_module_list_multiple_args(self):
+ class ModuleListNet(nn.Module):
+ def __init__(self):
+ super(ModuleListNet, self).__init__()
+ module_list = [
+ nn.Bilinear(100, 60, 50),
+ nn.ReLU(),
+ nn.Linear(50, 20),
+ nn.ReLU(),
+ nn.Linear(20, 5),
+ nn.ReLU(),
+ ]
+ self.module_list = nn.ModuleList(module_list)
+
+ def forward(self, *inputs):
+ for layer in self.module_list:
+ if isinstance(inputs, tuple):
+ inputs = layer(*inputs)
+ else:
+ inputs = layer(inputs)
+ return inputs
+
+ model = ModuleListNet()
+
+ # Compare uncheckpointed model with its checkpointed counterparts
+ # In addition to running checkpoint_sequential on the nn.ModuleList
+ # instance, we also run the function on the list of functions within
+ # the ModuleList.
+ self._check_checkpoint_sequential(
+ model,
+ [list(model.module_list.children()), model.module_list],
+ 2,
+ torch.randn(1, 100, requires_grad=True),
+ torch.randn(1, 60, requires_grad=True)
+ )
def test_checkpoint_rng_cpu(self):
for i in range(5):