+import weakref
from typing import Any, Callable, List, Optional
import torch
import torch.distributed as dist
from torch.distributed.optim import ZeroRedundancyOptimizer
-from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus
+from torch.distributed.optim.zero_redundancy_optimizer import (
+ _get_global_rank,
+ _OverlapStatus,
+)
from torch.nn.parallel.distributed import DistributedDataParallel
# Functional optimizers require passing a list of gradients to their `step()`
[_NO_PARAM_UPDATE for _ in range(num_local_optim_params)]
assert bucket_index in overlap_info.offsets, \
f"Bucket index {bucket_index} was not assigned to rank {rank}"
- offset = overlap_info.offsets[bucket_index]
- for i, grad in enumerate(bucket.gradients()):
- gradients[offset + i] = grad
+ gradients_offset = overlap_info.offsets[bucket_index]
+ bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
+ bucket_offset = bucket_assignment.offset
+ length = len(bucket_assignment.parameters)
+ bucket_gradients = bucket.gradients()[bucket_offset:bucket_offset + length]
+ for i, grad in enumerate(bucket_gradients):
+ gradients[gradients_offset + i] = grad
zero._local_step(gradients)
def _broadcast_bucket(
bucket_index: int,
zero: ZeroRedundancyOptimizer,
- assigned_rank: int,
):
r"""
Broadcasts a bucket's parameters.
parameters to broadcast.
zero (ZeroRedundancyOptimizer): the calling process's
:class:`ZeroRedundancyOptimizer` instance.
- assigned_rank (int): the rank assigned to the bucket; it has the
- updated parameters and serves as the source for the broadcast.
"""
overlap_info = zero._overlap_info
- device = overlap_info.params_per_bucket[bucket_index][0].device
- device_index = zero._device_to_device_index[device]
- assert bucket_index in zero._buckets[device_index][assigned_rank]
- overlap_info.broadcast_handles.append(
- dist.broadcast(
- zero._buckets[device_index][assigned_rank][bucket_index],
- src=assigned_rank,
- async_op=True
- )
- )
-
-def _collect_ddp_bucket_info(
+ assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \
+ "`assigned_ranks_per_bucket` is not fully constructed"
+ # Sort to ensure the same ordering across ranks
+ assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
+ assert len(assigned_ranks) > 0, f"Bucket {bucket_index} should be " \
+ "assigned to at least one rank"
+ for assigned_rank in assigned_ranks:
+ bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
+ if bucket_index in bucket_assignments:
+ overlap_info.broadcast_handles.append(
+ dist.broadcast(
+ bucket_assignments[bucket_index].tensor,
+ src=_get_global_rank(zero.process_group, assigned_rank),
+ group=zero.process_group,
+ async_op=True,
+ )
+ )
+
+
+def _save_ddp_bucket_info(
bucket: dist.GradBucket,
zero: ZeroRedundancyOptimizer,
- rank: int,
- assigned_rank: int,
):
r"""
- Collects :class:`DistributedDataParallel` gradient bucket information for
- the :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when
- overlapping.
+ Saves :class:`DistributedDataParallel` gradient bucket information for the
+ :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping.
+ In particular, this function is meant to be called upon seeing each
+ gradient bucket, meaning it does not save or compute any global
+ information.
Arguments:
bucket (dist.GradBucket): the current gradient bucket.
zero (ZeroRedundancyOptimizer): the calling process's
:class:`ZeroRedundancyOptimizer` instance.
- rank (int): the calling process's rank.
- assigned_rank (int): the rank assigned to update the parameters
- corresponding to ``bucket``.
"""
overlap_info = zero._overlap_info
- bucket_index = bucket.index()
bucket_params = bucket.parameters()
- assert len(bucket_params) > 0, "Bucket {bucket_index} is empty"
- params_per_rank = overlap_info.params_per_rank
- params_per_bucket = overlap_info.params_per_bucket
+ assert len(bucket_params) > 0, "Empty bucket"
- # Collect relevant information
- if assigned_rank == rank:
- overlap_info.offsets[bucket_index] = len(params_per_rank[assigned_rank])
- params_per_rank[assigned_rank].extend(bucket_params)
- params_per_bucket.append(bucket_params)
+ # Save the parameters in the bucket
+ overlap_info.params_per_bucket.append(bucket_params)
+ if overlap_info.shard_buckets:
+ # Additionally save the bucket size for the assignment heuristic to use
+ bucket_size = 0
+ for param in bucket_params:
+ bucket_size += param.numel()
+ assert overlap_info.total_size is not None
+ overlap_info.total_size += bucket_size
def hook_with_zero_step(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
ddp: DistributedDataParallel,
zero: ZeroRedundancyOptimizer,
+ shard_buckets: bool = False,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
r"""
Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer`
instance to use.
zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
instance to use.
+ shard_buckets (bool): if ``True``, then the assignment of each
+ :class:`DistributedDataParallel` bucket is partitioned across
+ possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+ across possibly multiple ranks) to approximate uniformity; if
+ ``False``, then each bucket is wholly assigned to a single
+ :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
Returns:
The modified hook.
"ZeroRedundancyOptimizer must be constructed with "
"`overlap_with_ddp=True` to use this hook properly"
)
+ ddp_ref = weakref.ref(ddp)
# NOTE: Gloo may hang with this overlapping approach, so we require
# NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
- if dist.get_backend() != dist.Backend.NCCL:
+ if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL: # type: ignore[union-attr]
raise RuntimeError(
"Overlapping DDP with ZeRO using this approach currently requires "
"NCCL backend to avoid hangs"
)
+ if shard_buckets:
+ zero._overlap_info.shard_buckets = True
+ zero._overlap_info.total_size = 0
+
def hook_with_zero_fn(
state: Any,
bucket: dist.GradBucket,
bucket_index = bucket.index()
# Proceed as normal until the DDP buckets have been rebuilt
- if not ddp._has_rebuilt_buckets:
+ if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr]
assert overlap_info.status == _OverlapStatus.UNINITIALIZED
return fut
if overlap_info.status == _OverlapStatus.UNINITIALIZED:
overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
-
rank = zero.global_rank
- assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index)
# Once DDP buckets have been rebuilt but ZeRO has not been
# properly initialized yet, collect the information needed
if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
- _collect_ddp_bucket_info(bucket, zero, rank, assigned_rank)
+ _save_ddp_bucket_info(bucket, zero)
return fut
assert overlap_info.status == _OverlapStatus.INITIALIZED
+ assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \
+ "`assigned_ranks_per_bucket` is not fully constructed"
+ assigned_to_bucket = rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
# Save the bucket reference and all-reduce future for the final bucket
- if assigned_rank == rank:
+ if assigned_to_bucket:
overlap_info.bucket_index_to_bucket[bucket_index] = bucket
overlap_info.bucket_index_to_future[bucket_index] = fut
# all-reduce future since that would add synchronization that delays
# all optimizer computation to wait for that last all-reduce
for bucket_index in range(num_buckets):
- assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index)
- if assigned_rank == rank:
+ assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
+ if rank in assigned_ranks:
# Wait on the bucket's all-reduce future to ensure correct
# gradients
assert bucket_index in overlap_info.bucket_index_to_future, \
curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index]
_perform_local_step(curr_bucket, zero, rank)
- _broadcast_bucket(bucket_index, zero, assigned_rank)
+ _broadcast_bucket(bucket_index, zero)
# Ensure that all parameter updates are finished before the
# next forward pass
- overlap_info.wait_for_broadcasts(num_buckets, rank)
+ overlap_info.wait_for_broadcasts()
overlap_info.clear_per_iter_info()
return fut
hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
ddp: DistributedDataParallel,
zero: ZeroRedundancyOptimizer,
+ shard_buckets: bool = False,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
r"""
Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer`
instance to use.
zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
instance to use.
+ shard_buckets (bool): if ``True``, then the assignment of each
+ :class:`DistributedDataParallel` bucket is partitioned across
+ possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+ across possibly multiple ranks) to approximate uniformity; if
+ ``False``, then each bucket is wholly assigned to a single
+ :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
Returns:
The modified hook.
"ZeroRedundancyOptimizer must be constructed with "
"`overlap_with_ddp=True` to use this hook properly"
)
+ ddp_ref = weakref.ref(ddp)
# NOTE: Gloo may hang with this overlapping approach, so we require
# NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
- if dist.get_backend() != dist.Backend.NCCL:
+ if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL: # type: ignore[union-attr]
raise RuntimeError(
"Overlapping DDP with ZeRO using this approach currently requires "
"NCCL backend to avoid hangs"
)
+ if shard_buckets:
+ zero._overlap_info.shard_buckets = True
+ zero._overlap_info.total_size = 0
+
def hook_with_zero_interleaved_fn(
state,
bucket: dist.GradBucket,
fut = hook(state, bucket)
# Proceed as normal until the DDP buckets have been rebuilt
- if not ddp._has_rebuilt_buckets:
+ if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr]
assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED
return fut
A :class:`torch.Tensor` representing the contents of the
gradient bucket.
"""
- assert ddp._has_rebuilt_buckets
+ assert ddp_ref()._has_rebuilt_buckets # type: ignore[union-attr]
bucket_index = bucket.index()
rank = zero.global_rank
- assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index)
overlap_info = zero._overlap_info
if overlap_info.status == _OverlapStatus.UNINITIALIZED:
overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
# Once DDP buckets have been rebuilt but ZeRO has not been
# properly initialized yet, collect the information needed
if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
- _collect_ddp_bucket_info(bucket, zero, rank, assigned_rank)
+ _save_ddp_bucket_info(bucket, zero)
return bucket.buffer()
+ assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
overlap_info.bucket_indices_seen.append(bucket_index)
- if assigned_rank == rank:
+ if rank in assigned_ranks:
_perform_local_step(bucket, zero, rank)
- _broadcast_bucket(bucket_index, zero, assigned_rank)
+ _broadcast_bucket(bucket_index, zero)
num_buckets = len(overlap_info.params_per_bucket)
if len(overlap_info.bucket_indices_seen) == num_buckets:
# Ensure that all parameter updates are finished before the
# next forward pass
- overlap_info.wait_for_broadcasts(num_buckets, rank)
+ overlap_info.wait_for_broadcasts()
overlap_info.clear_per_iter_info()
return bucket.buffer()
import io
import logging
from itertools import chain
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Set, Type
import torch
import torch.distributed as dist
self.zero.step()
-class _DDPBucket(NamedTuple):
+class _DDPBucketAssignment():
r"""
- This contains the model parameters corresponding to a
- :class:`DistributedDataParallel` gradient bucket.
+ This represents a :class:`DistributedDataParallel` bucket assignment,
+ meaning a (possibly non-strict) subset of the parameters corresponding to
+ a DDP bucket assigned to a rank to update.
+ Attributes:
bucket_index (int): index of the bucket determined by the DDP gradient
bucket all-reduce order.
- params (List[torch.Tensor]): model parameters in the bucket.
+ parameters (List[torch.Tensor]): model parameters in the bucket
+ assigned to this rank.
+ offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
+ giving the index of the first element in the passed-in
+ ``parameters``; this equivalently indexes into the
+ :class:`GradBucket` 's :meth:`gradients`.
+ device (torch.device): device on which the parameters are stored.
+ tensor (torch.Tensor): flattened tensor giving the data of the
+ parameter subset assigned to the rank.
"""
- bucket_index: int
- params: List[torch.Tensor]
+ def __init__(
+ self,
+ bucket_index: int,
+ parameters: List[torch.Tensor],
+ offset: int,
+ ):
+ self.bucket_index = bucket_index
+ self.parameters = parameters
+ self.offset = offset
+ if len(self.parameters) == 0:
+ raise ValueError("Empty bucket assignment")
+ # DDP guarantees all parameters in the bucket have the same device
+ self.device: torch.device = self.parameters[0].device
+ self.tensor: Optional[torch.Tensor] = None
class _OverlapStatus(enum.IntEnum):
This contains the information needed by :class:`ZeroRedundancyOptimizer`
to overlap with :class:`DistributedDataParallel`.
+ Arguments:
+ world_size (int): world size of the process group being used.
+
+ Attributes:
+ shard_buckets (bool): if ``True``, then the assignment of each
+ :class:`DistributedDataParallel` bucket is partitioned across
+ possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+ across possibly multiple ranks) to approximate uniformity following
+ a threshold given by the total parameter size divided by the world
+ size; if ``False``, then each bucket is wholly assigned to a single
+ :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
+ this should be set to the value passed into the hook constructor.
status (_OverlapStatus): current status; see :class:`_OverlapStatus`
for more information.
params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
parameter in that bucket, where ``rank`` is this process's own
rank; the keys of this :class:`dict` are the bucket indices
assigned to this rank.
+ num_bucket_assignments (int): total number of bucket assignments across
+ all ranks; this is equal to the number of
+ :class:`DistributedDataParallel` gradient buckets if
+ ``shard_buckets=False`` and possibly greater otherwise.
+ total_size (int, optional): total size of all buckets (i.e. sum of
+ ``param.numel()`` for all ``param`` across all buckets) if
+ ``shard_buckets=True``; otherwise, ``None``.
broadcast_handles (List[Work]): :class:`list` of async work handles for
the parameter broadcasts.
bucket_index_to_future (Dict[int, torch.futures.Future]):
bucket_indices_seen (List[int]): :class:`list` of the bucket indices
seen on this iteration.
"""
- def __init__(self) -> None:
+ def __init__(self, world_size) -> None:
self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
+ self.shard_buckets: bool = False
# Modified per bucket reconstruction
self.params_per_bucket: List[List[torch.Tensor]] = []
self.params_per_rank: List[List[torch.Tensor]] = \
- [[] for _ in range(dist.get_world_size())]
+ [[] for _ in range(world_size)]
self.offsets: Dict[int, int] = {}
+ self.assigned_ranks_per_bucket: List[Set[int]] = []
+ self.num_bucket_assignments: int = 0
+ self.total_size: Optional[int] = None
# Modified per iteration
self.broadcast_handles: List[Any] = []
self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
- def wait_for_broadcasts(self, num_buckets, rank) -> None:
+ def wait_for_broadcasts(self) -> None:
r"""
Waits for all parameter broadcasts. This should be called once all
broadcasts have been scheduled, meaning ``self.broadcast_handles`` is
filled. This clears ``self.broadcast_handles`` in preparation for the
next iteration.
-
- Arguments:
- num_buckets (int): total number of buckets.
- rank (int): the calling process's rank.
"""
- assert len(self.broadcast_handles) == num_buckets, \
- f"Missing at least one broadcast handle on rank {rank}"
+ assert len(self.broadcast_handles) == self.num_bucket_assignments, \
+ f"Missing at least one broadcast handle on rank {dist.get_rank()}"
_ = list(map(lambda x: x.wait(), self.broadcast_handles))
self.broadcast_handles.clear()
self._partition_parameters_cache: List[List[Dict]] = []
self._index_to_param_cache: List[torch.Tensor] = []
self._device_to_params_per_rank_cache: Dict[torch.device, List[List[torch.Tensor]]] = {}
- self._device_to_buckets_cache: Dict[torch.device, List[List[_DDPBucket]]] = {}
- self._device_to_device_index: Dict[torch.device, int] = {}
+ self._bucket_assignments_per_rank_cache: List[Dict[int, _DDPBucketAssignment]] = []
self._is_trainable_mask = self._get_is_trainable_mask()
# Default device for collective communication and buckets
if not overlap_with_ddp:
self._init_local_optimizer()
else:
- self._overlap_info: _OverlapInfo = _OverlapInfo()
+ self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
if parameters_as_bucket_view:
logging.warning(
"`parameters_as_bucket_view=True` will be ignored since "
"strategy will be used"
)
- # `self._buckets` is used if `parameters_as_bucket_view=True` or
- # `overlap_with_ddp=True`, in which case parameter data is flattened
- # into buckets (i.e. contiguous tensors)
- # If `overlap_with_ddp=True`, the bucketing requires an additional
- # dimension to match the DDP gradient bucketing
+ # `self._buckets` is used if `parameters_as_bucket_view=True`, in
+ # which case parameter data is flattened into contiguous bucket tensors
self.parameters_as_bucket_view = parameters_as_bucket_view
- self._buckets: Union[
- List[List[torch.Tensor]],
- List[List[Dict[int, torch.Tensor]]]
- ] = [] # type: ignore[assignment]
+ self._buckets: List[List[torch.Tensor]] = []
self._build_param_buckets()
# Optional consolidated optimizer state, only populated if this rank
self._index_to_param_cache.clear()
self._param_to_index_cache.clear()
self._device_to_params_per_rank_cache.clear()
- self._device_to_buckets_cache.clear()
+ self._bucket_assignments_per_rank_cache.clear()
def add_param_group(self, param_group: dict) -> None:
r"""
params_sorted = sorted(param_group["params"], key=lambda t: t.numel(), reverse=True)
for param in params_sorted:
# Greedily add the parameter to rank with smallest size so far
- rank = sizes.index(min(sizes))
+ rank = self._get_min_index(sizes)
param_group_params_per_rank[rank].append(param)
sizes[rank] += param.numel()
# Apply the constructed partition of the parameter group
self._device_to_params_per_rank_cache[device][rank].append(param)
return self._device_to_params_per_rank_cache
- @property
- def _device_to_buckets(
- self
- ) -> Dict[torch.device, List[List[_DDPBucket]]]:
+ def _get_min_index(
+ self,
+ values: List[int],
+ disallowed_indices: Optional[Set[int]] = None,
+ ) -> int:
r"""
- :class:`dict` mapping each device to a :class:`list` of :class:`list`
- of :class:`_DDPBucket` s.
+ Returns ``values.index(min(values))``, except only uses one pass. It
+ also excludes any indices in ``disallowed_indices`` if provided.
- ``_device_to_buckets[d][r][i]`` gives the ``i``th bucket
- assigned to rank ``r`` stored on device ``d``, where each bucket
- contains a list of the model parameters associated with the
- corresponding logical :class:`DistributedDataParallel` gradient bucket.
+ Arguments:
+ values: (List[int]): :class:`list` of values.
+ disallowed_indices (Optional[Set[int]]): indices that are
+ disallowed from being the returned min index.
+ """
+ min_index = -1
+ min_value = float("inf")
+ for i, value in enumerate(values):
+ if disallowed_indices and i in disallowed_indices:
+ continue
+ if value < min_value:
+ min_value = value
+ min_index = i
+ assert min_index >= 0, "All indices are disallowed"
+ return min_index
+
+ def _assign_bucket_subset_to_rank(
+ self,
+ bucket_index: int,
+ bucket_params: List[torch.Tensor],
+ bucket_offset: int,
+ assigned_rank: int,
+ assigned_ranks_per_bucket: List[Set[int]],
+ ) -> None:
+ r"""
+ Assigns the model parameters given by ``bucket_params``, representing a
+ (possibly non-strict) subset of the parameters corresponding to a
+ :class:`DistributedDataParallel` bucket, to the rank with the least
+ size assigned so far and collects relevant information.
- This is used for constructing the parameter buckets if
- ``overlap_with_ddp=True``.
+ Arguments:
+ bucket_index (int): index of the :class:`DistributedDataParallel`
+ gradient bucket.
+ bucket_params (List[torch.Tensor]): subset of the parameters
+ corresponding to the bucket to assign.
+ bucket_offset (int): offset giving the index of the first element
+ in ``bucket_params`` in the bucket's full parameter list.
+ assigned_rank (int): rank to assign to.
+ assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of ranks
+ assigned to each bucket.
"""
- assert self._overlap_with_ddp, \
- "`_device_to_buckets()` should only be used if " \
- "`overlap_with_ddp=True`"
- if len(self._device_to_buckets_cache) > 0:
- return self._device_to_buckets_cache
+ overlap_info = self._overlap_info
+ if len(bucket_params) == 0:
+ raise ValueError(
+ "Empty bucket assignment"
+ )
+ params_per_rank = overlap_info.params_per_rank
+ offsets = overlap_info.offsets
+
+ self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = \
+ _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
+ if self.global_rank == assigned_rank:
+ offsets[bucket_index] = len(params_per_rank[assigned_rank])
+ params_per_rank[assigned_rank].extend(bucket_params)
+ assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
+ self._overlap_info.num_bucket_assignments += 1
+
+ @property
+ def _bucket_assignments_per_rank(
+ self
+ ) -> List[Dict[int, _DDPBucketAssignment]]:
+ r"""
+ :class:`list` of length world size consisting of :class:`dict` s
+ mapping bucket indices to :class:`_DDPBucketAssignment` s for each
+ rank.
+ """
+ assert self._overlap_with_ddp, "`_bucket_assignments_per_rank` " \
+ "only be used if `overlap_with_ddp=True`"
+ if len(self._bucket_assignments_per_rank_cache) > 0:
+ return self._bucket_assignments_per_rank_cache
overlap_info = self._overlap_info
- assert overlap_info.status == _OverlapStatus.INITIALIZED, \
- "Accessing `_device_to_buckets` before the necessary " \
- "information has been collected"
+ assert overlap_info.status == _OverlapStatus.INITIALIZED
+ self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
params_per_bucket = overlap_info.params_per_bucket
- for bucket_idx, bucket_params in enumerate(params_per_bucket):
- assert len(bucket_params) > 0, "Empty bucket"
- rank = self._ddp_bucket_index_to_rank(bucket_idx)
- bucket = _DDPBucket(bucket_idx, bucket_params)
- device = bucket_params[0].device # assume same device per bucket
- if device not in self._device_to_buckets_cache:
- self._device_to_buckets_cache[device] = [[] for _ in range(self.world_size)]
- self._device_to_buckets_cache[device][rank].append(bucket)
- return self._device_to_buckets_cache
+ if overlap_info.shard_buckets:
+ # Define the assignment threshold to approximate uniformity
+ assert overlap_info.total_size is not None, \
+ "`total_size` was not computed"
+ threshold = overlap_info.total_size / self.world_size # type: ignore[operator]
+ size_per_rank = [0 for _ in range(self.world_size)]
+
+ num_buckets = len(params_per_bucket)
+ overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
+ assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
+ if not overlap_info.shard_buckets:
+ # Assign each DDP bucket entirely to a single rank
+ for bucket_index, bucket_params in enumerate(params_per_bucket):
+ assert len(bucket_params) > 0, "Empty bucket"
+ assigned_rank = self._get_assigned_rank(bucket_index)
+ self._assign_bucket_subset_to_rank(
+ bucket_index,
+ bucket_params,
+ 0,
+ assigned_rank,
+ assigned_ranks_per_bucket,
+ )
+ else:
+ # Assign each DDP bucket to possibly multiple ranks
+ # Specifically, sort the DDP buckets by increasing size, and for
+ # each bucket, iteratively assign the maximal unassigned subset
+ # with size less than `threshold` to the rank with the least total
+ # size so far -- each such assignment is represented by a
+ # `_DDPBucketAssignment` instance and only contains parameters from
+ # a single DDP bucket
+ params_per_bucket_enum = sorted(
+ enumerate(params_per_bucket),
+ key=lambda x: sum(p.numel() for p in x[1])
+ )
+ for bucket_index, bucket_params in params_per_bucket_enum:
+ assert len(bucket_params) > 0, "Empty bucket"
+ bucket_offset = 0
+ assignment_size = 0
+ for param_index, param in enumerate(bucket_params):
+ param_numel = param.numel()
+ if assignment_size + param_numel >= threshold and param_index > bucket_offset:
+ assigned_rank = self._get_min_index(size_per_rank, assigned_ranks_per_bucket[bucket_index])
+ # Include up to but not including the parameter that
+ # exceeded the threshold
+ self._assign_bucket_subset_to_rank(
+ bucket_index,
+ bucket_params[bucket_offset:param_index],
+ bucket_offset,
+ assigned_rank,
+ assigned_ranks_per_bucket,
+ )
+ size_per_rank[assigned_rank] += assignment_size
+ bucket_offset = param_index
+ assignment_size = 0
+ assignment_size += param_numel
+ # Assign the remainder of the bucket so that no assignment
+ # spans across two buckets
+ assigned_rank = self._get_min_index(size_per_rank, assigned_ranks_per_bucket[bucket_index])
+ self._assign_bucket_subset_to_rank(
+ bucket_index,
+ bucket_params[bucket_offset:],
+ bucket_offset,
+ assigned_rank,
+ assigned_ranks_per_bucket,
+ )
+ size_per_rank[assigned_rank] += assignment_size
+
+ return self._bucket_assignments_per_rank_cache
def _local_step(
self,
def _build_ddp_param_buckets(self) -> None:
r"""
- Builds parameter buckets if ``overlap_with_ddp`` so that for each
- device that stores this rank's parameters, there is a :class:`list` of
- buckets (represented as tensors) containing the parameters on that
- device that are assigned to the rank in the parameter update
- partition and grouped following the :class:`DistributedDataParallel`
- gradient buckets.
-
- This method should only be called during the delayed initialization
- when ``overlap_with_ddp=True``.
-
- .. warning::
- The current implementation assumes that all of the parameters in a
- bucket are of the same dense type when allocating the bucket's
- tensor.
-
- .. warning::
- If the model parameters are stored across more than one device,
- then the storage partitioning must be the same across all
- processes in order for parameter synchronization to work.
+ For each DDP bucket with parameters assigned to this rank, flattens the
+ data of those parameters into a single tensor and saves the tensor to
+ the ``tensor`` attribute in the corresponding
+ :class:`_DDPBucketAssignment` instance stored in
+ ``self._bucket_assignments_per_rank``.
+
+ :class:`DistributedDataParallel` guarantees that the parameters
+ corresponding to a gradient bucket have the same device and the same
+ dtype.
"""
- assert self._overlap_with_ddp, \
- "`_build_ddp_param_buckets()` should only be called when " \
- "`overlap_with_ddp=True`"
+ for bucket_assignments in self._bucket_assignments_per_rank:
+ for bucket_assignment in bucket_assignments.values():
+ params = bucket_assignment.parameters
+ bucket_size = 0
+ dtype = None
+ for param in params:
+ assert _is_trainable(param), "Model parameter " \
+ "corresponding to a gradient in a DDP bucket should " \
+ "require a gradient"
+ bucket_size += param.numel()
+ dtype = param.dtype # assumes all same dtype
+ assert bucket_size > 0, "Empty bucket"
- num_devices = len(self._device_to_buckets)
- self._buckets = [[{} for _ in range(self.world_size)] for _ in range(num_devices)] # type: ignore[assignment]
-
- for dev_idx, (device, ddp_buckets_per_rank) in enumerate(self._device_to_buckets.items()):
- self._device_to_device_index[device] = dev_idx
- for rank, ddp_buckets in enumerate(ddp_buckets_per_rank):
- for ddp_bucket in ddp_buckets:
- bucket_index = ddp_bucket.bucket_index # type: ignore[attr-defined]
- params = ddp_bucket.params # type: ignore[attr-defined]
- bucket_size = 0
- dtype = None
- for param in params:
- assert _is_trainable(param), \
- "Model parameter corresponding to a gradient in " \
- "a DDP bucket should require a gradient"
- bucket_size += param.numel()
- dtype = param.dtype # assumes all same dtype
- assert bucket_size > 0
- bucket = torch.empty(bucket_size, dtype=dtype, device=device)
- offset = 0
- # Construct the bucket (assuming all dense and same dtype)
- for param in params:
- offset_next = offset + param.numel()
- bucket[offset:offset_next].copy_(param.data.flatten())
- param.data = bucket[offset:offset_next].view_as(param.data)
- offset = offset_next
- self._buckets[dev_idx][rank][bucket_index] = bucket
+ # Construct the bucket tensor (assuming all dense and same dtype)
+ tensor = torch.empty(bucket_size, dtype=dtype, device=bucket_assignment.device)
+ offset = 0
+ for param in params:
+ offset_next = offset + param.numel()
+ tensor[offset:offset_next].copy_(param.data.flatten())
+ param.data = tensor[offset:offset_next].view_as(param.data)
+ offset = offset_next
+ bucket_assignment.tensor = tensor
def _verify_and_init_params(self, params: Any) -> None:
r"""
"functional optimizer with more than one parameter group"
params = param_groups[0]["params"]
self.optim: Any = self._optim_constructor(params, **self._optim_defaults)
+
+ # Log information about the DDP and ZeRO bucketing
+ if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF:
+ local_numel = sum(p.numel() for p in params)
+ num_assigned_buckets = len(self._bucket_assignments_per_rank[self.global_rank])
+ logging.info(
+ f"rank {self.global_rank} with {local_numel} parameters "
+ f"across {num_assigned_buckets} buckets"
+ )
+ if self.global_rank == 0:
+ logging.info(
+ f"{len(self._overlap_info.params_per_bucket)} DDP "
+ f"buckets and "
+ f"{self._overlap_info.num_bucket_assignments} bucket "
+ "assignments"
+ )
else:
# NOTE: Passing `param_groups` into the local optimizer constructor
# bypasses the empty parameter list check
self._build_ddp_param_buckets()
self._init_local_optimizer()
- def _ddp_bucket_index_to_rank(self, bucket_index: int) -> int:
- r"""Assigns a rank to a given DDP gradient bucket index."""
- return bucket_index % self.world_size
-
- def _get_assigned_ddp_bucket_indices(self) -> List[int]:
+ def _get_assigned_rank(self, bucket_index: int) -> int:
r"""
- Returns a list of the DDP gradient bucket indices assigned to this rank
- to update.
+ Returns the single rank assigned to a :class:`DistributedDataParallel`
+ gradient bucket.
+
+ Arguments:
+ bucket_index (int): index of the :class:`DistributedDataParallel`
+ bucket for which to get the assigned rank.
"""
- assert self._overlap_info.status == _OverlapStatus.INITIALIZED
- num_buckets = len(self._overlap_info.params_per_bucket)
- assigned_indices = [
- bucket_index for bucket_index in range(num_buckets)
- if self._ddp_bucket_index_to_rank(bucket_index) == self.global_rank
- ]
- return assigned_indices
+ assert not self._overlap_info.shard_buckets, \
+ "The bucket assignment requires global bucket information and " \
+ "will be computed later; there should be no need to use this " \
+ "method"
+ return bucket_index % self.world_size
def _check_overlap_initialized(self):
r"""