C++ handler for gradient reduction (#18251)
authorPieter Noordhuis <pcnoordhuis@gmail.com>
Mon, 1 Apr 2019 21:27:03 +0000 (14:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 21:30:02 +0000 (14:30 -0700)
Summary:
This commit adds the `c10d::Reducer` class that hooks into autograd
and performs gradient bucketing and reduction. These are the core
parts of `nn.parallel.DistributedDataParallel` that up to now were
only usable for CUDA models.

This should enable the following:

* Distributed data parallelism for models defined using the C++ frontend.
* Allow overlap of gradient computation and reduction for non-CUDA models.
* Enable distributed data parallelism for models with some unused parameters.

This does not include any logic for computing bucket assignment, which
can be done separately; either by observing autograd execution order
(this is what Apex does), or by assigning buckets based on some
maximum byte size, or both.

Also see #17757 and #13273.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18251

Reviewed By: mrshenli

Differential Revision: D14571899

Pulled By: pietern

fbshipit-source-id: 20f95eefd288dfe8cfffe0a28ca22fa7c9c3cd4c

tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/distributed/c10d/init.cpp
torch/csrc/distributed/c10d/reducer.cpp [new file with mode: 0644]
torch/csrc/distributed/c10d/reducer.h [new file with mode: 0644]
torch/distributed/distributed_c10d.py

index 5718599..cacbfe9 100644 (file)
@@ -152,6 +152,7 @@ def add_torch_libs():
         "torch/csrc/distributed/Module.cpp",
         "torch/csrc/distributed/c10d/init.cpp",
         "torch/csrc/distributed/c10d/ddp.cpp",
+        "torch/csrc/distributed/c10d/reducer.cpp",
     ] + [":generate-code=" + x for x in GENERATED_CPP])
     libtorch_python_sources = sets.to_list(sets.difference(
         sets.make(globbed_sources),
index 9c905ff..7da46b6 100644 (file)
@@ -686,6 +686,7 @@ if (BUILD_PYTHON)
     list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED)
     if (NOT MSVC AND NOT APPLE)
       list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp)
+      list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/distributed/c10d/reducer.cpp)
       list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
       list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
       if (USE_CUDA)
index 25ca579..0e9ccb7 100644 (file)
@@ -19,6 +19,7 @@
 
 #include <torch/csrc/Exceptions.h>
 #include <torch/csrc/distributed/c10d/ddp.h>
+#include <torch/csrc/distributed/c10d/reducer.h>
 #include <torch/csrc/utils/object_ptr.h>
 #include <torch/csrc/utils/pybind.h>
 
@@ -41,6 +42,25 @@ PyObject* c10d_init(PyObject* _unused) {
 
   auto module = py::handle(c10d_module).cast<py::module>();
 
+  shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
+      .def(py::init<
+           std::vector<std::vector<torch::autograd::Variable>>,
+           std::shared_ptr<::c10d::ProcessGroup>>())
+      .def(
+          "initialize_buckets",
+          &::c10d::Reducer::initialize_buckets,
+          py::call_guard<py::gil_scoped_release>())
+      .def(
+          "prepare_for_backward",
+          &::c10d::Reducer::prepare_for_backward,
+          py::call_guard<py::gil_scoped_release>())
+      .def(
+          "prepare_for_backward",
+          [](::c10d::Reducer& reducer, const torch::autograd::Variable& output)
+              -> void { reducer.prepare_for_backward({output}); },
+          py::call_guard<py::gil_scoped_release>())
+      .def("get_backward_stats", &::c10d::Reducer::get_backward_stats);
+
   py::enum_<::c10d::ReduceOp>(module, "ReduceOp", R"(
 An enum-like class of available reduce operations: ``SUM``, ``PRODUCT``,
 ``MIN``, and ``MAX``.
diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp
new file mode 100644 (file)
index 0000000..8128e05
--- /dev/null
@@ -0,0 +1,430 @@
+#include <torch/csrc/distributed/c10d/reducer.h>
+
+#include <functional>
+
+#include <c10/util/Exception.h>
+#include <torch/csrc/autograd/engine.h>
+#include <torch/csrc/autograd/function_hook.h>
+#include <torch/csrc/autograd/functions/accumulate_grad.h>
+#include <torch/csrc/autograd/profiler.h>
+#include <torch/csrc/utils/memory.h>
+
+namespace c10d {
+namespace {
+
+// Turns lambda without input/output into a torch::autograd::FunctionPostHook.
+class LambdaPostHook : public torch::autograd::FunctionPostHook {
+  using variable_list = std::vector<torch::autograd::Variable>;
+
+ public:
+  /* implicit */ LambdaPostHook(std::function<void(void)> fn) : fn_(fn) {}
+
+  variable_list operator()(
+      const variable_list& outputs,
+      const variable_list& /* unused */) override {
+    fn_();
+    return outputs;
+  }
+
+ protected:
+  std::function<void(void)> fn_;
+};
+
+inline int64_t current_time_in_nanos() {
+  return torch::autograd::profiler::getTime();
+}
+
+} // namespace
+
+Reducer::Reducer(
+    std::vector<std::vector<torch::autograd::Variable>> inputs,
+    std::shared_ptr<c10d::ProcessGroup> process_group)
+    : process_group_(std::move(process_group)),
+      expect_autograd_hooks_(false),
+      has_queued_final_callback_(false),
+      next_bucket_(0),
+      backward_stats_base_(0) {
+  AT_ASSERTM(inputs.size() >= 1, "Expected at least one model replica.");
+  AT_ASSERTM(inputs[0].size() >= 1, "Expected at least one parameter.");
+
+  // Verify that all specified variables require gradients,
+  // and that they have the same size across replicas.
+  {
+    const auto replica_count = inputs.size();
+    variables_.resize(replica_count);
+    for (size_t replica_index = 0; replica_index < replica_count;
+         replica_index++) {
+      const auto variable_count = inputs[replica_index].size();
+      variables_[replica_index].resize(variable_count);
+      AT_ASSERTM(
+          variables_[replica_index].size() == variables_[0].size(),
+          "Model replicas must have an equal number of parameters.");
+      for (size_t variable_index = 0; variable_index < variable_count;
+           variable_index++) {
+        variables_[replica_index][variable_index] =
+            inputs[replica_index][variable_index];
+        AT_ASSERTM(
+            variables_[replica_index][variable_index].requires_grad(),
+            "Variables must require gradients (have `requires_grad` set).");
+        AT_ASSERTM(
+            variables_[replica_index][variable_index].sizes() ==
+                variables_[0][variable_index].sizes(),
+            "Variables across model replicas must have identical sizes.");
+        AT_ASSERTM(
+            variables_[replica_index][variable_index].dtype() ==
+                variables_[0][variable_index].dtype(),
+            "Variables across model replicas must have identical dtype.");
+      }
+    }
+  }
+
+  // All variables are expected to have their `grad_fn` set to the gradient
+  // accumulation function (since they are leafs in the autograd graph).
+  // We store pointers to these functions such that we can check if they are
+  // used in an autograd pass. If they are not, we know their grad tensors
+  // can be marked as ready for reduction.
+  {
+    const auto replica_count = variables_.size();
+    grad_accumulators_.resize(replica_count);
+    for (size_t replica_index = 0; replica_index < replica_count;
+         replica_index++) {
+      const auto variable_count = variables_[replica_index].size();
+      grad_accumulators_[replica_index].resize(variable_count);
+      for (size_t variable_index = 0; variable_index < variable_count;
+           variable_index++) {
+        auto& variable = variables_[replica_index][variable_index];
+
+        // The gradient accumulator function is lazily initialized once.
+        // Therefore we can use its presence in the autograd graph as
+        // evidence that the parameter has participated in an iteration.
+        auto grad_accumulator = variable.grad_accumulator();
+
+        // Hook to execute after the gradient accumulator has executed.
+        grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>([=] {
+          std::lock_guard<std::mutex> lock(this->mutex_);
+          this->mark_variable_ready(replica_index, variable_index);
+        }));
+
+        // Map raw function pointer to replica index and parameter index.
+        // This is used later on when the autograd graph is traversed
+        // to check for parameters for which no gradient is computed.
+        func_[grad_accumulator.get()] =
+            std::make_tuple(replica_index, variable_index);
+
+        // The gradient accumulator is stored as weak_ptr in the autograd
+        // metadata of the variable, so we have to keep it alive here for
+        // the raw pointer to be valid.
+        grad_accumulators_[replica_index][variable_index] =
+            std::move(grad_accumulator);
+      }
+    }
+  }
+
+  // Initialize bucketing with naive approach where all variables
+  // are put in a single bucket. This is the equivalent of sequencing
+  // autograd and reduction. The user is expected to override this with
+  // something more clever, if applicable, by a call to `initialize_buckets`.
+  {
+    std::vector<size_t> variable_indices(variables_[0].size());
+    std::iota(std::begin(variable_indices), std::end(variable_indices), 0);
+    initialize_buckets(
+        std::vector<std::vector<size_t>>{std::move(variable_indices)});
+  }
+
+  // Initialize backward stats vector.
+  {
+    const auto replica_count = inputs.size();
+    backward_stats_.resize(replica_count);
+    const auto variable_count = inputs[0].size();
+    std::for_each(
+        backward_stats_.begin(),
+        backward_stats_.end(),
+        [=](std::vector<int64_t>& v) { v.resize(variable_count); });
+  }
+}
+
+// Called when the gradient for the specified variable is ready.
+// It can be called from two places:
+// - By an autograd thread after executing a gradient accumulator function.
+// - By the `Reducer::prepare_for_backward` function if the variable doesn't
+//   show up in the autograd graph (and it wouldn't be called by autograd).
+void Reducer::mark_variable_ready(size_t replica_index, size_t variable_index) {
+  // Ignore if we don't expect to be called.
+  // This may be the case if the user wants to accumulate gradients
+  // for number of iterations before reducing them.
+  if (!expect_autograd_hooks_) {
+    return;
+  }
+
+  AT_ASSERTM(replica_index < variables_.size(), "Out of range replica index.");
+  AT_ASSERTM(
+      variable_index < bucket_indices_.size(), "Out of range variable index.");
+  backward_stats_[replica_index][variable_index] =
+      current_time_in_nanos() - backward_stats_base_;
+
+  const auto& bucket_index = bucket_indices_[variable_index];
+  auto& bucket = buckets_[bucket_index.bucket_index];
+  auto& replica = bucket.replicas[replica_index];
+  auto& variable = replica.variables[bucket_index.intra_bucket_index];
+  const auto offset = replica.offsets[bucket_index.intra_bucket_index];
+  const auto length = replica.lengths[bucket_index.intra_bucket_index];
+
+  // Copy contents of gradient tensor to bucket tensor.
+  // If the gradient is not set, we assume it wasn't computed
+  // as part of the current backwards pass, and zero the part
+  // of the bucket it would otherwise hold.
+  auto bucket_view = replica.contents.narrow(0, offset, length);
+  auto& grad = variable.grad();
+  if (grad.defined()) {
+    AT_ASSERT(grad.type() == variable.type());
+    AT_ASSERT(grad.get_device() == variable.get_device());
+    AT_ASSERT(grad.numel() == length);
+    bucket_view.copy_(grad.view({-1}), /* non_blocking */ true);
+  } else {
+    bucket_view.zero_();
+  }
+
+  // TODO(@pietern): Make this work for both CPU/CUDA tensors.
+  // When using CPU tensors we don't need to do this.
+  // // Record event so that we can wait for all of them.
+  // auto& event = replica.events[bucket_index.intra_bucket_index];
+  // event.record();
+
+  // Check if this was the final gradient for this bucket.
+  if (--replica.pending == 0) {
+    // Prescale bucket contents to turn the global sum into the global average.
+    replica.contents.div_(process_group_->getSize());
+    // Kick off reduction if all replicas for this bucket are ready.
+    if (--bucket.pending == 0) {
+      mark_bucket_ready(bucket_index.bucket_index);
+    }
+  }
+
+  // Autograd callbacks can only be registered while the engine is running.
+  // Register this reducer's final callback once per backward pass.
+  if (!has_queued_final_callback_) {
+    has_queued_final_callback_ = true;
+    torch::autograd::Engine::get_default_engine().queue_callback(
+        [=] { this->finalize_backward(); });
+  }
+}
+
+// Called when the bucket at the specified index is ready to be reduced.
+void Reducer::mark_bucket_ready(size_t bucket_index) {
+  AT_ASSERT(bucket_index >= next_bucket_);
+
+  // Buckets are reduced in sequence. Ignore this bucket if
+  // it's not its turn to be reduced.
+  if (bucket_index > next_bucket_) {
+    return;
+  }
+
+  // Keep going, until we either:
+  // - have kicked off reduction for all buckets, or
+  // - found a bucket that's not yet ready for reduction.
+  for (; next_bucket_ < buckets_.size() && buckets_[next_bucket_].pending == 0;
+       next_bucket_++) {
+    auto& bucket = buckets_[next_bucket_];
+    std::vector<at::Tensor> tensors;
+    tensors.reserve(bucket.replicas.size());
+    for (const auto& replica : bucket.replicas) {
+      // TODO(@pietern): Ensure proper synchronization with the CUDA events
+      // that recorded copies into this contents tensor. If these copies are
+      // executed on non-default streams, the current stream for the device
+      // that holds the contents tensor must wait on these events.
+      //
+      // As long as autograd uses the default stream for every device,
+      // these operations are implicitly sequenced, and we don't need to
+      // do any extra synchronization here.
+      //
+      tensors.push_back(replica.contents);
+    }
+    bucket.work = process_group_->allreduce(tensors);
+  }
+}
+
+void Reducer::initialize_buckets(std::vector<std::vector<size_t>> indices) {
+  std::lock_guard<std::mutex> lock(mutex_);
+
+  // This shouldn't be called if we're expecting autograd hooks to fire.
+  AT_ASSERTM(
+      !expect_autograd_hooks_,
+      "`initialize_buckets` must NOT be called during autograd execution.");
+
+  // Clear current bucket assignment.
+  buckets_.clear();
+  bucket_indices_.clear();
+
+  // Ensure we have a bucket index for every variable.
+  bucket_indices_.resize(variables_[0].size());
+
+  // Iterate over buckets.
+  const auto bucket_count = indices.size();
+  const auto replica_count = variables_.size();
+  buckets_.reserve(bucket_count);
+  for (size_t bucket_index = 0; bucket_index < bucket_count; bucket_index++) {
+    Bucket bucket;
+
+    // TODO(@pietern): Validate indices.
+    // Must be non-empty, unique, and unique across buckets.
+    AT_ASSERTM(indices[bucket_index].size() > 0, "Empty bucket specified.");
+
+    // Iterate over model replicas.
+    for (size_t replica_index = 0; replica_index < replica_count;
+         replica_index++) {
+      at::TensorOptions options;
+      BucketReplica replica;
+      size_t offset = 0;
+
+      // Iterate over bucket variables.
+      for (const auto variable_index : indices[bucket_index]) {
+        AT_ASSERTM(
+            variable_index < variables_[replica_index].size(),
+            "Out of range variable index specified.");
+        const auto& variable = variables_[replica_index][variable_index];
+        if (!options.has_device()) {
+          options = options.device(variable.device());
+        } else {
+          AT_ASSERTM(
+              variable.device() == options.device(),
+              "All parameters in a bucket must be placed on the same device.");
+        }
+        if (!options.has_dtype()) {
+          options = options.dtype(variable.dtype());
+        } else {
+          AT_ASSERTM(
+              variable.dtype() == options.dtype(),
+              "All parameters in a bucket must have the same dtype.");
+        }
+        const auto length = variable.numel();
+        replica.variables.push_back(variable);
+        replica.offsets.push_back(offset);
+        replica.lengths.push_back(length);
+        offset += length;
+      }
+
+      // Allocate bucket contents tensor.
+      replica.contents = torch::autograd::make_variable_consuming(
+          at::empty({static_cast<long>(offset)}, options));
+
+      // Add bucket replica to enclosing bucket.
+      bucket.replicas.push_back(std::move(replica));
+    }
+
+    // Map participating variables to this bucket.
+    // This is identical across replicas so we only need to do this once.
+    size_t intra_bucket_index = 0;
+    for (const auto variable_index : indices[bucket_index]) {
+      AT_ASSERTM(
+          variable_index < bucket_indices_.size(),
+          "Out of range variable index specified.");
+      bucket_indices_[variable_index] = BucketIndex{
+          .bucket_index = bucket_index,
+          .intra_bucket_index = intra_bucket_index++,
+      };
+    }
+
+    buckets_.push_back(std::move(bucket));
+  }
+}
+
+// Traverse the autograd graph starting at the specified output.
+// All parameters for which we have a pointer to their gradient accumulation
+// functions and don't show up in this graph can be marked as ready
+// for reduction immediately. Not doing this means we would deadlock waiting
+// on a gradient for those parameters that will never be computed.
+//
+// Rough copy of torch::autograd::Engine::compute_dependencies.
+//
+void Reducer::prepare_for_backward(
+    const std::vector<torch::autograd::Variable>& outputs) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  std::unordered_set<torch::autograd::Function*> seen;
+  std::vector<torch::autograd::Function*> queue;
+
+  // Reset accounting.
+  has_queued_final_callback_ = false;
+  expect_autograd_hooks_ = true;
+  next_bucket_ = 0;
+  backward_stats_base_ = current_time_in_nanos();
+  for (auto& bucket : buckets_) {
+    for (auto& replica : bucket.replicas) {
+      replica.pending = replica.variables.size();
+    }
+    bucket.pending = bucket.replicas.size();
+  }
+
+  // Seed queue with the grad functions of all outputs.
+  for (const auto& output : outputs) {
+    auto grad_fn = output.grad_fn();
+    if (grad_fn) {
+      queue.push_back(grad_fn.get());
+    }
+  }
+
+  // Traverse the autograd graph starting at the specified output.
+  while (!queue.empty()) {
+    auto fn = queue.back();
+    queue.pop_back();
+    for (const auto& edge : fn->next_edges()) {
+      if (auto next_ptr = edge.function.get()) {
+        const bool was_inserted = seen.insert(next_ptr).second;
+        if (was_inserted) {
+          queue.push_back(next_ptr);
+        }
+      }
+    }
+  }
+
+  // Find accumulator functions that don't show up in this graph.
+  for (const auto& it : func_) {
+    // If the accumulator function is present in the graph, we know
+    // a gradient will be computed for the corresponding parameter.
+    if (seen.count(it.first) > 0) {
+      continue;
+    }
+
+    size_t replica_index;
+    size_t variable_index;
+    std::tie(replica_index, variable_index) = it.second;
+    mark_variable_ready(replica_index, variable_index);
+  }
+}
+
+void Reducer::finalize_backward() {
+  std::lock_guard<std::mutex> lock(mutex_);
+
+  // No longer expect autograd hooks to fire after this function returns.
+  AT_ASSERT(expect_autograd_hooks_);
+  expect_autograd_hooks_ = false;
+
+  // Check that all buckets were completed and had their work kicked off.
+  AT_ASSERTM(
+      next_bucket_ == buckets_.size(),
+      "Expected all buckets to be ready at the end of the backward pass.");
+
+  // Wait for asynchronous reduction to complete and unflatten contents.
+  for (auto& bucket : buckets_) {
+    AT_ASSERT(bucket.work);
+    bucket.work->wait();
+    for (auto& replica : bucket.replicas) {
+      for (size_t intra_bucket_index = 0;
+           intra_bucket_index < replica.variables.size();
+           intra_bucket_index++) {
+        auto& variable = replica.variables[intra_bucket_index];
+        const auto offset = replica.offsets[intra_bucket_index];
+        const auto length = replica.lengths[intra_bucket_index];
+        auto bucket_view =
+            replica.contents.narrow(0, offset, length).view(variable.sizes());
+        auto& grad = variable.grad();
+        if (grad.defined()) {
+          grad.copy_(bucket_view);
+        } else {
+          grad = torch::autograd::make_variable(bucket_view);
+        }
+      }
+    }
+  }
+}
+
+} // namespace c10d
diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h
new file mode 100644 (file)
index 0000000..6606cf6
--- /dev/null
@@ -0,0 +1,137 @@
+#pragma once
+
+#include <atomic>
+#include <memory>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
+
+#include <c10d/ProcessGroup.hpp>
+#include <torch/csrc/autograd/function.h>
+#include <torch/csrc/autograd/variable.h>
+
+namespace c10d {
+
+class Reducer {
+ public:
+  // The constructor takes a vector<Variable> with model parameters for
+  // every model replica, hence the vector<vector<>>.
+  explicit Reducer(
+      std::vector<std::vector<torch::autograd::Variable>> variables,
+      std::shared_ptr<c10d::ProcessGroup> process_group);
+
+  // To (re-)initialize bucket assignment, pass a list of buckets, each
+  // of which is specified by a list of indices in the variables list.
+  // This function performs validation that the variables within a bucket
+  // all live on the same device and have the same dimensionality.
+  void initialize_buckets(std::vector<std::vector<size_t>> indices);
+
+  // This function is called when the forward function has produced an output,
+  // and the user wishes to reduce gradients in the backwards pass.
+  // If they don't, and wish to accumulate gradients before reducing them,
+  // a call to this function can simply be omitted.
+  void prepare_for_backward(
+      const std::vector<torch::autograd::Variable>& outputs);
+
+  // Returns the relative time in nanoseconds when gradients were ready,
+  // with respect to the time `prepare_for_backward` was called. The outer
+  // vector is for model replicas and the inner vector is for parameters.
+  std::vector<std::vector<int64_t>> get_backward_stats() const {
+    return backward_stats_;
+  }
+
+ protected:
+  std::mutex mutex_;
+  std::vector<std::vector<torch::autograd::Variable>> variables_;
+  std::shared_ptr<c10d::ProcessGroup> process_group_;
+
+  std::vector<std::vector<std::shared_ptr<torch::autograd::Function>>>
+      grad_accumulators_;
+  std::unordered_map<torch::autograd::Function*, std::tuple<int, int>> func_;
+
+  bool expect_autograd_hooks_;
+  bool has_queued_final_callback_;
+  size_t next_bucket_;
+
+  void mark_variable_ready(size_t replica_index, size_t variable_index);
+
+  void mark_bucket_ready(size_t bucket_index);
+
+  void finalize_backward();
+
+  // A bucket replica represents [1..N] gradients to be reduced,
+  // with the same dtype, on the same device.
+  //
+  // Batching gradients together before reducing them can result in lower
+  // overhead and/or faster time to completion. Only gradients of the same type
+  // and on the same device can be batched. The tensor that represents the
+  // flattened gradient uses the same type and is placed on the same device.
+  // Buckets are filled as the gradients they hold are computed (triggered by
+  // autograd hooks). Buckets are reduced in a predetemined order that is
+  // identical across processes.
+  //
+  struct BucketReplica {
+    // Flattened (1 dimensional) contents of bucket.
+    at::Tensor contents;
+
+    // Variables that contribute to this bucket replica. Use refcounted value
+    // here so that we can easily unflatten the bucket contents into the
+    // participating variables after reduction has completed.
+    std::vector<torch::autograd::Variable> variables;
+
+    // Per-variable offset/length into the flat bucket contents tensor.
+    std::vector<size_t> offsets;
+    std::vector<size_t> lengths;
+
+    // Number of tensors to be added before this bucket is complete.
+    // This is reset to `variables.size()` every iteration.
+    size_t pending;
+
+    // TODO(@pietern)
+    // Memory copies from gradient tensors into the bucket are potentially
+    // done on different CUDA streams. We record an event for every copy
+    // so that we can synchronize with them prior to kicking off the reduction.
+    // std::vector<at::cuda::CUDAEvent> events;
+  };
+
+  // A bucket holds N bucket replicas (1 per model replica).
+  //
+  // If every bucket in this struct is ready, the reduction can be kicked off.
+  // One bucket per replica. Reduction is kicked off when every bucket is ready.
+  //
+  struct Bucket {
+    std::vector<BucketReplica> replicas;
+
+    // Number of replicas to be marked done before this bucket is ready.
+    size_t pending;
+
+    // Keep work handle around when this set of buckets is being reduced.
+    std::shared_ptr<c10d::ProcessGroup::Work> work;
+  };
+
+  std::vector<Bucket> buckets_;
+
+  // A bucket index locates the position of a particular variable in the bucket
+  // structure. The `bucket_index` field points to the bucket in the `buckets_`
+  // vector. The `intra_bucket_index` field points to the index of the variable
+  // in any of the vector fields in the bucket replica.
+  struct BucketIndex {
+    // Index into the `buckets_` variable.
+    size_t bucket_index;
+    // Index of parameter in single bucket replica.
+    size_t intra_bucket_index;
+  };
+
+  // Maps variable index to bucket indices. Bucketing across replicas is
+  // identical so no need to include the replica index here.
+  std::vector<BucketIndex> bucket_indices_;
+
+  // We collect the relative timestamp of every gradient being ready
+  // when executing autograd. This can be used to derive a timeline of
+  // the point in time buckets were ready, or ideal bucket assignment/ordering.
+  int64_t backward_stats_base_;
+  std::vector<std::vector<int64_t>> backward_stats_;
+};
+
+} // namespace c10d
index 1fea9d5..775485b 100644 (file)
@@ -12,6 +12,7 @@ from . import BroadcastOptions, AllreduceOptions, ReduceOptions, \
 from . import ReduceOp
 from . import PrefixStore
 from . import ProcessGroupGloo
+from . import Reducer
 
 
 _MPI_AVAILABLE = True