From: Andrew Gu Date: Fri, 13 Aug 2021 15:19:23 +0000 (-0700) Subject: Simplify data structures, add uniform approximation, fix mem leak (#63162) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~1046 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bd81c9178a82f967e14aa0b4cc3e97dac10d8f85;p=platform%2Fupstream%2Fpytorch.git Simplify data structures, add uniform approximation, fix mem leak (#63162) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63162 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D30284617 Pulled By: andwgu fbshipit-source-id: 9bd9e5f89abcc0d3dac56b85d55cc88e843baa9f --- diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index bee4a4f..c60c7de 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -916,7 +916,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): device, hook_constructor, gradient_as_bucket_view, - static_graph + static_graph, + **kwargs, ): SGD_LR = 0.01 SGD_MOMENTUM = 0.9 @@ -970,7 +971,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): ) ddp_model_overlap.register_comm_hook( None, - hook_constructor(allreduce_hook, ddp_model_overlap, zero_optim) + hook_constructor(allreduce_hook, ddp_model_overlap, zero_optim, **kwargs) ) # Set up the DDP model with local optimizer @@ -1095,6 +1096,62 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): # TODO: Add `test_ddp_with_zero_step_interleaved_parity_cpu()` once the # Gloo synchronization issue causing hangs is fixed. + @common_distributed.skip_if_win32() + @common_distributed.requires_nccl() + @common_distributed.skip_if_no_gpu + @common_distributed.skip_if_rocm + def test_ddp_with_zero_step_uniform_parity_gpu(self): + r""" + Check that overlapping DDP with ZeRO using + ``hook_with_zero_step()`` with ``shard_buckets=True`` + achieves parity with DDP using a local optimizer when running on GPU. + + NOTE: The test is skipped if using Windows since functional optimizers + are not currently supported. + """ + self.dist_init(self.rank, self.world_size, dist.Backend.NCCL) + for gradient_as_bucket_view, static_graph in itertools.product( + [True, False], + [True, False] + ): + self._test_ddp_zero_overlap( + torch.device(self.rank), + hook_with_zero_step, + gradient_as_bucket_view, + static_graph, + shard_buckets=True, + ) + # TODO: Add `test_ddp_with_zero_step_uniform_parity_cpu()` once the Gloo + # synchronization issue causing hangs is fixed. + + @common_distributed.skip_if_win32() + @common_distributed.requires_nccl() + @common_distributed.skip_if_no_gpu + @common_distributed.skip_if_rocm + def test_ddp_with_zero_step_interleaved_uniform_parity_gpu(self): + r""" + Check that overlapping DDP with ZeRO using + ``hook_with_zero_step()`` with ``shard_buckets=True`` + achieves parity with DDP using a local optimizer when running on GPU. + + NOTE: The test is skipped if using Windows since functional optimizers + are not currently supported. + """ + self.dist_init(self.rank, self.world_size, dist.Backend.NCCL) + for gradient_as_bucket_view, static_graph in itertools.product( + [True, False], + [True, False] + ): + self._test_ddp_zero_overlap( + torch.device(self.rank), + hook_with_zero_step_interleaved, + gradient_as_bucket_view, + static_graph, + shard_buckets=True, + ) + # TODO: Add `test_ddp_with_zero_step_interleaved_uniform_parity_cpu()` once + # the Gloo synchronization issue causing hangs is fixed. + if __name__ == "__main__": # ! unittest should not be used here, else the tests are not properly registered common_utils.run_tests() diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index c696ff2..b9ed357 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -1,9 +1,13 @@ +import weakref from typing import Any, Callable, List, Optional import torch import torch.distributed as dist from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus +from torch.distributed.optim.zero_redundancy_optimizer import ( + _get_global_rank, + _OverlapStatus, +) from torch.nn.parallel.distributed import DistributedDataParallel # Functional optimizers require passing a list of gradients to their `step()` @@ -44,9 +48,13 @@ def _perform_local_step( [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)] assert bucket_index in overlap_info.offsets, \ f"Bucket index {bucket_index} was not assigned to rank {rank}" - offset = overlap_info.offsets[bucket_index] - for i, grad in enumerate(bucket.gradients()): - gradients[offset + i] = grad + gradients_offset = overlap_info.offsets[bucket_index] + bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] + bucket_offset = bucket_assignment.offset + length = len(bucket_assignment.parameters) + bucket_gradients = bucket.gradients()[bucket_offset:bucket_offset + length] + for i, grad in enumerate(bucket_gradients): + gradients[gradients_offset + i] = grad zero._local_step(gradients) @@ -54,7 +62,6 @@ def _perform_local_step( def _broadcast_bucket( bucket_index: int, zero: ZeroRedundancyOptimizer, - assigned_rank: int, ): r""" Broadcasts a bucket's parameters. @@ -64,58 +71,63 @@ def _broadcast_bucket( parameters to broadcast. zero (ZeroRedundancyOptimizer): the calling process's :class:`ZeroRedundancyOptimizer` instance. - assigned_rank (int): the rank assigned to the bucket; it has the - updated parameters and serves as the source for the broadcast. """ overlap_info = zero._overlap_info - device = overlap_info.params_per_bucket[bucket_index][0].device - device_index = zero._device_to_device_index[device] - assert bucket_index in zero._buckets[device_index][assigned_rank] - overlap_info.broadcast_handles.append( - dist.broadcast( - zero._buckets[device_index][assigned_rank][bucket_index], - src=assigned_rank, - async_op=True - ) - ) - -def _collect_ddp_bucket_info( + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \ + "`assigned_ranks_per_bucket` is not fully constructed" + # Sort to ensure the same ordering across ranks + assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) + assert len(assigned_ranks) > 0, f"Bucket {bucket_index} should be " \ + "assigned to at least one rank" + for assigned_rank in assigned_ranks: + bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] + if bucket_index in bucket_assignments: + overlap_info.broadcast_handles.append( + dist.broadcast( + bucket_assignments[bucket_index].tensor, + src=_get_global_rank(zero.process_group, assigned_rank), + group=zero.process_group, + async_op=True, + ) + ) + + +def _save_ddp_bucket_info( bucket: dist.GradBucket, zero: ZeroRedundancyOptimizer, - rank: int, - assigned_rank: int, ): r""" - Collects :class:`DistributedDataParallel` gradient bucket information for - the :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when - overlapping. + Saves :class:`DistributedDataParallel` gradient bucket information for the + :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping. + In particular, this function is meant to be called upon seeing each + gradient bucket, meaning it does not save or compute any global + information. Arguments: bucket (dist.GradBucket): the current gradient bucket. zero (ZeroRedundancyOptimizer): the calling process's :class:`ZeroRedundancyOptimizer` instance. - rank (int): the calling process's rank. - assigned_rank (int): the rank assigned to update the parameters - corresponding to ``bucket``. """ overlap_info = zero._overlap_info - bucket_index = bucket.index() bucket_params = bucket.parameters() - assert len(bucket_params) > 0, "Bucket {bucket_index} is empty" - params_per_rank = overlap_info.params_per_rank - params_per_bucket = overlap_info.params_per_bucket + assert len(bucket_params) > 0, "Empty bucket" - # Collect relevant information - if assigned_rank == rank: - overlap_info.offsets[bucket_index] = len(params_per_rank[assigned_rank]) - params_per_rank[assigned_rank].extend(bucket_params) - params_per_bucket.append(bucket_params) + # Save the parameters in the bucket + overlap_info.params_per_bucket.append(bucket_params) + if overlap_info.shard_buckets: + # Additionally save the bucket size for the assignment heuristic to use + bucket_size = 0 + for param in bucket_params: + bucket_size += param.numel() + assert overlap_info.total_size is not None + overlap_info.total_size += bucket_size def hook_with_zero_step( hook: Callable[[Any, dist.GradBucket], torch.futures.Future], ddp: DistributedDataParallel, zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r""" Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer` @@ -139,6 +151,12 @@ def hook_with_zero_step( instance to use. zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). Returns: The modified hook. @@ -161,15 +179,20 @@ def hook_with_zero_step( "ZeroRedundancyOptimizer must be constructed with " "`overlap_with_ddp=True` to use this hook properly" ) + ddp_ref = weakref.ref(ddp) # NOTE: Gloo may hang with this overlapping approach, so we require # NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 - if dist.get_backend() != dist.Backend.NCCL: + if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL: # type: ignore[union-attr] raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " "NCCL backend to avoid hangs" ) + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + def hook_with_zero_fn( state: Any, bucket: dist.GradBucket, @@ -193,26 +216,27 @@ def hook_with_zero_step( bucket_index = bucket.index() # Proceed as normal until the DDP buckets have been rebuilt - if not ddp._has_rebuilt_buckets: + if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr] assert overlap_info.status == _OverlapStatus.UNINITIALIZED return fut if overlap_info.status == _OverlapStatus.UNINITIALIZED: overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS - rank = zero.global_rank - assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index) # Once DDP buckets have been rebuilt but ZeRO has not been # properly initialized yet, collect the information needed if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: - _collect_ddp_bucket_info(bucket, zero, rank, assigned_rank) + _save_ddp_bucket_info(bucket, zero) return fut assert overlap_info.status == _OverlapStatus.INITIALIZED + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \ + "`assigned_ranks_per_bucket` is not fully constructed" + assigned_to_bucket = rank in overlap_info.assigned_ranks_per_bucket[bucket_index] # Save the bucket reference and all-reduce future for the final bucket - if assigned_rank == rank: + if assigned_to_bucket: overlap_info.bucket_index_to_bucket[bucket_index] = bucket overlap_info.bucket_index_to_future[bucket_index] = fut @@ -238,8 +262,8 @@ def hook_with_zero_step( # all-reduce future since that would add synchronization that delays # all optimizer computation to wait for that last all-reduce for bucket_index in range(num_buckets): - assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index) - if assigned_rank == rank: + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] + if rank in assigned_ranks: # Wait on the bucket's all-reduce future to ensure correct # gradients assert bucket_index in overlap_info.bucket_index_to_future, \ @@ -252,11 +276,11 @@ def hook_with_zero_step( curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index] _perform_local_step(curr_bucket, zero, rank) - _broadcast_bucket(bucket_index, zero, assigned_rank) + _broadcast_bucket(bucket_index, zero) # Ensure that all parameter updates are finished before the # next forward pass - overlap_info.wait_for_broadcasts(num_buckets, rank) + overlap_info.wait_for_broadcasts() overlap_info.clear_per_iter_info() return fut @@ -268,6 +292,7 @@ def hook_with_zero_step_interleaved( hook: Callable[[Any, dist.GradBucket], torch.futures.Future], ddp: DistributedDataParallel, zero: ZeroRedundancyOptimizer, + shard_buckets: bool = False, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r""" Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer` @@ -292,6 +317,12 @@ def hook_with_zero_step_interleaved( instance to use. zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` instance to use. + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity; if + ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). Returns: The modified hook. @@ -314,15 +345,20 @@ def hook_with_zero_step_interleaved( "ZeroRedundancyOptimizer must be constructed with " "`overlap_with_ddp=True` to use this hook properly" ) + ddp_ref = weakref.ref(ddp) # NOTE: Gloo may hang with this overlapping approach, so we require # NCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 - if dist.get_backend() != dist.Backend.NCCL: + if dist.get_backend(ddp_ref().process_group) != dist.Backend.NCCL: # type: ignore[union-attr] raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " "NCCL backend to avoid hangs" ) + if shard_buckets: + zero._overlap_info.shard_buckets = True + zero._overlap_info.total_size = 0 + def hook_with_zero_interleaved_fn( state, bucket: dist.GradBucket, @@ -339,7 +375,7 @@ def hook_with_zero_step_interleaved( fut = hook(state, bucket) # Proceed as normal until the DDP buckets have been rebuilt - if not ddp._has_rebuilt_buckets: + if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr] assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED return fut @@ -353,11 +389,10 @@ def hook_with_zero_step_interleaved( A :class:`torch.Tensor` representing the contents of the gradient bucket. """ - assert ddp._has_rebuilt_buckets + assert ddp_ref()._has_rebuilt_buckets # type: ignore[union-attr] bucket_index = bucket.index() rank = zero.global_rank - assigned_rank = zero._ddp_bucket_index_to_rank(bucket_index) overlap_info = zero._overlap_info if overlap_info.status == _OverlapStatus.UNINITIALIZED: overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS @@ -365,20 +400,21 @@ def hook_with_zero_step_interleaved( # Once DDP buckets have been rebuilt but ZeRO has not been # properly initialized yet, collect the information needed if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: - _collect_ddp_bucket_info(bucket, zero, rank, assigned_rank) + _save_ddp_bucket_info(bucket, zero) return bucket.buffer() + assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] overlap_info.bucket_indices_seen.append(bucket_index) - if assigned_rank == rank: + if rank in assigned_ranks: _perform_local_step(bucket, zero, rank) - _broadcast_bucket(bucket_index, zero, assigned_rank) + _broadcast_bucket(bucket_index, zero) num_buckets = len(overlap_info.params_per_bucket) if len(overlap_info.bucket_indices_seen) == num_buckets: # Ensure that all parameter updates are finished before the # next forward pass - overlap_info.wait_for_broadcasts(num_buckets, rank) + overlap_info.wait_for_broadcasts() overlap_info.clear_per_iter_info() return bucket.buffer() diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index e1bad82..bba71e4 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -9,7 +9,7 @@ import enum import io import logging from itertools import chain -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type import torch import torch.distributed as dist @@ -121,17 +121,39 @@ class _ZeROJoinHook(JoinHook): self.zero.step() -class _DDPBucket(NamedTuple): +class _DDPBucketAssignment(): r""" - This contains the model parameters corresponding to a - :class:`DistributedDataParallel` gradient bucket. + This represents a :class:`DistributedDataParallel` bucket assignment, + meaning a (possibly non-strict) subset of the parameters corresponding to + a DDP bucket assigned to a rank to update. + Attributes: bucket_index (int): index of the bucket determined by the DDP gradient bucket all-reduce order. - params (List[torch.Tensor]): model parameters in the bucket. + parameters (List[torch.Tensor]): model parameters in the bucket + assigned to this rank. + offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` + giving the index of the first element in the passed-in + ``parameters``; this equivalently indexes into the + :class:`GradBucket` 's :meth:`gradients`. + device (torch.device): device on which the parameters are stored. + tensor (torch.Tensor): flattened tensor giving the data of the + parameter subset assigned to the rank. """ - bucket_index: int - params: List[torch.Tensor] + def __init__( + self, + bucket_index: int, + parameters: List[torch.Tensor], + offset: int, + ): + self.bucket_index = bucket_index + self.parameters = parameters + self.offset = offset + if len(self.parameters) == 0: + raise ValueError("Empty bucket assignment") + # DDP guarantees all parameters in the bucket have the same device + self.device: torch.device = self.parameters[0].device + self.tensor: Optional[torch.Tensor] = None class _OverlapStatus(enum.IntEnum): @@ -158,6 +180,18 @@ class _OverlapInfo(): This contains the information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. + Arguments: + world_size (int): world size of the process group being used. + + Attributes: + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity following + a threshold given by the total parameter size divided by the world + size; if ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); + this should be set to the value passed into the hook constructor. status (_OverlapStatus): current status; see :class:`_OverlapStatus` for more information. params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]`` @@ -170,6 +204,13 @@ class _OverlapInfo(): parameter in that bucket, where ``rank`` is this process's own rank; the keys of this :class:`dict` are the bucket indices assigned to this rank. + num_bucket_assignments (int): total number of bucket assignments across + all ranks; this is equal to the number of + :class:`DistributedDataParallel` gradient buckets if + ``shard_buckets=False`` and possibly greater otherwise. + total_size (int, optional): total size of all buckets (i.e. sum of + ``param.numel()`` for all ``param`` across all buckets) if + ``shard_buckets=True``; otherwise, ``None``. broadcast_handles (List[Work]): :class:`list` of async work handles for the parameter broadcasts. bucket_index_to_future (Dict[int, torch.futures.Future]): @@ -180,14 +221,18 @@ class _OverlapInfo(): bucket_indices_seen (List[int]): :class:`list` of the bucket indices seen on this iteration. """ - def __init__(self) -> None: + def __init__(self, world_size) -> None: self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED + self.shard_buckets: bool = False # Modified per bucket reconstruction self.params_per_bucket: List[List[torch.Tensor]] = [] self.params_per_rank: List[List[torch.Tensor]] = \ - [[] for _ in range(dist.get_world_size())] + [[] for _ in range(world_size)] self.offsets: Dict[int, int] = {} + self.assigned_ranks_per_bucket: List[Set[int]] = [] + self.num_bucket_assignments: int = 0 + self.total_size: Optional[int] = None # Modified per iteration self.broadcast_handles: List[Any] = [] @@ -196,19 +241,15 @@ class _OverlapInfo(): self.bucket_index_to_future: Dict[int, torch.futures.Future] = {} self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {} - def wait_for_broadcasts(self, num_buckets, rank) -> None: + def wait_for_broadcasts(self) -> None: r""" Waits for all parameter broadcasts. This should be called once all broadcasts have been scheduled, meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` in preparation for the next iteration. - - Arguments: - num_buckets (int): total number of buckets. - rank (int): the calling process's rank. """ - assert len(self.broadcast_handles) == num_buckets, \ - f"Missing at least one broadcast handle on rank {rank}" + assert len(self.broadcast_handles) == self.num_bucket_assignments, \ + f"Missing at least one broadcast handle on rank {dist.get_rank()}" _ = list(map(lambda x: x.wait(), self.broadcast_handles)) self.broadcast_handles.clear() @@ -339,8 +380,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._partition_parameters_cache: List[List[Dict]] = [] self._index_to_param_cache: List[torch.Tensor] = [] self._device_to_params_per_rank_cache: Dict[torch.device, List[List[torch.Tensor]]] = {} - self._device_to_buckets_cache: Dict[torch.device, List[List[_DDPBucket]]] = {} - self._device_to_device_index: Dict[torch.device, int] = {} + self._bucket_assignments_per_rank_cache: List[Dict[int, _DDPBucketAssignment]] = [] self._is_trainable_mask = self._get_is_trainable_mask() # Default device for collective communication and buckets @@ -360,7 +400,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): if not overlap_with_ddp: self._init_local_optimizer() else: - self._overlap_info: _OverlapInfo = _OverlapInfo() + self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) if parameters_as_bucket_view: logging.warning( "`parameters_as_bucket_view=True` will be ignored since " @@ -368,16 +408,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): "strategy will be used" ) - # `self._buckets` is used if `parameters_as_bucket_view=True` or - # `overlap_with_ddp=True`, in which case parameter data is flattened - # into buckets (i.e. contiguous tensors) - # If `overlap_with_ddp=True`, the bucketing requires an additional - # dimension to match the DDP gradient bucketing + # `self._buckets` is used if `parameters_as_bucket_view=True`, in + # which case parameter data is flattened into contiguous bucket tensors self.parameters_as_bucket_view = parameters_as_bucket_view - self._buckets: Union[ - List[List[torch.Tensor]], - List[List[Dict[int, torch.Tensor]]] - ] = [] # type: ignore[assignment] + self._buckets: List[List[torch.Tensor]] = [] self._build_param_buckets() # Optional consolidated optimizer state, only populated if this rank @@ -395,7 +429,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._index_to_param_cache.clear() self._param_to_index_cache.clear() self._device_to_params_per_rank_cache.clear() - self._device_to_buckets_cache.clear() + self._bucket_assignments_per_rank_cache.clear() def add_param_group(self, param_group: dict) -> None: r""" @@ -610,7 +644,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): params_sorted = sorted(param_group["params"], key=lambda t: t.numel(), reverse=True) for param in params_sorted: # Greedily add the parameter to rank with smallest size so far - rank = sizes.index(min(sizes)) + rank = self._get_min_index(sizes) param_group_params_per_rank[rank].append(param) sizes[rank] += param.numel() # Apply the constructed partition of the parameter group @@ -762,44 +796,160 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._device_to_params_per_rank_cache[device][rank].append(param) return self._device_to_params_per_rank_cache - @property - def _device_to_buckets( - self - ) -> Dict[torch.device, List[List[_DDPBucket]]]: + def _get_min_index( + self, + values: List[int], + disallowed_indices: Optional[Set[int]] = None, + ) -> int: r""" - :class:`dict` mapping each device to a :class:`list` of :class:`list` - of :class:`_DDPBucket` s. + Returns ``values.index(min(values))``, except only uses one pass. It + also excludes any indices in ``disallowed_indices`` if provided. - ``_device_to_buckets[d][r][i]`` gives the ``i``th bucket - assigned to rank ``r`` stored on device ``d``, where each bucket - contains a list of the model parameters associated with the - corresponding logical :class:`DistributedDataParallel` gradient bucket. + Arguments: + values: (List[int]): :class:`list` of values. + disallowed_indices (Optional[Set[int]]): indices that are + disallowed from being the returned min index. + """ + min_index = -1 + min_value = float("inf") + for i, value in enumerate(values): + if disallowed_indices and i in disallowed_indices: + continue + if value < min_value: + min_value = value + min_index = i + assert min_index >= 0, "All indices are disallowed" + return min_index + + def _assign_bucket_subset_to_rank( + self, + bucket_index: int, + bucket_params: List[torch.Tensor], + bucket_offset: int, + assigned_rank: int, + assigned_ranks_per_bucket: List[Set[int]], + ) -> None: + r""" + Assigns the model parameters given by ``bucket_params``, representing a + (possibly non-strict) subset of the parameters corresponding to a + :class:`DistributedDataParallel` bucket, to the rank with the least + size assigned so far and collects relevant information. - This is used for constructing the parameter buckets if - ``overlap_with_ddp=True``. + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + gradient bucket. + bucket_params (List[torch.Tensor]): subset of the parameters + corresponding to the bucket to assign. + bucket_offset (int): offset giving the index of the first element + in ``bucket_params`` in the bucket's full parameter list. + assigned_rank (int): rank to assign to. + assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of ranks + assigned to each bucket. """ - assert self._overlap_with_ddp, \ - "`_device_to_buckets()` should only be used if " \ - "`overlap_with_ddp=True`" - if len(self._device_to_buckets_cache) > 0: - return self._device_to_buckets_cache + overlap_info = self._overlap_info + if len(bucket_params) == 0: + raise ValueError( + "Empty bucket assignment" + ) + params_per_rank = overlap_info.params_per_rank + offsets = overlap_info.offsets + + self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = \ + _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + if self.global_rank == assigned_rank: + offsets[bucket_index] = len(params_per_rank[assigned_rank]) + params_per_rank[assigned_rank].extend(bucket_params) + assigned_ranks_per_bucket[bucket_index].add(assigned_rank) + self._overlap_info.num_bucket_assignments += 1 + + @property + def _bucket_assignments_per_rank( + self + ) -> List[Dict[int, _DDPBucketAssignment]]: + r""" + :class:`list` of length world size consisting of :class:`dict` s + mapping bucket indices to :class:`_DDPBucketAssignment` s for each + rank. + """ + assert self._overlap_with_ddp, "`_bucket_assignments_per_rank` " \ + "only be used if `overlap_with_ddp=True`" + if len(self._bucket_assignments_per_rank_cache) > 0: + return self._bucket_assignments_per_rank_cache overlap_info = self._overlap_info - assert overlap_info.status == _OverlapStatus.INITIALIZED, \ - "Accessing `_device_to_buckets` before the necessary " \ - "information has been collected" + assert overlap_info.status == _OverlapStatus.INITIALIZED + self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] params_per_bucket = overlap_info.params_per_bucket - for bucket_idx, bucket_params in enumerate(params_per_bucket): - assert len(bucket_params) > 0, "Empty bucket" - rank = self._ddp_bucket_index_to_rank(bucket_idx) - bucket = _DDPBucket(bucket_idx, bucket_params) - device = bucket_params[0].device # assume same device per bucket - if device not in self._device_to_buckets_cache: - self._device_to_buckets_cache[device] = [[] for _ in range(self.world_size)] - self._device_to_buckets_cache[device][rank].append(bucket) - return self._device_to_buckets_cache + if overlap_info.shard_buckets: + # Define the assignment threshold to approximate uniformity + assert overlap_info.total_size is not None, \ + "`total_size` was not computed" + threshold = overlap_info.total_size / self.world_size # type: ignore[operator] + size_per_rank = [0 for _ in range(self.world_size)] + + num_buckets = len(params_per_bucket) + overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] + assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket + if not overlap_info.shard_buckets: + # Assign each DDP bucket entirely to a single rank + for bucket_index, bucket_params in enumerate(params_per_bucket): + assert len(bucket_params) > 0, "Empty bucket" + assigned_rank = self._get_assigned_rank(bucket_index) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params, + 0, + assigned_rank, + assigned_ranks_per_bucket, + ) + else: + # Assign each DDP bucket to possibly multiple ranks + # Specifically, sort the DDP buckets by increasing size, and for + # each bucket, iteratively assign the maximal unassigned subset + # with size less than `threshold` to the rank with the least total + # size so far -- each such assignment is represented by a + # `_DDPBucketAssignment` instance and only contains parameters from + # a single DDP bucket + params_per_bucket_enum = sorted( + enumerate(params_per_bucket), + key=lambda x: sum(p.numel() for p in x[1]) + ) + for bucket_index, bucket_params in params_per_bucket_enum: + assert len(bucket_params) > 0, "Empty bucket" + bucket_offset = 0 + assignment_size = 0 + for param_index, param in enumerate(bucket_params): + param_numel = param.numel() + if assignment_size + param_numel >= threshold and param_index > bucket_offset: + assigned_rank = self._get_min_index(size_per_rank, assigned_ranks_per_bucket[bucket_index]) + # Include up to but not including the parameter that + # exceeded the threshold + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:param_index], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + size_per_rank[assigned_rank] += assignment_size + bucket_offset = param_index + assignment_size = 0 + assignment_size += param_numel + # Assign the remainder of the bucket so that no assignment + # spans across two buckets + assigned_rank = self._get_min_index(size_per_rank, assigned_ranks_per_bucket[bucket_index]) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + size_per_rank[assigned_rank] += assignment_size + + return self._bucket_assignments_per_rank_cache def _local_step( self, @@ -1109,57 +1259,38 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _build_ddp_param_buckets(self) -> None: r""" - Builds parameter buckets if ``overlap_with_ddp`` so that for each - device that stores this rank's parameters, there is a :class:`list` of - buckets (represented as tensors) containing the parameters on that - device that are assigned to the rank in the parameter update - partition and grouped following the :class:`DistributedDataParallel` - gradient buckets. - - This method should only be called during the delayed initialization - when ``overlap_with_ddp=True``. - - .. warning:: - The current implementation assumes that all of the parameters in a - bucket are of the same dense type when allocating the bucket's - tensor. - - .. warning:: - If the model parameters are stored across more than one device, - then the storage partitioning must be the same across all - processes in order for parameter synchronization to work. + For each DDP bucket with parameters assigned to this rank, flattens the + data of those parameters into a single tensor and saves the tensor to + the ``tensor`` attribute in the corresponding + :class:`_DDPBucketAssignment` instance stored in + ``self._bucket_assignments_per_rank``. + + :class:`DistributedDataParallel` guarantees that the parameters + corresponding to a gradient bucket have the same device and the same + dtype. """ - assert self._overlap_with_ddp, \ - "`_build_ddp_param_buckets()` should only be called when " \ - "`overlap_with_ddp=True`" + for bucket_assignments in self._bucket_assignments_per_rank: + for bucket_assignment in bucket_assignments.values(): + params = bucket_assignment.parameters + bucket_size = 0 + dtype = None + for param in params: + assert _is_trainable(param), "Model parameter " \ + "corresponding to a gradient in a DDP bucket should " \ + "require a gradient" + bucket_size += param.numel() + dtype = param.dtype # assumes all same dtype + assert bucket_size > 0, "Empty bucket" - num_devices = len(self._device_to_buckets) - self._buckets = [[{} for _ in range(self.world_size)] for _ in range(num_devices)] # type: ignore[assignment] - - for dev_idx, (device, ddp_buckets_per_rank) in enumerate(self._device_to_buckets.items()): - self._device_to_device_index[device] = dev_idx - for rank, ddp_buckets in enumerate(ddp_buckets_per_rank): - for ddp_bucket in ddp_buckets: - bucket_index = ddp_bucket.bucket_index # type: ignore[attr-defined] - params = ddp_bucket.params # type: ignore[attr-defined] - bucket_size = 0 - dtype = None - for param in params: - assert _is_trainable(param), \ - "Model parameter corresponding to a gradient in " \ - "a DDP bucket should require a gradient" - bucket_size += param.numel() - dtype = param.dtype # assumes all same dtype - assert bucket_size > 0 - bucket = torch.empty(bucket_size, dtype=dtype, device=device) - offset = 0 - # Construct the bucket (assuming all dense and same dtype) - for param in params: - offset_next = offset + param.numel() - bucket[offset:offset_next].copy_(param.data.flatten()) - param.data = bucket[offset:offset_next].view_as(param.data) - offset = offset_next - self._buckets[dev_idx][rank][bucket_index] = bucket + # Construct the bucket tensor (assuming all dense and same dtype) + tensor = torch.empty(bucket_size, dtype=dtype, device=bucket_assignment.device) + offset = 0 + for param in params: + offset_next = offset + param.numel() + tensor[offset:offset_next].copy_(param.data.flatten()) + param.data = tensor[offset:offset_next].view_as(param.data) + offset = offset_next + bucket_assignment.tensor = tensor def _verify_and_init_params(self, params: Any) -> None: r""" @@ -1245,6 +1376,22 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): "functional optimizer with more than one parameter group" params = param_groups[0]["params"] self.optim: Any = self._optim_constructor(params, **self._optim_defaults) + + # Log information about the DDP and ZeRO bucketing + if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF: + local_numel = sum(p.numel() for p in params) + num_assigned_buckets = len(self._bucket_assignments_per_rank[self.global_rank]) + logging.info( + f"rank {self.global_rank} with {local_numel} parameters " + f"across {num_assigned_buckets} buckets" + ) + if self.global_rank == 0: + logging.info( + f"{len(self._overlap_info.params_per_bucket)} DDP " + f"buckets and " + f"{self._overlap_info.num_bucket_assignments} bucket " + "assignments" + ) else: # NOTE: Passing `param_groups` into the local optimizer constructor # bypasses the empty parameter list check @@ -1274,22 +1421,20 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._build_ddp_param_buckets() self._init_local_optimizer() - def _ddp_bucket_index_to_rank(self, bucket_index: int) -> int: - r"""Assigns a rank to a given DDP gradient bucket index.""" - return bucket_index % self.world_size - - def _get_assigned_ddp_bucket_indices(self) -> List[int]: + def _get_assigned_rank(self, bucket_index: int) -> int: r""" - Returns a list of the DDP gradient bucket indices assigned to this rank - to update. + Returns the single rank assigned to a :class:`DistributedDataParallel` + gradient bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + bucket for which to get the assigned rank. """ - assert self._overlap_info.status == _OverlapStatus.INITIALIZED - num_buckets = len(self._overlap_info.params_per_bucket) - assigned_indices = [ - bucket_index for bucket_index in range(num_buckets) - if self._ddp_bucket_index_to_rank(bucket_index) == self.global_rank - ] - return assigned_indices + assert not self._overlap_info.shard_buckets, \ + "The bucket assignment requires global bucket information and " \ + "will be computed later; there should be no need to use this " \ + "method" + return bucket_index % self.world_size def _check_overlap_initialized(self): r""" diff --git a/torch/distributed/optim/zero_redundancy_optimizer.pyi b/torch/distributed/optim/zero_redundancy_optimizer.pyi index 7dc0812..bfea1e7 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.pyi +++ b/torch/distributed/optim/zero_redundancy_optimizer.pyi @@ -1,19 +1,32 @@ import enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Type, +) import torch from torch.distributed.algorithms.join import Joinable, JoinHook from torch.optim import Optimizer +def _get_global_rank(group: Any, rank: int) -> int: ... + class _ZeROJoinHook(JoinHook): zero: Any = ... def __init__(self, zero: Any) -> None: ... def main_hook(self) -> None: ... -class _DDPBucket(NamedTuple): +class _DDPBucketAssignment(): bucket_index: int - params: List[torch.Tensor] + parameters: List[torch.Tensor] + offset: int + device: torch.device + tensor: Optional[torch.Tensor] class _OverlapStatus(enum.IntEnum): UNINITIALIZED: int = ... @@ -29,8 +42,11 @@ class _OverlapInfo: bucket_index_to_future: Any = ... bucket_index_to_bucket: Any = ... bucket_indices_seen: Any = ... + assigned_ranks_per_bucket: List[Set[int]] = ... + total_size: int = ... + shard_buckets: bool = ... def __init__(self) -> None: ... - def wait_for_broadcasts(self, num_buckets, rank) -> None: ... + def wait_for_broadcasts(self) -> None: ... def clear_per_iter_info(self) -> None: ... class ZeroRedundancyOptimizer(Optimizer, Joinable): @@ -45,7 +61,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): _device_to_device_index: Dict[torch.device, int] = ... _overlap_with_ddp: bool = ... _overlap_info: _OverlapInfo = ... - _buckets: Union[List[List[torch.Tensor]], List[List[Dict[int, torch.Tensor]]]] = ... + _buckets: List[List[torch.Tensor]] = ... + _bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ... def __init__(self, params: Any, optimizer_class: Type[Optimizer], process_group: Optional[Any]=..., parameters_as_bucket_view: bool=..., overlap_with_ddp: bool=..., **defaults: Any) -> None: ... def add_param_group(self, param_group: dict) -> None: ... def consolidate_state_dict(self, to: int=...) -> None: ... @@ -53,7 +70,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... def state_dict(self) -> Dict[str, Any]: ... def _local_step(self, gradients: Optional[List[Optional[torch.Tensor]]] = None, closure: Optional[Callable[[], float]] = None, **kwargs: Any,) -> Optional[float]: ... - def _ddp_bucket_index_to_rank(self, bucket_index: int) -> int: ... + def _get_assigned_rank(self, bucket_index: int) -> int: ... def join_hook(self, **kwargs): ... def join_device(self) -> torch.device: ... def join_process_group(self) -> Any: ...