From a5c4348d54cc7736e5574ac56d890695197339ed Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Thu, 18 Apr 2019 14:51:37 -0700 Subject: [PATCH] Recursively find tensors in DDP module output (#19360) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19360 We'll return the output object verbatim since it is a freeform object. We need to find any tensors in this object, though, because we need to figure out which parameters were used during this forward pass, to ensure we short circuit reduction for any unused parameters. Before this commit only lists were handled and the functionality went untested. This commit adds support for dicts and recursive structures, and also adds a test case. Closes #19354. Reviewed By: mrshenli Differential Revision: D14978016 fbshipit-source-id: 4bb6999520871fb6a9e4561608afa64d55f4f3a8 --- test/test_c10d.py | 90 +++++++++++++++++++++++++++++++++ torch/csrc/distributed/c10d/reducer.cpp | 9 ++++ torch/nn/parallel/distributed.py | 23 ++++++--- 3 files changed, 114 insertions(+), 8 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index 3f9dcf6..451498a 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -1979,6 +1979,96 @@ class DistributedDataParallelTest(MultiProcessTestCase): # The expected result of the allreduce should be the average self.assertEqual(grads_batch[0], (torch.ones(10) * (self.world_size + 1) * len(devices) / 2.0).chunk(5)) + @skip_if_not_nccl + @skip_if_not_multigpu + def test_arbitrary_forward_return_value(self): + """ + Note: this test can be sped up by only running it on a CPU module + once DistributedDataParallel supports them. + """ + store = c10d.FileStore(self.file.name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + class ForwardReturnValueModule(nn.Module): + def __init__(self): + super(ForwardReturnValueModule, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 4, bias=False) + self.fc3 = nn.Linear(4, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x, fn): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + # The first softmax does NOT include fc3 in its autograd graph + # whereas the second softmax DOES. If we pass only the first + # tensor we see in the output to the reducer, it marks the + # gradient for fc3 as ready (because it doesn't show up). If + # downstream uses of this return value choose to differentiate + # against the second output tensor, it would still receive a + # gradient and a callback for this tensor, resulting in a crash. + return fn( + F.softmax(x, dim=1), + F.softmax(self.fc3(x), dim=1), + ) + + device_id = gpus_for_rank(self.world_size)[self.rank][0] + model = DistributedDataParallel( + ForwardReturnValueModule().float().to(device_id), + device_ids=[device_id], + process_group=process_group, + ) + + batch_size = 4 + criterion = nn.CrossEntropyLoss() + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id) + + # Always run "backward" to ensure the reducer is called by autograd. + # If we don't correctly capture the output tensors from the return value, + # the reducer won't see a hook for the unused parameter, and throw an error. + # The correct capture is what we're testing in this function. + def test(box, unbox): + output = model(input, fn=box) + loss = criterion(unbox(output), target) + loss.backward() + + # Test with identity return value + test( + box=lambda x, y: (x, y), + unbox=lambda obj: obj[1], + ) + + # Test with list return value + test( + box=lambda x, y: ["foo", x, "bar", y], + unbox=lambda obj: obj[3], + ) + + # Test with tuple return value + test( + box=lambda x, y: ("foo", x, "bar", y), + unbox=lambda obj: obj[3], + ) + + # Test with dict return value + test( + box=lambda x, y: {"foo": "bar", "a": x, "b": y}, + unbox=lambda obj: obj["b"], + ) + + # Test with list with dict return value + test( + box=lambda x, y: ["foo", "bar", {"a": x, "b": y}], + unbox=lambda obj: obj[2]["b"], + ) + + # Test with dict with list return value + test( + box=lambda x, y: {"foo": "bar", "list": [0, x, 1, y]}, + unbox=lambda obj: obj["list"][3], + ) + class ReducerModule(nn.Module): def __init__(self): diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index a7b666c..06f04a9 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -163,6 +163,15 @@ void Reducer::mark_variable_ready( const auto& bucket_index = variable_locators_[variable_index]; auto& bucket = buckets_[bucket_index.bucket_index]; auto& replica = bucket.replicas[replica_index]; + AT_ASSERTM( + replica.pending >= 1, + "Received autograd hook for completed bucket replica ", + "(replica_index: ", + replica_index, + ", variable_index: ", + variable_index, + ")."); + auto& variable = replica.variables[bucket_index.intra_bucket_index]; const auto offset = replica.offsets[bucket_index.intra_bucket_index]; const auto length = replica.lengths[bucket_index.intra_bucket_index]; diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index d2b4fe9..5cb5c53 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1,4 +1,5 @@ import copy +import itertools import torch @@ -15,6 +16,19 @@ from .parallel_apply import parallel_apply from torch.cuda._utils import _get_device_index +def _find_tensors(obj): + r""" + Recursively find all tensors contained in the specified object. + """ + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(_find_tensors, obj.values())) + return [] + + class DistributedDataParallel(Module): r"""Implements distributed data parallelism that is based on ``torch.distributed`` package at the module level. @@ -361,14 +375,7 @@ class DistributedDataParallel(Module): # We need to find any tensors in this object, though, because we need to # figure out which parameters were used during this forward pass, # to ensure we short circuit reduction for any unused parameters. - output_tensors = [] - if isinstance(output, torch.Tensor): - output_tensors = [output] - if isinstance(output, (list, tuple)): - def istensor(obj): - return isinstance(obj, torch.Tensor) - output_tensors = list(filter(istensor, output)) - self.reducer.prepare_for_backward(output_tensors) + self.reducer.prepare_for_backward(list(_find_tensors(output))) return output def scatter(self, inputs, kwargs, device_ids): -- 2.7.4