[wip] Move smallest bucket to end after rebuild buckets (#62279)
authorRohan Varma <rvarm1@fb.com>
Tue, 17 Aug 2021 22:01:21 +0000 (15:01 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 22:04:50 +0000 (15:04 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62279

Before rebuild buckets, `kDefaultFirstBucketBytes` is actually misleading because we reverse the parameter indices when initialize reducer so it is actually the size of the last bucket.

Currently rebuild buckets sets this to be the first bucket size, but seeing if keeping it as last can help perf.

This is currently experimental only and don't plan to land it unless experiments show a clear win.
ghstack-source-id: 135966897

Test Plan: CI

Reviewed By: SciPioneer

Differential Revision: D29927931

fbshipit-source-id: 55b949986fa2c3bade6fcb4bf5b513461bf0f490

torch/csrc/distributed/c10d/reducer.cpp
torch/testing/_internal/distributed/distributed_test.py

index 17d7a8d..d91f191 100644 (file)
@@ -1667,6 +1667,18 @@ bool Reducer::rebuild_buckets() {
   bucket_size_limits.push_back(first_bucket_bytes_cap_);
   bucket_size_limits.push_back(bucket_bytes_cap_);
   std::vector<size_t> per_bucket_size_limits;
+  auto ddp_set_last_bucket_as_small =
+      (parse_env("DDP_SET_LAST_BUCKET_CAP").compare("1") == 0);
+
+  if (ddp_set_last_bucket_as_small) {
+    // Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last
+    // bucket. We cannot simply pass in {bucket_bytes_cap_, first_bucket_bytes_cap}
+    // as the bucket order as we would immediately advance to the 2nd element
+    // after the first bucket, whereas we only want the last bucket to have
+    // a smaller size.
+    std::reverse(rebuilt_params_.begin(), rebuilt_params_.end());
+    std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end());
+  }
   std::tie(rebuilt_bucket_indices, per_bucket_size_limits) =
       compute_bucket_assignment_by_size(
           rebuilt_params_,
@@ -1674,6 +1686,13 @@ bool Reducer::rebuild_buckets() {
           expect_sparse_gradients_[0],
           rebuilt_param_indices_);
 
+  if (ddp_set_last_bucket_as_small) {
+    // Reverse again because buckets were rebuilt in the opposite of gradient
+    // ready order.
+    std::reverse(rebuilt_bucket_indices.begin(), rebuilt_bucket_indices.end());
+    std::reverse(per_bucket_size_limits.begin(), per_bucket_size_limits.end());
+  }
+
   if (ddp_debug_level_ != c10d::DistributedDebugLevel::OFF) {
     TORCH_INTERNAL_ASSERT(
         rebuilt_bucket_indices.size() == per_bucket_size_limits.size())
@@ -1694,6 +1713,7 @@ bool Reducer::rebuild_buckets() {
 
   initialize_buckets(
       std::move(rebuilt_bucket_indices), std::move(per_bucket_size_limits));
+
   return true;
 }
 
index d7bf0ca..54a22b0 100644 (file)
@@ -7867,6 +7867,7 @@ class DistributedTest:
             torch.cuda.set_device(self.rank)
             default_bucket_cap_mb = 25 * (1024 ** 2)
             first_bucket_bytes_mb = dist._DEFAULT_FIRST_BUCKET_BYTES
+            os.environ["DDP_SET_LAST_BUCKET_CAP"] = "1"
 
             class MyModel(nn.Module):
                 def __init__(self):
@@ -7884,32 +7885,24 @@ class DistributedTest:
                 device_ids=[self.rank]
             )
             inp = torch.randn(10, 2)
+            rebuilt_bucket_index = 2
             for i in range(6):
                 out = ddp(inp).sum()
                 out.backward()
                 logging_data = ddp._get_ddp_logging_data()
-                if i < 2:
-                    bucket_size_limits = [
-                        int(b) for b in logging_data["initial_bucket_size_limits"].split(", ")
-                    ]
-                    # first_bucket_bytes is actually the last because we reverse
-                    # parameter bucket order.
-                    self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
-                    for j, bucket_size in enumerate(bucket_size_limits):
-                        if j != len(bucket_size_limits) - 1:
-                            self.assertEqual(bucket_size, default_bucket_cap_mb)
-                else:
-                    bucket_size_limits = [
-                        int(b) for b in logging_data["rebuilt_bucket_size_limits"].split(", ")
-                    ]
-                    # TODO: rebuild buckets places first bucket at beginning, but
-                    # might be better to move it to end.
-                    self.assertEqual(
-                        bucket_size_limits[0], first_bucket_bytes_mb
-                    )
-                    for j, bucket_size in enumerate(bucket_size_limits):
-                        if j != 0:
-                            self.assertEqual(bucket_size, default_bucket_cap_mb)
+                bucket_size_limits = [
+                    int(b) for b in logging_data[
+                        "{}_bucket_size_limits".format(
+                            "initial" if i < rebuilt_bucket_index else "rebuilt"
+                        )
+                    ].split(", ")
+                ]
+                # first_bucket_bytes is actually the last because we reverse
+                # parameter bucket order under DDP_SET_LAST_BUCKET_CAP flag.
+                self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
+                for j, bucket_size in enumerate(bucket_size_limits):
+                    if j != len(bucket_size_limits) - 1:
+                        self.assertEqual(bucket_size, default_bucket_cap_mb)
 
         @skip_if_lt_x_gpu(2)
         @sandcastle_skip_if(