fix GradBucket.is_last() logic (#63768)
authorNima Elyasi <nimaelyasi@fb.com>
Wed, 1 Sep 2021 15:47:44 +0000 (08:47 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 16:29:13 +0000 (09:29 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63768

passed number of buckets to GradBucket constructor, to check if index is equal to num_buckets - 1 in the .is_last() function.

Test Plan:
buck test mode/dev-nosan //caffe2/test/distributed/algorithms/ddp_comm_hooks:test_ddp_hooks

test output: https://www.internalfb.com/intern/testinfra/testconsole/testrun/8162774375985873/

Reviewed By: SciPioneer, mrshenli

Differential Revision: D30455913

fbshipit-source-id: 8c67ca69cbf191d6e189e09248407eb167bb24b6

test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py
torch/csrc/distributed/c10d/comm.hpp
torch/csrc/distributed/c10d/reducer.cpp

index 7b889fd..3d00712 100644 (file)
@@ -177,6 +177,36 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
 
         np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
 
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_is_last_hook(self):
+
+        store = dist.FileStore(self.file_name, self.world_size)
+        process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
+
+        def hook(flags, bucket):
+            flags.append(bucket.is_last())
+            fut = torch.futures.Future()
+            fut.set_result(bucket.buffer())
+            return fut
+
+        flags = []
+        device_id = gpus_for_rank(self.world_size)[self.rank][0]
+        model = nn.Sequential(
+            nn.Linear(2, 4000, bias=False),
+            *[nn.Linear(4000, 4000, bias=False) for _ in range(10)]
+        )
+        gpu_model = DistributedDataParallel(
+            model.to(device_id),
+            device_ids=[device_id],
+            process_group=process_group,
+        )
+        gpu_model.register_comm_hook(state=flags, hook=hook)
+        input = torch.randn(10, 2)
+        gpu_model(input).sum().backward()
+        self.assertTrue(flags[-1])
+        self.assertFalse(any(flags[:-1]))
+
 
 if __name__ == "__main__":
     assert (
index 9b45795..4690c35 100644 (file)
@@ -18,12 +18,14 @@ class TORCH_API GradBucket {
  public:
   explicit GradBucket(
       size_t index,
+      size_t bucket_count,
       const at::Tensor& tensor,
       const std::vector<size_t>& offsets,
       const std::vector<size_t>& lengths,
       const std::vector<c10::IntArrayRef>& sizes_vec,
       const std::vector<at::Tensor>& parameters)
       : index_(index),
+        bucket_count_(bucket_count),
         buffer_(tensor),
         offsets_(offsets),
         lengths_(lengths),
@@ -63,11 +65,12 @@ class TORCH_API GradBucket {
 
   // Returns whther this bucket is the last bucket to allreduce in an iteration.
   bool isLast() const {
-    return index_ == 0;
+    return index_ == bucket_count_ - 1;
   }
 
  private:
   size_t index_;
+  size_t bucket_count_;
   at::Tensor buffer_;
 
   // Per-variable info in buffer_.
index eafc70c..91db615 100644 (file)
@@ -472,6 +472,7 @@ std::vector<c10d::GradBucket> Reducer::get_grad_buckets(
     auto variables_for_bucket = get_variables_for_bucket(i, bucket);
     gradBuckets.emplace_back(
         i,
+        buckets_.size(),
         return_zero_tensors ? at::zeros_like(bucket.replicas[0].contents)
                             : bucket.replicas[0].contents,
         bucket.replicas[0].offsets,
@@ -888,6 +889,7 @@ void Reducer::all_reduce_bucket(Bucket& bucket) {
   auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket);
   GradBucket grad_bucket(
       next_bucket_,
+      buckets_.size(),
       tensors[0],
       // Since we only support single-process single-device
       // mode, there is always only one replica in the bucket.