Simplify data structures, add uniform approximation, fix mem leak (#63162)
authorAndrew Gu <andgu@fb.com>
Fri, 13 Aug 2021 15:19:23 +0000 (08:19 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 15:20:59 +0000 (08:20 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63162

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D30284617

Pulled By: andwgu

fbshipit-source-id: 9bd9e5f89abcc0d3dac56b85d55cc88e843baa9f

test/distributed/optim/test_zero_redundancy_optimizer.py
torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py
torch/distributed/optim/zero_redundancy_optimizer.py
torch/distributed/optim/zero_redundancy_optimizer.pyi

index bee4a4f..c60c7de 100644 (file)
@@ -916,7 +916,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
         device,
         hook_constructor,
         gradient_as_bucket_view,
-        static_graph
+        static_graph,
+        **kwargs,
     ):
         SGD_LR = 0.01
         SGD_MOMENTUM = 0.9
@@ -970,7 +971,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
                 )
                 ddp_model_overlap.register_comm_hook(
                     None,
-                    hook_constructor(allreduce_hook, ddp_model_overlap, zero_optim)
+                    hook_constructor(allreduce_hook, ddp_model_overlap, zero_optim, **kwargs)
                 )
 
                 # Set up the DDP model with local optimizer
@@ -1095,6 +1096,62 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
     # TODO: Add `test_ddp_with_zero_step_interleaved_parity_cpu()` once the
     # Gloo synchronization issue causing hangs is fixed.
 
+    @common_distributed.skip_if_win32()
+    @common_distributed.requires_nccl()
+    @common_distributed.skip_if_no_gpu
+    @common_distributed.skip_if_rocm
+    def test_ddp_with_zero_step_uniform_parity_gpu(self):
+        r"""
+        Check that overlapping DDP with ZeRO using
+        ``hook_with_zero_step()`` with ``shard_buckets=True``
+        achieves parity with DDP using a local optimizer when running on GPU.
+
+        NOTE: The test is skipped if using Windows since functional optimizers
+        are not currently supported.
+        """
+        self.dist_init(self.rank, self.world_size, dist.Backend.NCCL)
+        for gradient_as_bucket_view, static_graph in itertools.product(
+            [True, False],
+            [True, False]
+        ):
+            self._test_ddp_zero_overlap(
+                torch.device(self.rank),
+                hook_with_zero_step,
+                gradient_as_bucket_view,
+                static_graph,
+                shard_buckets=True,
+            )
+    # TODO: Add `test_ddp_with_zero_step_uniform_parity_cpu()` once the Gloo
+    # synchronization issue causing hangs is fixed.
+
+    @common_distributed.skip_if_win32()
+    @common_distributed.requires_nccl()
+    @common_distributed.skip_if_no_gpu
+    @common_distributed.skip_if_rocm
+    def test_ddp_with_zero_step_interleaved_uniform_parity_gpu(self):
+        r"""
+        Check that overlapping DDP with ZeRO using
+        ``hook_with_zero_step()`` with ``shard_buckets=True``
+        achieves parity with DDP using a local optimizer when running on GPU.
+
+        NOTE: The test is skipped if using Windows since functional optimizers
+        are not currently supported.
+        """
+        self.dist_init(self.rank, self.world_size, dist.Backend.NCCL)
+        for gradient_as_bucket_view, static_graph in itertools.product(
+            [True, False],
+            [True, False]
+        ):
+            self._test_ddp_zero_overlap(
+                torch.device(self.rank),
+                hook_with_zero_step_interleaved,
+                gradient_as_bucket_view,
+                static_graph,
+                shard_buckets=True,
+            )
+    # TODO: Add `test_ddp_with_zero_step_interleaved_uniform_parity_cpu()` once
+    # the Gloo synchronization issue causing hangs is fixed.
+
 if __name__ == "__main__":
     # ! unittest should not be used here, else the tests are not properly registered
     common_utils.run_tests()
index c696ff2..b9ed357 100644 (file)
@@ -1,9 +1,13 @@
+import weakref
 from typing import Any, Callable, List, Optional
 
 import torch
 import torch.distributed as dist
 from torch.distributed.optim import ZeroRedundancyOptimizer
-from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus
+from torch.distributed.optim.zero_redundancy_optimizer import (
+    _get_global_rank,
+    _OverlapStatus,
+)
 from torch.nn.parallel.distributed import DistributedDataParallel
 
 # Functional optimizers require passing a list of gradients to their `step()`
@@ -44,9 +48,13 @@ def _perform_local_step(
         [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)]
     assert bucket_index in overlap_info.offsets, \
         f"Bucket index {bucket_index} was not assigned to rank {rank}"
-    offset = overlap_info.offsets[bucket_index]
-    for i, grad in enumerate(bucket.gradients()):
-        gradients[offset + i] = grad
+    gradients_offset = overlap_info.offsets[bucket_index]
+    bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
+    bucket_offset = bucket_assignment.offset
+    length = len(bucket_assignment.parameters)
+    bucket_gradients = bucket.gradients()[bucket_offset:bucket_offset + length]
+    for i, grad in enumerate(bucket_gradients):
+        gradients[gradients_offset + i] = grad
 
     zero._local_step(gradients)
 
@@ -54,7 +62,6 @@ def _perform_local_step(
 def _broadcast_bucket(
     bucket_index: int,
     zero: ZeroRedundancyOptimizer,
-    assigned_rank: int,
 ):
     r"""
     Broadcasts a bucket's parameters.
@@ -64,58 +71,63 @@ def _broadcast_bucket(
             parameters to broadcast.
         zero (ZeroRedundancyOptimizer): the calling process's
             :class:`ZeroRedundancyOptimizer` instance.
-        assigned_rank (int): the rank assigned to the bucket; it has the
-            updated parameters and serves as the source for the broadcast.
     """
     overlap_info = zero._overlap_info
-    device = overlap_info.params_per_bucket[bucket_index][0].device
-    device_index = zero._device_to_device_index[device]
-    assert bucket_index in zero._buckets[device_index][assigned_rank]
-    overlap_info.broadcast_handles.append(
-        dist.broadcast(
-            zero._buckets[device_index][assigned_rank][bucket_index],
-            src=assigned_rank,
-            async_op=True
-        )
-    )
-
-def _collect_ddp_bucket_info(
+    assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \
+        "`assigned_ranks_per_bucket` is not fully constructed"
+    # Sort to ensure the same ordering across ranks
+    assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
+    assert len(assigned_ranks) > 0, f"Bucket {bucket_index} should be " \
+        "assigned to at least one rank"
+    for assigned_rank in assigned_ranks:
+        bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
+        if bucket_index in bucket_assignments:
+            overlap_info.broadcast_handles.append(
+                dist.broadcast(
+                    bucket_assignments[bucket_index].tensor,
+                    src=_get_global_rank(zero.process_group, assigned_rank),
+                    group=zero.process_group,
+                    async_op=True,
+                )
+            )
+
+
+def _save_ddp_bucket_info(
     bucket: dist.GradBucket,
     zero: ZeroRedundancyOptimizer,
-    rank: int,
-    assigned_rank: int,
 ):
     r"""
-    Collects :class:`DistributedDataParallel` gradient bucket information for
-    the :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when
-    overlapping.
+    Saves :class:`DistributedDataParallel` gradient bucket information for the
+    :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping.
+    In particular, this function is meant to be called upon seeing each
+    gradient bucket, meaning it does not save or compute any global
+    information.
 
     Arguments:
         bucket (dist.GradBucket): the current gradient bucket.
         zero (ZeroRedundancyOptimizer): the calling process's
             :class:`ZeroRedundancyOptimizer` instance.
-        rank (int): the calling process's rank.
-        assigned_rank (int): the rank assigned to update the parameters
-            corresponding to ``bucket``.
     """
     overlap_info = zero._overlap_info
-    bucket_index = bucket.index()
     bucket_params = bucket.parameters()
-    assert len(bucket_params) > 0, "Bucket {bucket_index} is empty"
-    params_per_rank = overlap_info.params_per_rank
-    params_per_bucket = overlap_info.params_per_bucket
+    assert len(bucket_params) > 0, "Empty bucket"
 
-    # Collect relevant information
-    if assigned_rank == rank:
-        overlap_info.offsets[bucket_index] = len(params_per_rank[assigned_rank])
-    params_per_rank[assigned_rank].extend(bucket_params)
-    params_per_bucket.append(bucket_params)
+    # Save the parameters in the bucket
+    overlap_info.params_per_bucket.append(bucket_params)
+    if overlap_info.shard_buckets:
+        # Additionally save the bucket size for the assignment heuristic to use
+        bucket_size = 0
+        for param in bucket_params:
+            bucket_size += param.numel()
+        assert overlap_info.total_size is not None
+        overlap_info.total_size += bucket_size
 
 
 def hook_with_zero_step(
     hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
     ddp: DistributedDataParallel,
     zero: ZeroRedundancyOptimizer,
+    shard_buckets: bool = False,
 ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
     r"""
     Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer`
@@ -139,6 +151,12 @@ def hook_with_zero_step(
             instance to use.
         zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
             instance to use.
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity; if
+            ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
 
     Returns:
         The modified hook.
@@ -161,15 +179,20 @@ def hook_with_zero_step(
             "ZeroRedundancyOptimizer must be constructed with "
             "`overlap_with_ddp=True` to use this hook properly"
         )
+    ddp_ref = weakref.ref(ddp)
 
     # NOTE: Gloo may hang with this overlapping approach, so we require
     # NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
-    if dist.get_backend() != dist.Backend.NCCL:
+    if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL:  # type: ignore[union-attr]
         raise RuntimeError(
             "Overlapping DDP with ZeRO using this approach currently requires "
             "NCCL backend to avoid hangs"
         )
 
+    if shard_buckets:
+        zero._overlap_info.shard_buckets = True
+        zero._overlap_info.total_size = 0
+
     def hook_with_zero_fn(
         state: Any,
         bucket: dist.GradBucket,
@@ -193,26 +216,27 @@ def hook_with_zero_step(
         bucket_index = bucket.index()
 
         # Proceed as normal until the DDP buckets have been rebuilt
-        if not ddp._has_rebuilt_buckets:
+        if not ddp_ref()._has_rebuilt_buckets:  # type: ignore[union-attr]
             assert overlap_info.status == _OverlapStatus.UNINITIALIZED
             return fut
 
         if overlap_info.status == _OverlapStatus.UNINITIALIZED:
             overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
-
         rank = zero.global_rank
-        assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index)
 
         # Once DDP buckets have been rebuilt but ZeRO has not been
         # properly initialized yet, collect the information needed
         if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
-            _collect_ddp_bucket_info(bucket, zero, rank, assigned_rank)
+            _save_ddp_bucket_info(bucket, zero)
             return fut
 
         assert overlap_info.status == _OverlapStatus.INITIALIZED
+        assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \
+            "`assigned_ranks_per_bucket` is not fully constructed"
+        assigned_to_bucket = rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
 
         # Save the bucket reference and all-reduce future for the final bucket
-        if assigned_rank == rank:
+        if assigned_to_bucket:
             overlap_info.bucket_index_to_bucket[bucket_index] = bucket
             overlap_info.bucket_index_to_future[bucket_index] = fut
 
@@ -238,8 +262,8 @@ def hook_with_zero_step(
         # all-reduce future since that would add synchronization that delays
         # all optimizer computation to wait for that last all-reduce
         for bucket_index in range(num_buckets):
-            assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index)
-            if assigned_rank == rank:
+            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
+            if rank in assigned_ranks:
                 # Wait on the bucket's all-reduce future to ensure correct
                 # gradients
                 assert bucket_index in overlap_info.bucket_index_to_future, \
@@ -252,11 +276,11 @@ def hook_with_zero_step(
                 curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index]
                 _perform_local_step(curr_bucket, zero, rank)
 
-            _broadcast_bucket(bucket_index, zero, assigned_rank)
+            _broadcast_bucket(bucket_index, zero)
 
         # Ensure that all parameter updates are finished before the
         # next forward pass
-        overlap_info.wait_for_broadcasts(num_buckets, rank)
+        overlap_info.wait_for_broadcasts()
         overlap_info.clear_per_iter_info()
 
         return fut
@@ -268,6 +292,7 @@ def hook_with_zero_step_interleaved(
     hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
     ddp: DistributedDataParallel,
     zero: ZeroRedundancyOptimizer,
+    shard_buckets: bool = False,
 ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
     r"""
     Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer`
@@ -292,6 +317,12 @@ def hook_with_zero_step_interleaved(
             instance to use.
         zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
             instance to use.
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity; if
+            ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
 
     Returns:
         The modified hook.
@@ -314,15 +345,20 @@ def hook_with_zero_step_interleaved(
             "ZeroRedundancyOptimizer must be constructed with "
             "`overlap_with_ddp=True` to use this hook properly"
         )
+    ddp_ref = weakref.ref(ddp)
 
     # NOTE: Gloo may hang with this overlapping approach, so we require
     # NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
-    if dist.get_backend() != dist.Backend.NCCL:
+    if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL:  # type: ignore[union-attr]
         raise RuntimeError(
             "Overlapping DDP with ZeRO using this approach currently requires "
             "NCCL backend to avoid hangs"
         )
 
+    if shard_buckets:
+        zero._overlap_info.shard_buckets = True
+        zero._overlap_info.total_size = 0
+
     def hook_with_zero_interleaved_fn(
         state,
         bucket: dist.GradBucket,
@@ -339,7 +375,7 @@ def hook_with_zero_step_interleaved(
         fut = hook(state, bucket)
 
         # Proceed as normal until the DDP buckets have been rebuilt
-        if not ddp._has_rebuilt_buckets:
+        if not ddp_ref()._has_rebuilt_buckets:  # type: ignore[union-attr]
             assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED
             return fut
 
@@ -353,11 +389,10 @@ def hook_with_zero_step_interleaved(
                 A :class:`torch.Tensor` representing the contents of the
                 gradient bucket.
             """
-            assert ddp._has_rebuilt_buckets
+            assert ddp_ref()._has_rebuilt_buckets  # type: ignore[union-attr]
 
             bucket_index = bucket.index()
             rank = zero.global_rank
-            assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index)
             overlap_info = zero._overlap_info
             if overlap_info.status == _OverlapStatus.UNINITIALIZED:
                 overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
@@ -365,20 +400,21 @@ def hook_with_zero_step_interleaved(
             # Once DDP buckets have been rebuilt but ZeRO has not been
             # properly initialized yet, collect the information needed
             if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
-                _collect_ddp_bucket_info(bucket, zero, rank, assigned_rank)
+                _save_ddp_bucket_info(bucket, zero)
                 return bucket.buffer()
 
+            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
             overlap_info.bucket_indices_seen.append(bucket_index)
-            if assigned_rank == rank:
+            if rank in assigned_ranks:
                 _perform_local_step(bucket, zero, rank)
 
-            _broadcast_bucket(bucket_index, zero, assigned_rank)
+            _broadcast_bucket(bucket_index, zero)
 
             num_buckets = len(overlap_info.params_per_bucket)
             if len(overlap_info.bucket_indices_seen) == num_buckets:
                 # Ensure that all parameter updates are finished before the
                 # next forward pass
-                overlap_info.wait_for_broadcasts(num_buckets, rank)
+                overlap_info.wait_for_broadcasts()
                 overlap_info.clear_per_iter_info()
 
             return bucket.buffer()
index e1bad82..bba71e4 100644 (file)
@@ -9,7 +9,7 @@ import enum
 import io
 import logging
 from itertools import chain
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Set, Type
 
 import torch
 import torch.distributed as dist
@@ -121,17 +121,39 @@ class _ZeROJoinHook(JoinHook):
         self.zero.step()
 
 
-class _DDPBucket(NamedTuple):
+class _DDPBucketAssignment():
     r"""
-    This contains the model parameters corresponding to a
-    :class:`DistributedDataParallel` gradient bucket.
+    This represents a :class:`DistributedDataParallel` bucket assignment,
+    meaning a (possibly non-strict) subset of the parameters corresponding to
+    a DDP bucket assigned to a rank to update.
 
+    Attributes:
         bucket_index (int): index of the bucket determined by the DDP gradient
             bucket all-reduce order.
-        params (List[torch.Tensor]): model parameters in the bucket.
+        parameters (List[torch.Tensor]): model parameters in the bucket
+            assigned to this rank.
+        offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
+            giving the index of the first element in the passed-in
+            ``parameters``; this equivalently indexes into the
+            :class:`GradBucket` 's :meth:`gradients`.
+        device (torch.device): device on which the parameters are stored.
+        tensor (torch.Tensor): flattened tensor giving the data of the
+            parameter subset assigned to the rank.
     """
-    bucket_index: int
-    params: List[torch.Tensor]
+    def __init__(
+        self,
+        bucket_index: int,
+        parameters: List[torch.Tensor],
+        offset: int,
+    ):
+        self.bucket_index = bucket_index
+        self.parameters = parameters
+        self.offset = offset
+        if len(self.parameters) == 0:
+            raise ValueError("Empty bucket assignment")
+        # DDP guarantees all parameters in the bucket have the same device
+        self.device: torch.device = self.parameters[0].device
+        self.tensor: Optional[torch.Tensor] = None
 
 
 class _OverlapStatus(enum.IntEnum):
@@ -158,6 +180,18 @@ class _OverlapInfo():
     This contains the information needed by :class:`ZeroRedundancyOptimizer`
     to overlap with :class:`DistributedDataParallel`.
 
+    Arguments:
+        world_size (int): world size of the process group being used.
+
+    Attributes:
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity following
+            a threshold given by the total parameter size divided by the world
+            size; if ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
+            this should be set to the value passed into the hook constructor.
         status (_OverlapStatus): current status; see :class:`_OverlapStatus`
             for more information.
         params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
@@ -170,6 +204,13 @@ class _OverlapInfo():
             parameter in that bucket, where ``rank`` is this process's own
             rank; the keys of this :class:`dict` are the bucket indices
             assigned to this rank.
+        num_bucket_assignments (int): total number of bucket assignments across
+            all ranks; this is equal to the number of
+            :class:`DistributedDataParallel` gradient buckets if
+            ``shard_buckets=False`` and possibly greater otherwise.
+        total_size (int, optional): total size of all buckets (i.e. sum of
+            ``param.numel()`` for all ``param`` across all buckets) if
+            ``shard_buckets=True``; otherwise, ``None``.
         broadcast_handles (List[Work]): :class:`list` of async work handles for
             the parameter broadcasts.
         bucket_index_to_future (Dict[int, torch.futures.Future]):
@@ -180,14 +221,18 @@ class _OverlapInfo():
         bucket_indices_seen (List[int]): :class:`list` of the bucket indices
             seen on this iteration.
     """
-    def __init__(self) -> None:
+    def __init__(self, world_size) -> None:
         self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
+        self.shard_buckets: bool = False
 
         # Modified per bucket reconstruction
         self.params_per_bucket: List[List[torch.Tensor]] = []
         self.params_per_rank: List[List[torch.Tensor]] = \
-            [[] for _ in range(dist.get_world_size())]
+            [[] for _ in range(world_size)]
         self.offsets: Dict[int, int] = {}
+        self.assigned_ranks_per_bucket: List[Set[int]] = []
+        self.num_bucket_assignments: int = 0
+        self.total_size: Optional[int] = None
 
         # Modified per iteration
         self.broadcast_handles: List[Any] = []
@@ -196,19 +241,15 @@ class _OverlapInfo():
         self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
         self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
 
-    def wait_for_broadcasts(self, num_buckets, rank) -> None:
+    def wait_for_broadcasts(self) -> None:
         r"""
         Waits for all parameter broadcasts. This should be called once all
         broadcasts have been scheduled, meaning ``self.broadcast_handles`` is
         filled. This clears ``self.broadcast_handles`` in preparation for the
         next iteration.
-
-        Arguments:
-            num_buckets (int): total number of buckets.
-            rank (int): the calling process's rank.
         """
-        assert len(self.broadcast_handles) == num_buckets, \
-            f"Missing at least one broadcast handle on rank {rank}"
+        assert len(self.broadcast_handles) == self.num_bucket_assignments, \
+            f"Missing at least one broadcast handle on rank {dist.get_rank()}"
         _ = list(map(lambda x: x.wait(), self.broadcast_handles))
         self.broadcast_handles.clear()
 
@@ -339,8 +380,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
         self._partition_parameters_cache: List[List[Dict]] = []
         self._index_to_param_cache: List[torch.Tensor] = []
         self._device_to_params_per_rank_cache: Dict[torch.device, List[List[torch.Tensor]]] = {}
-        self._device_to_buckets_cache: Dict[torch.device, List[List[_DDPBucket]]] = {}
-        self._device_to_device_index: Dict[torch.device, int] = {}
+        self._bucket_assignments_per_rank_cache: List[Dict[int, _DDPBucketAssignment]] = []
         self._is_trainable_mask = self._get_is_trainable_mask()
 
         # Default device for collective communication and buckets
@@ -360,7 +400,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
         if not overlap_with_ddp:
             self._init_local_optimizer()
         else:
-            self._overlap_info: _OverlapInfo = _OverlapInfo()
+            self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
             if parameters_as_bucket_view:
                 logging.warning(
                     "`parameters_as_bucket_view=True` will be ignored since "
@@ -368,16 +408,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
                     "strategy will be used"
                 )
 
-        # `self._buckets` is used if `parameters_as_bucket_view=True` or
-        # `overlap_with_ddp=True`, in which case parameter data is flattened
-        # into buckets (i.e. contiguous tensors)
-        # If `overlap_with_ddp=True`, the bucketing requires an additional
-        # dimension to match the DDP gradient bucketing
+        # `self._buckets` is used if `parameters_as_bucket_view=True`, in
+        # which case parameter data is flattened into contiguous bucket tensors
         self.parameters_as_bucket_view = parameters_as_bucket_view
-        self._buckets: Union[
-            List[List[torch.Tensor]],
-            List[List[Dict[int, torch.Tensor]]]
-        ] = []  # type: ignore[assignment]
+        self._buckets: List[List[torch.Tensor]] = []
         self._build_param_buckets()
 
         # Optional consolidated optimizer state, only populated if this rank
@@ -395,7 +429,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
         self._index_to_param_cache.clear()
         self._param_to_index_cache.clear()
         self._device_to_params_per_rank_cache.clear()
-        self._device_to_buckets_cache.clear()
+        self._bucket_assignments_per_rank_cache.clear()
 
     def add_param_group(self, param_group: dict) -> None:
         r"""
@@ -610,7 +644,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
                     params_sorted = sorted(param_group["params"], key=lambda t: t.numel(), reverse=True)
                     for param in params_sorted:
                         # Greedily add the parameter to rank with smallest size so far
-                        rank = sizes.index(min(sizes))
+                        rank = self._get_min_index(sizes)
                         param_group_params_per_rank[rank].append(param)
                         sizes[rank] += param.numel()
                     # Apply the constructed partition of the parameter group
@@ -762,44 +796,160 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
                         self._device_to_params_per_rank_cache[device][rank].append(param)
         return self._device_to_params_per_rank_cache
 
-    @property
-    def _device_to_buckets(
-        self
-    ) -> Dict[torch.device, List[List[_DDPBucket]]]:
+    def _get_min_index(
+        self,
+        values: List[int],
+        disallowed_indices: Optional[Set[int]] = None,
+    ) -> int:
         r"""
-        :class:`dict` mapping each device to a :class:`list` of :class:`list`
-        of :class:`_DDPBucket` s.
+        Returns ``values.index(min(values))``, except only uses one pass. It
+        also excludes any indices in ``disallowed_indices`` if provided.
 
-        ``_device_to_buckets[d][r][i]`` gives the ``i``th bucket
-        assigned to rank ``r`` stored on device ``d``, where each bucket
-        contains a list of the model parameters associated with the
-        corresponding logical :class:`DistributedDataParallel` gradient bucket.
+        Arguments:
+            values: (List[int]): :class:`list` of values.
+            disallowed_indices (Optional[Set[int]]): indices that are
+                disallowed from being the returned min index.
+        """
+        min_index = -1
+        min_value = float("inf")
+        for i, value in enumerate(values):
+            if disallowed_indices and i in disallowed_indices:
+                continue
+            if value < min_value:
+                min_value = value
+                min_index = i
+        assert min_index >= 0, "All indices are disallowed"
+        return min_index
+
+    def _assign_bucket_subset_to_rank(
+        self,
+        bucket_index: int,
+        bucket_params: List[torch.Tensor],
+        bucket_offset: int,
+        assigned_rank: int,
+        assigned_ranks_per_bucket: List[Set[int]],
+    ) -> None:
+        r"""
+        Assigns the model parameters given by ``bucket_params``, representing a
+        (possibly non-strict) subset of the parameters corresponding to a
+        :class:`DistributedDataParallel` bucket, to the rank with the least
+        size assigned so far and collects relevant information.
 
-        This is used for constructing the parameter buckets if
-        ``overlap_with_ddp=True``.
+        Arguments:
+            bucket_index (int): index of the :class:`DistributedDataParallel`
+                gradient bucket.
+            bucket_params (List[torch.Tensor]): subset of the parameters
+                corresponding to the bucket to assign.
+            bucket_offset (int): offset giving the index of the first element
+                in ``bucket_params`` in the bucket's full parameter list.
+            assigned_rank (int): rank to assign to.
+            assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of ranks
+                assigned to each bucket.
         """
-        assert self._overlap_with_ddp, \
-            "`_device_to_buckets()` should only be used if " \
-            "`overlap_with_ddp=True`"
-        if len(self._device_to_buckets_cache) > 0:
-            return self._device_to_buckets_cache
+        overlap_info = self._overlap_info
+        if len(bucket_params) == 0:
+            raise ValueError(
+                "Empty bucket assignment"
+            )
+        params_per_rank = overlap_info.params_per_rank
+        offsets = overlap_info.offsets
+
+        self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = \
+            _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
+        if self.global_rank == assigned_rank:
+            offsets[bucket_index] = len(params_per_rank[assigned_rank])
+        params_per_rank[assigned_rank].extend(bucket_params)
+        assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
+        self._overlap_info.num_bucket_assignments += 1
+
+    @property
+    def _bucket_assignments_per_rank(
+        self
+    ) -> List[Dict[int, _DDPBucketAssignment]]:
+        r"""
+        :class:`list` of length world size consisting of :class:`dict` s
+        mapping bucket indices to :class:`_DDPBucketAssignment` s for each
+        rank.
+        """
+        assert self._overlap_with_ddp, "`_bucket_assignments_per_rank` " \
+            "only be used if `overlap_with_ddp=True`"
+        if len(self._bucket_assignments_per_rank_cache) > 0:
+            return self._bucket_assignments_per_rank_cache
 
         overlap_info = self._overlap_info
-        assert overlap_info.status == _OverlapStatus.INITIALIZED, \
-            "Accessing `_device_to_buckets` before the necessary " \
-            "information has been collected"
+        assert overlap_info.status == _OverlapStatus.INITIALIZED
 
+        self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
         params_per_bucket = overlap_info.params_per_bucket
-        for bucket_idx, bucket_params in enumerate(params_per_bucket):
-            assert len(bucket_params) > 0, "Empty bucket"
-            rank = self._ddp_bucket_index_to_rank(bucket_idx)
-            bucket = _DDPBucket(bucket_idx, bucket_params)
-            device = bucket_params[0].device  # assume same device per bucket
-            if device not in self._device_to_buckets_cache:
-                self._device_to_buckets_cache[device] = [[] for _ in range(self.world_size)]
-            self._device_to_buckets_cache[device][rank].append(bucket)
 
-        return self._device_to_buckets_cache
+        if overlap_info.shard_buckets:
+            # Define the assignment threshold to approximate uniformity
+            assert overlap_info.total_size is not None, \
+                "`total_size` was not computed"
+            threshold = overlap_info.total_size / self.world_size  # type: ignore[operator]
+            size_per_rank = [0 for _ in range(self.world_size)]
+
+        num_buckets = len(params_per_bucket)
+        overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
+        assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
+        if not overlap_info.shard_buckets:
+            # Assign each DDP bucket entirely to a single rank
+            for bucket_index, bucket_params in enumerate(params_per_bucket):
+                assert len(bucket_params) > 0, "Empty bucket"
+                assigned_rank = self._get_assigned_rank(bucket_index)
+                self._assign_bucket_subset_to_rank(
+                    bucket_index,
+                    bucket_params,
+                    0,
+                    assigned_rank,
+                    assigned_ranks_per_bucket,
+                )
+        else:
+            # Assign each DDP bucket to possibly multiple ranks
+            # Specifically, sort the DDP buckets by increasing size, and for
+            # each bucket, iteratively assign the maximal unassigned subset
+            # with size less than `threshold` to the rank with the least total
+            # size so far -- each such assignment is represented by a
+            # `_DDPBucketAssignment` instance and only contains parameters from
+            # a single DDP bucket
+            params_per_bucket_enum = sorted(
+                enumerate(params_per_bucket),
+                key=lambda x: sum(p.numel() for p in x[1])
+            )
+            for bucket_index, bucket_params in params_per_bucket_enum:
+                assert len(bucket_params) > 0, "Empty bucket"
+                bucket_offset = 0
+                assignment_size = 0
+                for param_index, param in enumerate(bucket_params):
+                    param_numel = param.numel()
+                    if assignment_size + param_numel >= threshold and param_index > bucket_offset:
+                        assigned_rank = self._get_min_index(size_per_rank, assigned_ranks_per_bucket[bucket_index])
+                        # Include up to but not including the parameter that
+                        # exceeded the threshold
+                        self._assign_bucket_subset_to_rank(
+                            bucket_index,
+                            bucket_params[bucket_offset:param_index],
+                            bucket_offset,
+                            assigned_rank,
+                            assigned_ranks_per_bucket,
+                        )
+                        size_per_rank[assigned_rank] += assignment_size
+                        bucket_offset = param_index
+                        assignment_size = 0
+                    assignment_size += param_numel
+                # Assign the remainder of the bucket so that no assignment
+                # spans across two buckets
+                assigned_rank = self._get_min_index(size_per_rank, assigned_ranks_per_bucket[bucket_index])
+                self._assign_bucket_subset_to_rank(
+                    bucket_index,
+                    bucket_params[bucket_offset:],
+                    bucket_offset,
+                    assigned_rank,
+                    assigned_ranks_per_bucket,
+                )
+                size_per_rank[assigned_rank] += assignment_size
+
+        return self._bucket_assignments_per_rank_cache
 
     def _local_step(
         self,
@@ -1109,57 +1259,38 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
 
     def _build_ddp_param_buckets(self) -> None:
         r"""
-        Builds parameter buckets if ``overlap_with_ddp`` so that for each
-        device that stores this rank's parameters, there is a :class:`list` of
-        buckets (represented as tensors) containing the parameters on that
-        device that are assigned to the rank in the parameter update
-        partition and grouped following the :class:`DistributedDataParallel`
-        gradient buckets.
-
-        This method should only be called during the delayed initialization
-        when ``overlap_with_ddp=True``.
-
-        .. warning::
-            The current implementation assumes that all of the parameters in a
-            bucket are of the same dense type when allocating the bucket's
-            tensor.
-
-        .. warning::
-            If the model parameters are stored across more than one device,
-            then the storage partitioning must be the same across all
-            processes in order for parameter synchronization to work.
+        For each DDP bucket with parameters assigned to this rank, flattens the
+        data of those parameters into a single tensor and saves the tensor to
+        the ``tensor`` attribute in the corresponding
+        :class:`_DDPBucketAssignment` instance stored in
+        ``self._bucket_assignments_per_rank``.
+
+        :class:`DistributedDataParallel` guarantees that the parameters
+        corresponding to a gradient bucket have the same device and the same
+        dtype.
         """
-        assert self._overlap_with_ddp, \
-            "`_build_ddp_param_buckets()` should only be called when " \
-            "`overlap_with_ddp=True`"
+        for bucket_assignments in self._bucket_assignments_per_rank:
+            for bucket_assignment in bucket_assignments.values():
+                params = bucket_assignment.parameters
+                bucket_size = 0
+                dtype = None
+                for param in params:
+                    assert _is_trainable(param), "Model parameter " \
+                        "corresponding to a gradient in a DDP bucket should " \
+                        "require a gradient"
+                    bucket_size += param.numel()
+                    dtype = param.dtype  # assumes all same dtype
+                assert bucket_size > 0, "Empty bucket"
 
-        num_devices = len(self._device_to_buckets)
-        self._buckets = [[{} for _ in range(self.world_size)] for _ in range(num_devices)]  # type: ignore[assignment]
-
-        for dev_idx, (device, ddp_buckets_per_rank) in enumerate(self._device_to_buckets.items()):
-            self._device_to_device_index[device] = dev_idx
-            for rank, ddp_buckets in enumerate(ddp_buckets_per_rank):
-                for ddp_bucket in ddp_buckets:
-                    bucket_index = ddp_bucket.bucket_index  # type: ignore[attr-defined]
-                    params = ddp_bucket.params              # type: ignore[attr-defined]
-                    bucket_size = 0
-                    dtype = None
-                    for param in params:
-                        assert _is_trainable(param), \
-                            "Model parameter corresponding to a gradient in " \
-                            "a DDP bucket should require a gradient"
-                        bucket_size += param.numel()
-                        dtype = param.dtype  # assumes all same dtype
-                    assert bucket_size > 0
-                    bucket = torch.empty(bucket_size, dtype=dtype, device=device)
-                    offset = 0
-                    # Construct the bucket (assuming all dense and same dtype)
-                    for param in params:
-                        offset_next = offset + param.numel()
-                        bucket[offset:offset_next].copy_(param.data.flatten())
-                        param.data = bucket[offset:offset_next].view_as(param.data)
-                        offset = offset_next
-                    self._buckets[dev_idx][rank][bucket_index] = bucket
+                # Construct the bucket tensor (assuming all dense and same dtype)
+                tensor = torch.empty(bucket_size, dtype=dtype, device=bucket_assignment.device)
+                offset = 0
+                for param in params:
+                    offset_next = offset + param.numel()
+                    tensor[offset:offset_next].copy_(param.data.flatten())
+                    param.data = tensor[offset:offset_next].view_as(param.data)
+                    offset = offset_next
+                bucket_assignment.tensor = tensor
 
     def _verify_and_init_params(self, params: Any) -> None:
         r"""
@@ -1245,6 +1376,22 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
                 "functional optimizer with more than one parameter group"
             params = param_groups[0]["params"]
             self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
+
+            # Log information about the DDP and ZeRO bucketing
+            if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF:
+                local_numel = sum(p.numel() for p in params)
+                num_assigned_buckets = len(self._bucket_assignments_per_rank[self.global_rank])
+                logging.info(
+                    f"rank {self.global_rank} with {local_numel} parameters "
+                    f"across {num_assigned_buckets} buckets"
+                )
+                if self.global_rank == 0:
+                    logging.info(
+                        f"{len(self._overlap_info.params_per_bucket)} DDP "
+                        f"buckets and "
+                        f"{self._overlap_info.num_bucket_assignments} bucket "
+                        "assignments"
+                    )
         else:
             # NOTE: Passing `param_groups` into the local optimizer constructor
             # bypasses the empty parameter list check
@@ -1274,22 +1421,20 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
         self._build_ddp_param_buckets()
         self._init_local_optimizer()
 
-    def _ddp_bucket_index_to_rank(self, bucket_index: int) -> int:
-        r"""Assigns a rank to a given DDP gradient bucket index."""
-        return bucket_index % self.world_size
-
-    def _get_assigned_ddp_bucket_indices(self) -> List[int]:
+    def _get_assigned_rank(self, bucket_index: int) -> int:
         r"""
-        Returns a list of the DDP gradient bucket indices assigned to this rank
-        to update.
+        Returns the single rank assigned to a :class:`DistributedDataParallel`
+        gradient bucket.
+
+        Arguments:
+            bucket_index (int): index of the :class:`DistributedDataParallel`
+                bucket for which to get the assigned rank.
         """
-        assert self._overlap_info.status == _OverlapStatus.INITIALIZED
-        num_buckets = len(self._overlap_info.params_per_bucket)
-        assigned_indices = [
-            bucket_index for bucket_index in range(num_buckets)
-            if self._ddp_bucket_index_to_rank(bucket_index) == self.global_rank
-        ]
-        return assigned_indices
+        assert not self._overlap_info.shard_buckets, \
+            "The bucket assignment requires global bucket information and " \
+            "will be computed later; there should be no need to use this " \
+            "method"
+        return bucket_index % self.world_size
 
     def _check_overlap_initialized(self):
         r"""
index 7dc0812..bfea1e7 100644 (file)
@@ -1,19 +1,32 @@
 import enum
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Type,
+)
 
 import torch
 from torch.distributed.algorithms.join import Joinable, JoinHook
 from torch.optim import Optimizer
 
 
+def _get_global_rank(group: Any, rank: int) -> int: ...
+
 class _ZeROJoinHook(JoinHook):
     zero: Any = ...
     def __init__(self, zero: Any) -> None: ...
     def main_hook(self) -> None: ...
 
-class _DDPBucket(NamedTuple):
+class _DDPBucketAssignment():
     bucket_index: int
-    params: List[torch.Tensor]
+    parameters: List[torch.Tensor]
+    offset: int
+    device: torch.device
+    tensor: Optional[torch.Tensor]
 
 class _OverlapStatus(enum.IntEnum):
     UNINITIALIZED: int = ...
@@ -29,8 +42,11 @@ class _OverlapInfo:
     bucket_index_to_future: Any = ...
     bucket_index_to_bucket: Any = ...
     bucket_indices_seen: Any = ...
+    assigned_ranks_per_bucket: List[Set[int]] = ...
+    total_size: int = ...
+    shard_buckets: bool = ...
     def __init__(self) -> None: ...
-    def wait_for_broadcasts(self, num_buckets, rank) -> None: ...
+    def wait_for_broadcasts(self) -> None: ...
     def clear_per_iter_info(self) -> None: ...
 
 class ZeroRedundancyOptimizer(Optimizer, Joinable):
@@ -45,7 +61,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
     _device_to_device_index: Dict[torch.device, int] = ...
     _overlap_with_ddp: bool = ...
     _overlap_info: _OverlapInfo = ...
-    _buckets: Union[List[List[torch.Tensor]], List[List[Dict[int, torch.Tensor]]]] = ...
+    _buckets: List[List[torch.Tensor]] = ...
+    _bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ...
     def __init__(self, params: Any, optimizer_class: Type[Optimizer], process_group: Optional[Any]=..., parameters_as_bucket_view: bool=..., overlap_with_ddp: bool=..., **defaults: Any) -> None: ...
     def add_param_group(self, param_group: dict) -> None: ...
     def consolidate_state_dict(self, to: int=...) -> None: ...
@@ -53,7 +70,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
     def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
     def state_dict(self) -> Dict[str, Any]: ...
     def _local_step(self, gradients: Optional[List[Optional[torch.Tensor]]] = None, closure: Optional[Callable[[], float]] = None, **kwargs: Any,) -> Optional[float]: ...
-    def _ddp_bucket_index_to_rank(self, bucket_index: int) -> int: ...
+    def _get_assigned_rank(self, bucket_index: int) -> int: ...
     def join_hook(self, **kwargs): ...
     def join_device(self) -> torch.device: ...
     def join_process_group(self) -> Any: ...