From: Arunava Date: Sun, 7 Apr 2019 07:07:24 +0000 (-0700) Subject: convert_sync_batch_norm to SyncBatchNorm (#18787) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~353 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=79533ef097bf3ff2f1d85782ae1f279fe15e8686;p=platform%2Fupstream%2Fpytorch.git convert_sync_batch_norm to SyncBatchNorm (#18787) Summary: Closes #18382 Please let me know if any changes are required. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18787 Differential Revision: D14821147 Pulled By: soumith fbshipit-source-id: edd98eab1b3f4151c4ae5148146435ddb2ae678d --- diff --git a/test/test_distributed.py b/test/test_distributed.py index 7143f37..bc92a79 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -1479,7 +1479,7 @@ class _DistTestBase(object): model_gpu.cuda(gpu_subset[0]) # DDP training setup - model_DDP = nn.utils.convert_sync_batchnorm(copy.deepcopy(model)) + model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model)) model_DDP.cuda(gpu_subset[0]) model_DDP = nn.parallel.DistributedDataParallel( model_DDP, device_ids=gpu_subset diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index dc697e8..2d2034b 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -356,7 +356,7 @@ class SyncBatchNorm(_BatchNorm): or Spatio-temporal Batch Normalization. Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use - torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping + torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping Network with DDP. Args: @@ -458,3 +458,47 @@ class SyncBatchNorm(_BatchNorm): return sync_batch_norm.apply( input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, exponential_average_factor, process_group, world_size) + + @classmethod + def convert_sync_batchnorm(cls, module, process_group=None): + r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to + `torch.nn.SyncBatchNorm` layer. + + Args: + module (nn.Module): containing module + process_group (optional): process group to scope synchronization, + default is the whole world + + Returns: + The original module with the converted `torch.nn.SyncBatchNorm` layer + + Example:: + + >>> # Network with nn.BatchNorm layer + >>> module = torch.nn.Sequential( + >>> torch.nn.Linear(20, 100), + >>> torch.nn.BatchNorm1d(100) + >>> ).cuda() + >>> # creating process group (optional) + >>> # process_ids is a list of int identifying rank ids. + >>> process_group = torch.distributed.new_group(process_ids) + >>> sync_bn_module = convert_sync_batchnorm(module, process_group) + + """ + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module_output = torch.nn.SyncBatchNorm(module.num_features, + module.eps, module.momentum, + module.affine, + module.track_running_stats, + process_group) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, cls.convert_sync_batchnorm(child)) + del module + return module_output diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index 2398766..24d77e8 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -3,4 +3,3 @@ from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_ # noqa from .weight_norm import weight_norm, remove_weight_norm # noqa: F401 from .convert_parameters import parameters_to_vector, vector_to_parameters # noqa: F401 from .spectral_norm import spectral_norm, remove_spectral_norm # noqa: F401 -from .sync_batch_norm import convert_sync_batchnorm # noqa: F401 diff --git a/torch/nn/utils/sync_batch_norm.py b/torch/nn/utils/sync_batch_norm.py deleted file mode 100644 index ca034da..0000000 --- a/torch/nn/utils/sync_batch_norm.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch - - -def convert_sync_batchnorm(module, process_group=None): - r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to - `torch.nn.SyncBatchNorm` layer. - - Args: - module (nn.Module): containing module - process_group (optional): process group to scope synchronization, - default is the whole world - - Returns: - The original module with the converted `torch.nn.SyncBatchNorm` layer - - Example:: - - >>> # Network with nn.BatchNorm layer - >>> module = torch.nn.Sequential( - >>> torch.nn.Linear(20, 100), - >>> torch.nn.BatchNorm1d(100) - >>> ).cuda() - >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) - >>> sync_bn_module = convert_sync_batchnorm(module, process_group) - - """ - module_output = module - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - module_output = torch.nn.SyncBatchNorm(module.num_features, - module.eps, module.momentum, - module.affine, - module.track_running_stats, - process_group) - if module.affine: - module_output.weight.data = module.weight.data.clone().detach() - module_output.bias.data = module.bias.data.clone().detach() - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - for name, child in module.named_children(): - module_output.add_module(name, convert_sync_batchnorm(child)) - del module - return module_output