Add tests for reducer class (#18845)
authorPieter Noordhuis <pietern@fb.com>
Fri, 5 Apr 2019 16:04:43 +0000 (09:04 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 16:07:29 +0000 (09:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18845

This adds a few CPU only test cases for the reducer class.

Reviewed By: mrshenli

Differential Revision: D14768432

fbshipit-source-id: c008a52206826304e634a95bc14167ed94c97662

test/test_c10d.py
torch/csrc/distributed/c10d/init.cpp
torch/csrc/distributed/c10d/reducer.cpp
torch/csrc/distributed/c10d/reducer.h

index 4e4a2e5..ed56ead 100644 (file)
@@ -2,12 +2,14 @@ import copy
 import math
 import multiprocessing
 import os
+import random
 import sys
 import tempfile
 import time
 import unittest
 from datetime import timedelta
 
+from itertools import groupby
 from functools import wraps
 from collections import namedtuple
 
@@ -16,6 +18,7 @@ import common_utils as common
 from torch import nn
 import torch.nn.functional as F
 import torch.distributed as c10d
+import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel
 
 from common_utils import TestCase, load_tests, run_tests
@@ -1740,6 +1743,139 @@ class DistributedDataParallelTest(MultiProcessTestCase):
         self.assertEqual(grads_batch[0], (torch.ones(10) * (self.world_size + 1) * len(devices) / 2.0).chunk(5))
 
 
+class ReducerModule(nn.Module):
+    def __init__(self):
+        super(ReducerModule, self).__init__()
+        self.fc1 = nn.Linear(2, 10, bias=False)
+        self.fc2 = nn.Linear(10, 4, bias=False)
+        self.fc3 = nn.Linear(4, 4, bias=False)
+        self.relu = nn.ReLU()
+
+    def forward(self, x, use_fc3=True):
+        x = self.relu(self.fc1(x)).float()
+        x = self.relu(self.fc2(x)).float()
+        if use_fc3:
+            x = self.fc3(x).float()
+        return F.softmax(x, dim=1)
+
+
+class ReducerTest(TestCase):
+    def setUp(self):
+        self.store = c10d.FileStore("/dev/null", 1)
+        self.process_group = c10d.ProcessGroupGloo(self.store, 0, 1)
+
+    def test_single_dtype_single_bucket(self):
+        model = ReducerModule()
+        parameters = list(model.parameters())
+        buckets = [list(range(len(parameters)))]
+        dist.Reducer([parameters], buckets, self.process_group)
+
+    def _create_mixed_precision_model(self):
+        model = ReducerModule()
+        model.float()
+        model.fc1.double()
+        return model
+
+    def test_multi_dtype_single_bucket(self):
+        model = self._create_mixed_precision_model()
+
+        # Raise if there are multiple types per bucket.
+        # In this case we create one bucket for all parameters.
+        with self.assertRaises(RuntimeError):
+            parameters = [list(model.parameters())]
+            buckets = [list(range(len(parameters[0])))]
+            dist.Reducer(parameters, buckets, self.process_group)
+
+    def test_multi_dtype_multi_bucket(self):
+        model = self._create_mixed_precision_model()
+        parameters = [list(model.parameters())]
+        group_by_type = groupby(
+            range(len(parameters[0])),
+            key=lambda i: parameters[0][i].type())
+        buckets = [list(indices) for _, indices in group_by_type]
+        dist.Reducer(parameters, buckets, self.process_group)
+
+    def _create_reducer_for_models(self, models):
+        parameters = [list(model.parameters()) for model in models]
+        group_by_type = groupby(
+            range(len(parameters[0])),
+            key=lambda i: parameters[0][i].type())
+        buckets = [list(indices) for _, indices in group_by_type]
+        return dist.Reducer(parameters, buckets, self.process_group)
+
+    def test_forward_backward_single_replica(self):
+        batch_size = 10
+        model = self._create_mixed_precision_model()
+        reducer = self._create_reducer_for_models([model])
+        loss = nn.CrossEntropyLoss()
+        input = torch.rand([batch_size, 2], dtype=torch.double)
+        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+        output = loss(model(input), target)
+        reducer.prepare_for_backward(output)
+        output.backward()
+
+    def test_forward_backward_multi_replica(self):
+        batch_size = 10
+        num_replicas = 2
+        models = [self._create_mixed_precision_model() for _ in range(num_replicas)]
+        reducer = self._create_reducer_for_models(models)
+        loss = nn.CrossEntropyLoss()
+        input = torch.rand([batch_size, 2], dtype=torch.double).chunk(num_replicas)
+        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+        outputs = [models[i](input[i]) for i in range(num_replicas)]
+        output = loss(torch.cat(outputs), target)
+        reducer.prepare_for_backward(output)
+        output.backward()
+
+        # The reducer will have reduced the gradients for all model replicas.
+        # Verify that they are equal across model replicas.
+        for parameters in zip(*[model.parameters() for model in models]):
+            for parameter in parameters:
+                self.assertEqual(parameters[0].grad, parameter.grad)
+
+    def test_forward_backward_unused_parameters(self):
+        batch_size = 10
+        model = self._create_mixed_precision_model()
+        reducer = self._create_reducer_for_models([model])
+        loss = nn.CrossEntropyLoss()
+        input = torch.rand([batch_size, 2], dtype=torch.double)
+        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+        output = loss(model(input, use_fc3=False), target)
+
+        # Check that the grad of fc3 is not set.
+        self.assertEqual(None, model.fc3.weight.grad)
+
+        # Compute and accumulate gradients.
+        reducer.prepare_for_backward(output)
+        output.backward()
+
+        # The reducer will have marked the grad of fc3 as ready, because
+        # it doesn't show up in the autograd graph of `output`.
+        # This should result in its contents being equal to zero.
+        self.assertEqual(torch.zeros(model.fc3.weight.size()), model.fc3.weight.grad)
+
+    def test_forward_backward_optimizer(self):
+        batch_size = 10
+        model = self._create_mixed_precision_model()
+        reducer = self._create_reducer_for_models([model])
+        loss = nn.CrossEntropyLoss()
+        optimizer = torch.optim.Adam(model.parameters())
+        for i in range(3):
+            input = torch.rand([batch_size, 2], dtype=torch.double)
+            target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
+
+            # The `zero_grad` function calls `detach_` and `zero_` on the grad
+            # tensors of model parameters. If we tried to set the grad tensors
+            # to a view of the reducer's bucket tensors, this would blow up.
+            optimizer.zero_grad()
+
+            # Unused parameter only in the first iteration.
+            output = loss(model(input, use_fc3=(i > 0)), target)
+            reducer.prepare_for_backward(output)
+            output.backward()
+            optimizer.step()
+
+
 if __name__ == '__main__':
     assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"
 
index 0e9ccb7..ff36947 100644 (file)
@@ -45,6 +45,7 @@ PyObject* c10d_init(PyObject* _unused) {
   shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
       .def(py::init<
            std::vector<std::vector<torch::autograd::Variable>>,
+           std::vector<std::vector<size_t>>,
            std::shared_ptr<::c10d::ProcessGroup>>())
       .def(
           "initialize_buckets",
index 6d62f31..166090a 100644 (file)
@@ -38,62 +38,64 @@ inline int64_t current_time_in_nanos() {
 } // namespace
 
 Reducer::Reducer(
-    std::vector<std::vector<torch::autograd::Variable>> inputs,
+    std::vector<std::vector<torch::autograd::Variable>> replicas,
+    std::vector<std::vector<size_t>> bucket_indices,
     std::shared_ptr<c10d::ProcessGroup> process_group)
-    : process_group_(std::move(process_group)),
+    : replicas_(std::move(replicas)),
+      process_group_(std::move(process_group)),
       expect_autograd_hooks_(false),
       has_queued_final_callback_(false),
       next_bucket_(0),
       backward_stats_base_(0) {
-  AT_ASSERTM(inputs.size() >= 1, "Expected at least one model replica.");
-  AT_ASSERTM(inputs[0].size() >= 1, "Expected at least one parameter.");
+  AT_ASSERTM(replicas_.size() >= 1, "Expected at least one model replica.");
+  AT_ASSERTM(replicas_[0].size() >= 1, "Expected at least one parameter.");
 
   // Verify that all specified variables require gradients,
   // and that they have the same size across replicas.
   {
-    const auto replica_count = inputs.size();
-    variables_.resize(replica_count);
+    const auto replica_count = replicas_.size();
     for (size_t replica_index = 0; replica_index < replica_count;
          replica_index++) {
-      const auto variable_count = inputs[replica_index].size();
-      variables_[replica_index].resize(variable_count);
+      const auto variable_count = replicas_[replica_index].size();
       AT_ASSERTM(
-          variables_[replica_index].size() == variables_[0].size(),
+          replicas_[replica_index].size() == replicas_[0].size(),
           "Model replicas must have an equal number of parameters.");
       for (size_t variable_index = 0; variable_index < variable_count;
            variable_index++) {
-        variables_[replica_index][variable_index] =
-            inputs[replica_index][variable_index];
         AT_ASSERTM(
-            variables_[replica_index][variable_index].requires_grad(),
+            replicas_[replica_index][variable_index].requires_grad(),
             "Variables must require gradients (have `requires_grad` set).");
         AT_ASSERTM(
-            variables_[replica_index][variable_index].sizes() ==
-                variables_[0][variable_index].sizes(),
+            replicas_[replica_index][variable_index].sizes() ==
+                replicas_[0][variable_index].sizes(),
             "Variables across model replicas must have identical sizes.");
         AT_ASSERTM(
-            variables_[replica_index][variable_index].dtype() ==
-                variables_[0][variable_index].dtype(),
+            replicas_[replica_index][variable_index].dtype() ==
+                replicas_[0][variable_index].dtype(),
             "Variables across model replicas must have identical dtype.");
       }
     }
   }
 
+  // Initialize variable bucketing.
+  // This can be reinitialized later after capturing runtime information.
+  initialize_buckets(std::move(bucket_indices));
+
   // All variables are expected to have their `grad_fn` set to the gradient
   // accumulation function (since they are leafs in the autograd graph).
   // We store pointers to these functions such that we can check if they are
   // used in an autograd pass. If they are not, we know their grad tensors
   // can be marked as ready for reduction.
   {
-    const auto replica_count = variables_.size();
+    const auto replica_count = replicas_.size();
     grad_accumulators_.resize(replica_count);
     for (size_t replica_index = 0; replica_index < replica_count;
          replica_index++) {
-      const auto variable_count = variables_[replica_index].size();
+      const auto variable_count = replicas_[replica_index].size();
       grad_accumulators_[replica_index].resize(variable_count);
       for (size_t variable_index = 0; variable_index < variable_count;
            variable_index++) {
-        auto& variable = variables_[replica_index][variable_index];
+        auto& variable = replicas_[replica_index][variable_index];
 
         // The gradient accumulator function is lazily initialized once.
         // Therefore we can use its presence in the autograd graph as
@@ -103,7 +105,8 @@ Reducer::Reducer(
         // Hook to execute after the gradient accumulator has executed.
         grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>([=] {
           std::lock_guard<std::mutex> lock(this->mutex_);
-          this->mark_variable_ready(replica_index, variable_index);
+          this->mark_variable_ready(
+              replica_index, variable_index, /* called_from_autograd= */ true);
         }));
 
         // Map raw function pointer to replica index and parameter index.
@@ -121,22 +124,11 @@ Reducer::Reducer(
     }
   }
 
-  // Initialize bucketing with naive approach where all variables
-  // are put in a single bucket. This is the equivalent of sequencing
-  // autograd and reduction. The user is expected to override this with
-  // something more clever, if applicable, by a call to `initialize_buckets`.
-  {
-    std::vector<size_t> variable_indices(variables_[0].size());
-    std::iota(std::begin(variable_indices), std::end(variable_indices), 0);
-    initialize_buckets(
-        std::vector<std::vector<size_t>>{std::move(variable_indices)});
-  }
-
   // Initialize backward stats vector.
   {
-    const auto replica_count = inputs.size();
+    const auto replica_count = replicas_.size();
     backward_stats_.resize(replica_count);
-    const auto variable_count = inputs[0].size();
+    const auto variable_count = replicas_[0].size();
     std::for_each(
         backward_stats_.begin(),
         backward_stats_.end(),
@@ -149,7 +141,10 @@ Reducer::Reducer(
 // - By an autograd thread after executing a gradient accumulator function.
 // - By the `Reducer::prepare_for_backward` function if the variable doesn't
 //   show up in the autograd graph (and it wouldn't be called by autograd).
-void Reducer::mark_variable_ready(size_t replica_index, size_t variable_index) {
+void Reducer::mark_variable_ready(
+    size_t replica_index,
+    size_t variable_index,
+    bool called_from_autograd) {
   // Ignore if we don't expect to be called.
   // This may be the case if the user wants to accumulate gradients
   // for number of iterations before reducing them.
@@ -157,13 +152,14 @@ void Reducer::mark_variable_ready(size_t replica_index, size_t variable_index) {
     return;
   }
 
-  AT_ASSERTM(replica_index < variables_.size(), "Out of range replica index.");
+  AT_ASSERTM(replica_index < replicas_.size(), "Out of range replica index.");
   AT_ASSERTM(
-      variable_index < bucket_indices_.size(), "Out of range variable index.");
+      variable_index < variable_locators_.size(),
+      "Out of range variable index.");
   backward_stats_[replica_index][variable_index] =
       current_time_in_nanos() - backward_stats_base_;
 
-  const auto& bucket_index = bucket_indices_[variable_index];
+  const auto& bucket_index = variable_locators_[variable_index];
   auto& bucket = buckets_[bucket_index.bucket_index];
   auto& replica = bucket.replicas[replica_index];
   auto& variable = replica.variables[bucket_index.intra_bucket_index];
@@ -177,8 +173,13 @@ void Reducer::mark_variable_ready(size_t replica_index, size_t variable_index) {
   auto bucket_view = replica.contents.narrow(0, offset, length);
   auto& grad = variable.grad();
   if (grad.defined()) {
+    // Assert that the grad tensor and the bucket don't share storage.
+    // If they did, we could avoid the copy altogether.
+    // The reason for not doing this is that existing code calls
+    // `detach_` from `zero_grad`, which is incompatible with views.
+    AT_ASSERT(!grad.is_alias_of(bucket_view));
     AT_ASSERT(grad.type() == variable.type());
-    AT_ASSERT(grad.get_device() == variable.get_device());
+    AT_ASSERT(grad.device() == variable.device());
     AT_ASSERT(grad.numel() == length);
     bucket_view.copy_(grad.view({-1}), /* non_blocking */ true);
   } else {
@@ -203,7 +204,7 @@ void Reducer::mark_variable_ready(size_t replica_index, size_t variable_index) {
 
   // Autograd callbacks can only be registered while the engine is running.
   // Register this reducer's final callback once per backward pass.
-  if (!has_queued_final_callback_) {
+  if (!has_queued_final_callback_ && called_from_autograd) {
     has_queued_final_callback_ = true;
     torch::autograd::Engine::get_default_engine().queue_callback(
         [=] { this->finalize_backward(); });
@@ -244,7 +245,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
   }
 }
 
-void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
+void Reducer::initialize_buckets(
+    std::vector<std::vector<size_t>> bucket_indices) {
   std::lock_guard<std::mutex> lock(mutex_);
 
   // This shouldn't be called if we're expecting autograd hooks to fire.
@@ -254,21 +256,22 @@ void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
 
   // Clear current bucket assignment.
   buckets_.clear();
-  bucket_indices_.clear();
+  variable_locators_.clear();
 
   // Ensure we have a bucket index for every variable.
-  bucket_indices_.resize(variables_[0].size());
+  variable_locators_.resize(replicas_[0].size());
 
   // Iterate over buckets.
-  const auto bucket_count = indices.size();
-  const auto replica_count = variables_.size();
+  const auto bucket_count = bucket_indices.size();
+  const auto replica_count = replicas_.size();
   buckets_.reserve(bucket_count);
   for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
     Bucket bucket;
 
     // TODO(@pietern): Validate indices.
     // Must be non-empty, unique, and unique across buckets.
-    AT_ASSERTM(indices[bucket_index].size() > 0, "Empty bucket specified.");
+    AT_ASSERTM(
+        bucket_indices[bucket_index].size() > 0, "Empty bucket specified.");
 
     // Iterate over model replicas.
     for (size_t replica_index = 0; replica_index < replica_count;
@@ -278,11 +281,11 @@ void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
       size_t offset = 0;
 
       // Iterate over bucket variables.
-      for (const auto variable_index : indices[bucket_index]) {
+      for (const auto variable_index : bucket_indices[bucket_index]) {
         AT_ASSERTM(
-            variable_index < variables_[replica_index].size(),
+            variable_index < replicas_[replica_index].size(),
             "Out of range variable index specified.");
-        const auto& variable = variables_[replica_index][variable_index];
+        const auto& variable = replicas_[replica_index][variable_index];
         if (!options.has_device()) {
           options = options.device(variable.device());
         } else {
@@ -305,6 +308,9 @@ void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
       }
 
       // Allocate bucket contents tensor.
+      // This must be a Variable because as of Apr 2019 there is still
+      // a distinction between the Tensor and Variable types, and it
+      // is not recommended (or sometimes even possible) to mix and match.
       replica.contents = torch::autograd::make_variable_consuming(
           at::empty({static_cast<long>(offset)}, options));
 
@@ -315,11 +321,11 @@ void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
     // Map participating variables to this bucket.
     // This is identical across replicas so we only need to do this once.
     size_t intra_bucket_index = 0;
-    for (const auto variable_index : indices[bucket_index]) {
+    for (const auto variable_index : bucket_indices[bucket_index]) {
       AT_ASSERTM(
-          variable_index < bucket_indices_.size(),
+          variable_index < variable_locators_.size(),
           "Out of range variable index specified.");
-      bucket_indices_[variable_index] = BucketIndex{
+      variable_locators_[variable_index] = VariableLocator{
           .bucket_index = bucket_index,
           .intra_bucket_index = intra_bucket_index++,
       };
@@ -418,11 +424,10 @@ void Reducer::finalize_backward() {
         auto bucket_view =
             replica.contents.narrow(0, offset, length).view(variable.sizes());
         auto& grad = variable.grad();
-        if (grad.defined()) {
-          grad.copy_(bucket_view);
-        } else {
-          grad = torch::autograd::make_variable(bucket_view);
+        if (!grad.defined()) {
+          grad = at::empty(bucket_view.sizes(), bucket_view.options());
         }
+        grad.copy_(bucket_view);
       }
     }
   }
index 6606cf6..f1d5ba3 100644 (file)
@@ -15,17 +15,20 @@ namespace c10d {
 
 class Reducer {
  public:
-  // The constructor takes a vector<Variable> with model parameters for
-  // every model replica, hence the vector<vector<>>.
+  // The constructor takes a list of variables for every model replica.
+  // The bucket assignment for this reducer is specified as a list of
+  // buckets, each of which is specified as a list of indices into the
+  // variables list for **a single replica** (i.e. `variables[0]`).
   explicit Reducer(
-      std::vector<std::vector<torch::autograd::Variable>> variables,
+      std::vector<std::vector<torch::autograd::Variable>> replicas,
+      std::vector<std::vector<size_t>> bucket_indices,
       std::shared_ptr<c10d::ProcessGroup> process_group);
 
   // To (re-)initialize bucket assignment, pass a list of buckets, each
   // of which is specified by a list of indices in the variables list.
   // This function performs validation that the variables within a bucket
   // all live on the same device and have the same dimensionality.
-  void initialize_buckets(std::vector<std::vector<size_t>> indices);
+  void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
 
   // This function is called when the forward function has produced an output,
   // and the user wishes to reduce gradients in the backwards pass.
@@ -43,7 +46,7 @@ class Reducer {
 
  protected:
   std::mutex mutex_;
-  std::vector<std::vector<torch::autograd::Variable>> variables_;
+  std::vector<std::vector<torch::autograd::Variable>> replicas_;
   std::shared_ptr<c10d::ProcessGroup> process_group_;
 
   std::vector<std::vector<std::shared_ptr<torch::autograd::Function>>>
@@ -54,7 +57,10 @@ class Reducer {
   bool has_queued_final_callback_;
   size_t next_bucket_;
 
-  void mark_variable_ready(size_t replica_index, size_t variable_index);
+  void mark_variable_ready(
+      size_t replica_index,
+      size_t variable_index,
+      bool called_from_autograd = false);
 
   void mark_bucket_ready(size_t bucket_index);
 
@@ -112,20 +118,19 @@ class Reducer {
 
   std::vector<Bucket> buckets_;
 
-  // A bucket index locates the position of a particular variable in the bucket
+  // A variable locator locates a particular variable in the bucket
   // structure. The `bucket_index` field points to the bucket in the `buckets_`
   // vector. The `intra_bucket_index` field points to the index of the variable
   // in any of the vector fields in the bucket replica.
-  struct BucketIndex {
+  struct VariableLocator {
     // Index into the `buckets_` variable.
     size_t bucket_index;
     // Index of parameter in single bucket replica.
     size_t intra_bucket_index;
   };
 
-  // Maps variable index to bucket indices. Bucketing across replicas is
-  // identical so no need to include the replica index here.
-  std::vector<BucketIndex> bucket_indices_;
+  // Map the index of a variable to its location in the bucket structure.
+  std::vector<VariableLocator> variable_locators_;
 
   // We collect the relative timestamp of every gradient being ready
   // when executing autograd. This can be used to derive a timeline of