"""
# Collect buffers for modules, filtering out buffers that should be ignored.
named_module_buffers = [
- [
- (buffer, buffer_name)
- for buffer_name, buffer in self.module.named_buffers()
- ]
+ (buffer, buffer_name)
+ for buffer_name, buffer in self.module.named_buffers()
]
self.modules_buffers = [
- [
- buffer
- for (buffer, buffer_name) in module_buffers
- if buffer_name not in self.parameters_to_ignore
- ]
- for module_buffers in named_module_buffers
+ buffer
+ for (buffer, buffer_name) in named_module_buffers
+ if buffer_name not in self.parameters_to_ignore
]
-
def _build_param_to_name_mapping(self, parameters):
param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
param_set = set(parameters[0])
if self.will_sync_module_buffers():
authoritative_rank = self._find_common_rank(self._distributed_rank, False)
self._distributed_broadcast_coalesced(
- self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank
+ self.modules_buffers, self.broadcast_bucket_size, authoritative_rank
)
# When running in join model, agrees upon a common rank and broadcast model
return (
self.require_forward_param_sync
and self.broadcast_buffers
- and len(self.modules_buffers[0]) > 0
+ and len(self.modules_buffers) > 0
)
def _find_common_rank(self, input_rank, rank_cond):
# reassigned.
self._assign_modules_buffers()
self._distributed_broadcast_coalesced(
- self.modules_buffers[0],
+ self.modules_buffers,
self.broadcast_bucket_size,
authoritative_rank,
)