import math
import multiprocessing
import os
+import random
import sys
import tempfile
import time
import unittest
from datetime import timedelta
+from itertools import groupby
from functools import wraps
from collections import namedtuple
from torch import nn
import torch.nn.functional as F
import torch.distributed as c10d
+import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from common_utils import TestCase, load_tests, run_tests
self.assertEqual(grads_batch[0], (torch.ones(10) * (self.world_size + 1) * len(devices) / 2.0).chunk(5))
+class ReducerModule(nn.Module):
+ def __init__(self):
+ super(ReducerModule, self).__init__()
+ self.fc1 = nn.Linear(2, 10, bias=False)
+ self.fc2 = nn.Linear(10, 4, bias=False)
+ self.fc3 = nn.Linear(4, 4, bias=False)
+ self.relu = nn.ReLU()
+
+ def forward(self, x, use_fc3=True):
+ x = self.relu(self.fc1(x)).float()
+ x = self.relu(self.fc2(x)).float()
+ if use_fc3:
+ x = self.fc3(x).float()
+ return F.softmax(x, dim=1)
+
+
+class ReducerTest(TestCase):
+ def setUp(self):
+ self.store = c10d.FileStore("/dev/null", 1)
+ self.process_group = c10d.ProcessGroupGloo(self.store, 0, 1)
+
+ def test_single_dtype_single_bucket(self):
+ model = ReducerModule()
+ parameters = list(model.parameters())
+ buckets = [list(range(len(parameters)))]
+ dist.Reducer([parameters], buckets, self.process_group)
+
+ def _create_mixed_precision_model(self):
+ model = ReducerModule()
+ model.float()
+ model.fc1.double()
+ return model
+
+ def test_multi_dtype_single_bucket(self):
+ model = self._create_mixed_precision_model()
+
+ # Raise if there are multiple types per bucket.
+ # In this case we create one bucket for all parameters.
+ with self.assertRaises(RuntimeError):
+ parameters = [list(model.parameters())]
+ buckets = [list(range(len(parameters[0])))]
+ dist.Reducer(parameters, buckets, self.process_group)
+
+ def test_multi_dtype_multi_bucket(self):
+ model = self._create_mixed_precision_model()
+ parameters = [list(model.parameters())]
+ group_by_type = groupby(
+ range(len(parameters[0])),
+ key=lambda i: parameters[0][i].type())
+ buckets = [list(indices) for _, indices in group_by_type]
+ dist.Reducer(parameters, buckets, self.process_group)
+
+ def _create_reducer_for_models(self, models):
+ parameters = [list(model.parameters()) for model in models]
+ group_by_type = groupby(
+ range(len(parameters[0])),
+ key=lambda i: parameters[0][i].type())
+ buckets = [list(indices) for _, indices in group_by_type]
+ return dist.Reducer(parameters, buckets, self.process_group)
+
+ def test_forward_backward_single_replica(self):
+ batch_size = 10
+ model = self._create_mixed_precision_model()
+ reducer = self._create_reducer_for_models([model])
+ loss = nn.CrossEntropyLoss()
+ input = torch.rand([batch_size, 2], dtype=torch.double)
+ target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+ output = loss(model(input), target)
+ reducer.prepare_for_backward(output)
+ output.backward()
+
+ def test_forward_backward_multi_replica(self):
+ batch_size = 10
+ num_replicas = 2
+ models = [self._create_mixed_precision_model() for _ in range(num_replicas)]
+ reducer = self._create_reducer_for_models(models)
+ loss = nn.CrossEntropyLoss()
+ input = torch.rand([batch_size, 2], dtype=torch.double).chunk(num_replicas)
+ target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+ outputs = [models[i](input[i]) for i in range(num_replicas)]
+ output = loss(torch.cat(outputs), target)
+ reducer.prepare_for_backward(output)
+ output.backward()
+
+ # The reducer will have reduced the gradients for all model replicas.
+ # Verify that they are equal across model replicas.
+ for parameters in zip(*[model.parameters() for model in models]):
+ for parameter in parameters:
+ self.assertEqual(parameters[0].grad, parameter.grad)
+
+ def test_forward_backward_unused_parameters(self):
+ batch_size = 10
+ model = self._create_mixed_precision_model()
+ reducer = self._create_reducer_for_models([model])
+ loss = nn.CrossEntropyLoss()
+ input = torch.rand([batch_size, 2], dtype=torch.double)
+ target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+ output = loss(model(input, use_fc3=False), target)
+
+ # Check that the grad of fc3 is not set.
+ self.assertEqual(None, model.fc3.weight.grad)
+
+ # Compute and accumulate gradients.
+ reducer.prepare_for_backward(output)
+ output.backward()
+
+ # The reducer will have marked the grad of fc3 as ready, because
+ # it doesn't show up in the autograd graph of `output`.
+ # This should result in its contents being equal to zero.
+ self.assertEqual(torch.zeros(model.fc3.weight.size()), model.fc3.weight.grad)
+
+ def test_forward_backward_optimizer(self):
+ batch_size = 10
+ model = self._create_mixed_precision_model()
+ reducer = self._create_reducer_for_models([model])
+ loss = nn.CrossEntropyLoss()
+ optimizer = torch.optim.Adam(model.parameters())
+ for i in range(3):
+ input = torch.rand([batch_size, 2], dtype=torch.double)
+ target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+
+ # The `zero_grad` function calls `detach_` and `zero_` on the grad
+ # tensors of model parameters. If we tried to set the grad tensors
+ # to a view of the reducer's bucket tensors, this would blow up.
+ optimizer.zero_grad()
+
+ # Unused parameter only in the first iteration.
+ output = loss(model(input, use_fc3=(i > 0)), target)
+ reducer.prepare_for_backward(output)
+ output.backward()
+ optimizer.step()
+
+
if __name__ == '__main__':
assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"
} // namespace
Reducer::Reducer(
- std::vector<std::vector<torch::autograd::Variable>> inputs,
+ std::vector<std::vector<torch::autograd::Variable>> replicas,
+ std::vector<std::vector<size_t>> bucket_indices,
std::shared_ptr<c10d::ProcessGroup> process_group)
- : process_group_(std::move(process_group)),
+ : replicas_(std::move(replicas)),
+ process_group_(std::move(process_group)),
expect_autograd_hooks_(false),
has_queued_final_callback_(false),
next_bucket_(0),
backward_stats_base_(0) {
- AT_ASSERTM(inputs.size() >= 1, "Expected at least one model replica.");
- AT_ASSERTM(inputs[0].size() >= 1, "Expected at least one parameter.");
+ AT_ASSERTM(replicas_.size() >= 1, "Expected at least one model replica.");
+ AT_ASSERTM(replicas_[0].size() >= 1, "Expected at least one parameter.");
// Verify that all specified variables require gradients,
// and that they have the same size across replicas.
{
- const auto replica_count = inputs.size();
- variables_.resize(replica_count);
+ const auto replica_count = replicas_.size();
for (size_t replica_index = 0; replica_index < replica_count;
replica_index++) {
- const auto variable_count = inputs[replica_index].size();
- variables_[replica_index].resize(variable_count);
+ const auto variable_count = replicas_[replica_index].size();
AT_ASSERTM(
- variables_[replica_index].size() == variables_[0].size(),
+ replicas_[replica_index].size() == replicas_[0].size(),
"Model replicas must have an equal number of parameters.");
for (size_t variable_index = 0; variable_index < variable_count;
variable_index++) {
- variables_[replica_index][variable_index] =
- inputs[replica_index][variable_index];
AT_ASSERTM(
- variables_[replica_index][variable_index].requires_grad(),
+ replicas_[replica_index][variable_index].requires_grad(),
"Variables must require gradients (have `requires_grad` set).");
AT_ASSERTM(
- variables_[replica_index][variable_index].sizes() ==
- variables_[0][variable_index].sizes(),
+ replicas_[replica_index][variable_index].sizes() ==
+ replicas_[0][variable_index].sizes(),
"Variables across model replicas must have identical sizes.");
AT_ASSERTM(
- variables_[replica_index][variable_index].dtype() ==
- variables_[0][variable_index].dtype(),
+ replicas_[replica_index][variable_index].dtype() ==
+ replicas_[0][variable_index].dtype(),
"Variables across model replicas must have identical dtype.");
}
}
}
+ // Initialize variable bucketing.
+ // This can be reinitialized later after capturing runtime information.
+ initialize_buckets(std::move(bucket_indices));
+
// All variables are expected to have their `grad_fn` set to the gradient
// accumulation function (since they are leafs in the autograd graph).
// We store pointers to these functions such that we can check if they are
// used in an autograd pass. If they are not, we know their grad tensors
// can be marked as ready for reduction.
{
- const auto replica_count = variables_.size();
+ const auto replica_count = replicas_.size();
grad_accumulators_.resize(replica_count);
for (size_t replica_index = 0; replica_index < replica_count;
replica_index++) {
- const auto variable_count = variables_[replica_index].size();
+ const auto variable_count = replicas_[replica_index].size();
grad_accumulators_[replica_index].resize(variable_count);
for (size_t variable_index = 0; variable_index < variable_count;
variable_index++) {
- auto& variable = variables_[replica_index][variable_index];
+ auto& variable = replicas_[replica_index][variable_index];
// The gradient accumulator function is lazily initialized once.
// Therefore we can use its presence in the autograd graph as
// Hook to execute after the gradient accumulator has executed.
grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
- this->mark_variable_ready(replica_index, variable_index);
+ this->mark_variable_ready(
+ replica_index, variable_index, /* called_from_autograd= */ true);
}));
// Map raw function pointer to replica index and parameter index.
}
}
- // Initialize bucketing with naive approach where all variables
- // are put in a single bucket. This is the equivalent of sequencing
- // autograd and reduction. The user is expected to override this with
- // something more clever, if applicable, by a call to `initialize_buckets`.
- {
- std::vector<size_t> variable_indices(variables_[0].size());
- std::iota(std::begin(variable_indices), std::end(variable_indices), 0);
- initialize_buckets(
- std::vector<std::vector<size_t>>{std::move(variable_indices)});
- }
-
// Initialize backward stats vector.
{
- const auto replica_count = inputs.size();
+ const auto replica_count = replicas_.size();
backward_stats_.resize(replica_count);
- const auto variable_count = inputs[0].size();
+ const auto variable_count = replicas_[0].size();
std::for_each(
backward_stats_.begin(),
backward_stats_.end(),
// - By an autograd thread after executing a gradient accumulator function.
// - By the `Reducer::prepare_for_backward` function if the variable doesn't
// show up in the autograd graph (and it wouldn't be called by autograd).
-void Reducer::mark_variable_ready(size_t replica_index, size_t variable_index) {
+void Reducer::mark_variable_ready(
+ size_t replica_index,
+ size_t variable_index,
+ bool called_from_autograd) {
// Ignore if we don't expect to be called.
// This may be the case if the user wants to accumulate gradients
// for number of iterations before reducing them.
return;
}
- AT_ASSERTM(replica_index < variables_.size(), "Out of range replica index.");
+ AT_ASSERTM(replica_index < replicas_.size(), "Out of range replica index.");
AT_ASSERTM(
- variable_index < bucket_indices_.size(), "Out of range variable index.");
+ variable_index < variable_locators_.size(),
+ "Out of range variable index.");
backward_stats_[replica_index][variable_index] =
current_time_in_nanos() - backward_stats_base_;
- const auto& bucket_index = bucket_indices_[variable_index];
+ const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[replica_index];
auto& variable = replica.variables[bucket_index.intra_bucket_index];
auto bucket_view = replica.contents.narrow(0, offset, length);
auto& grad = variable.grad();
if (grad.defined()) {
+ // Assert that the grad tensor and the bucket don't share storage.
+ // If they did, we could avoid the copy altogether.
+ // The reason for not doing this is that existing code calls
+ // `detach_` from `zero_grad`, which is incompatible with views.
+ AT_ASSERT(!grad.is_alias_of(bucket_view));
AT_ASSERT(grad.type() == variable.type());
- AT_ASSERT(grad.get_device() == variable.get_device());
+ AT_ASSERT(grad.device() == variable.device());
AT_ASSERT(grad.numel() == length);
bucket_view.copy_(grad.view({-1}), /* non_blocking */ true);
} else {
// Autograd callbacks can only be registered while the engine is running.
// Register this reducer's final callback once per backward pass.
- if (!has_queued_final_callback_) {
+ if (!has_queued_final_callback_ && called_from_autograd) {
has_queued_final_callback_ = true;
torch::autograd::Engine::get_default_engine().queue_callback(
[=] { this->finalize_backward(); });
}
}
-void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
+void Reducer::initialize_buckets(
+ std::vector<std::vector<size_t>> bucket_indices) {
std::lock_guard<std::mutex> lock(mutex_);
// This shouldn't be called if we're expecting autograd hooks to fire.
// Clear current bucket assignment.
buckets_.clear();
- bucket_indices_.clear();
+ variable_locators_.clear();
// Ensure we have a bucket index for every variable.
- bucket_indices_.resize(variables_[0].size());
+ variable_locators_.resize(replicas_[0].size());
// Iterate over buckets.
- const auto bucket_count = indices.size();
- const auto replica_count = variables_.size();
+ const auto bucket_count = bucket_indices.size();
+ const auto replica_count = replicas_.size();
buckets_.reserve(bucket_count);
for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
Bucket bucket;
// TODO(@pietern): Validate indices.
// Must be non-empty, unique, and unique across buckets.
- AT_ASSERTM(indices[bucket_index].size() > 0, "Empty bucket specified.");
+ AT_ASSERTM(
+ bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");
// Iterate over model replicas.
for (size_t replica_index = 0; replica_index < replica_count;
size_t offset = 0;
// Iterate over bucket variables.
- for (const auto variable_index : indices[bucket_index]) {
+ for (const auto variable_index : bucket_indices[bucket_index]) {
AT_ASSERTM(
- variable_index < variables_[replica_index].size(),
+ variable_index < replicas_[replica_index].size(),
"Out of range variable index specified.");
- const auto& variable = variables_[replica_index][variable_index];
+ const auto& variable = replicas_[replica_index][variable_index];
if (!options.has_device()) {
options = options.device(variable.device());
} else {
}
// Allocate bucket contents tensor.
+ // This must be a Variable because as of Apr 2019 there is still
+ // a distinction between the Tensor and Variable types, and it
+ // is not recommended (or sometimes even possible) to mix and match.
replica.contents = torch::autograd::make_variable_consuming(
at::empty({static_cast<long>(offset)}, options));
// Map participating variables to this bucket.
// This is identical across replicas so we only need to do this once.
size_t intra_bucket_index = 0;
- for (const auto variable_index : indices[bucket_index]) {
+ for (const auto variable_index : bucket_indices[bucket_index]) {
AT_ASSERTM(
- variable_index < bucket_indices_.size(),
+ variable_index < variable_locators_.size(),
"Out of range variable index specified.");
- bucket_indices_[variable_index] = BucketIndex{
+ variable_locators_[variable_index] = VariableLocator{
.bucket_index = bucket_index,
.intra_bucket_index = intra_bucket_index++,
};
auto bucket_view =
replica.contents.narrow(0, offset, length).view(variable.sizes());
auto& grad = variable.grad();
- if (grad.defined()) {
- grad.copy_(bucket_view);
- } else {
- grad = torch::autograd::make_variable(bucket_view);
+ if (!grad.defined()) {
+ grad = at::empty(bucket_view.sizes(), bucket_view.options());
}
+ grad.copy_(bucket_view);
}
}
}
class Reducer {
public:
- // The constructor takes a vector<Variable> with model parameters for
- // every model replica, hence the vector<vector<>>.
+ // The constructor takes a list of variables for every model replica.
+ // The bucket assignment for this reducer is specified as a list of
+ // buckets, each of which is specified as a list of indices into the
+ // variables list for **a single replica** (i.e. `variables[0]`).
explicit Reducer(
- std::vector<std::vector<torch::autograd::Variable>> variables,
+ std::vector<std::vector<torch::autograd::Variable>> replicas,
+ std::vector<std::vector<size_t>> bucket_indices,
std::shared_ptr<c10d::ProcessGroup> process_group);
// To (re-)initialize bucket assignment, pass a list of buckets, each
// of which is specified by a list of indices in the variables list.
// This function performs validation that the variables within a bucket
// all live on the same device and have the same dimensionality.
- void initialize_buckets(std::vector<std::vector<size_t>> indices);
+ void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
// This function is called when the forward function has produced an output,
// and the user wishes to reduce gradients in the backwards pass.
protected:
std::mutex mutex_;
- std::vector<std::vector<torch::autograd::Variable>> variables_;
+ std::vector<std::vector<torch::autograd::Variable>> replicas_;
std::shared_ptr<c10d::ProcessGroup> process_group_;
std::vector<std::vector<std::shared_ptr<torch::autograd::Function>>>
bool has_queued_final_callback_;
size_t next_bucket_;
- void mark_variable_ready(size_t replica_index, size_t variable_index);
+ void mark_variable_ready(
+ size_t replica_index,
+ size_t variable_index,
+ bool called_from_autograd = false);
void mark_bucket_ready(size_t bucket_index);
std::vector<Bucket> buckets_;
- // A bucket index locates the position of a particular variable in the bucket
+ // A variable locator locates a particular variable in the bucket
// structure. The `bucket_index` field points to the bucket in the `buckets_`
// vector. The `intra_bucket_index` field points to the index of the variable
// in any of the vector fields in the bucket replica.
- struct BucketIndex {
+ struct VariableLocator {
// Index into the `buckets_` variable.
size_t bucket_index;
// Index of parameter in single bucket replica.
size_t intra_bucket_index;
};
- // Maps variable index to bucket indices. Bucketing across replicas is
- // identical so no need to include the replica index here.
- std::vector<BucketIndex> bucket_indices_;
+ // Map the index of a variable to its location in the bucket structure.
+ std::vector<VariableLocator> variable_locators_;
// We collect the relative timestamp of every gradient being ready
// when executing autograd. This can be used to derive a timeline of