convert_sync_batch_norm to SyncBatchNorm (#18787)
authorArunava <learningdroidarunava@gmail.com>
Sun, 7 Apr 2019 07:07:24 +0000 (00:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 7 Apr 2019 07:13:02 +0000 (00:13 -0700)
Summary:
Closes #18382

Please let me know if any changes are required.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18787

Differential Revision: D14821147

Pulled By: soumith

fbshipit-source-id: edd98eab1b3f4151c4ae5148146435ddb2ae678d

test/test_distributed.py
torch/nn/modules/batchnorm.py
torch/nn/utils/__init__.py
torch/nn/utils/sync_batch_norm.py [deleted file]

index 7143f37..bc92a79 100644 (file)
@@ -1479,7 +1479,7 @@ class _DistTestBase(object):
         model_gpu.cuda(gpu_subset[0])
 
         # DDP training setup
-        model_DDP = nn.utils.convert_sync_batchnorm(copy.deepcopy(model))
+        model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
         model_DDP.cuda(gpu_subset[0])
         model_DDP = nn.parallel.DistributedDataParallel(
             model_DDP, device_ids=gpu_subset
index dc697e8..2d2034b 100644 (file)
@@ -356,7 +356,7 @@ class SyncBatchNorm(_BatchNorm):
     or Spatio-temporal Batch Normalization.
 
     Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
-    torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
+    torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
     Network with DDP.
 
     Args:
@@ -458,3 +458,47 @@ class SyncBatchNorm(_BatchNorm):
             return sync_batch_norm.apply(
                 input, self.weight, self.bias, self.running_mean, self.running_var,
                 self.eps, exponential_average_factor, process_group, world_size)
+
+    @classmethod
+    def convert_sync_batchnorm(cls, module, process_group=None):
+        r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
+        `torch.nn.SyncBatchNorm` layer.
+
+        Args:
+            module (nn.Module): containing module
+            process_group (optional): process group to scope synchronization,
+        default is the whole world
+
+        Returns:
+            The original module with the converted `torch.nn.SyncBatchNorm` layer
+
+        Example::
+
+            >>> # Network with nn.BatchNorm layer
+            >>> module = torch.nn.Sequential(
+            >>>            torch.nn.Linear(20, 100),
+            >>>            torch.nn.BatchNorm1d(100)
+            >>>          ).cuda()
+            >>> # creating process group (optional)
+            >>> # process_ids is a list of int identifying rank ids.
+            >>> process_group = torch.distributed.new_group(process_ids)
+            >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
+
+        """
+        module_output = module
+        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+            module_output = torch.nn.SyncBatchNorm(module.num_features,
+                                                   module.eps, module.momentum,
+                                                   module.affine,
+                                                   module.track_running_stats,
+                                                   process_group)
+            if module.affine:
+                module_output.weight.data = module.weight.data.clone().detach()
+                module_output.bias.data = module.bias.data.clone().detach()
+            module_output.running_mean = module.running_mean
+            module_output.running_var = module.running_var
+            module_output.num_batches_tracked = module.num_batches_tracked
+        for name, child in module.named_children():
+            module_output.add_module(name, cls.convert_sync_batchnorm(child))
+        del module
+        return module_output
index 2398766..24d77e8 100644 (file)
@@ -3,4 +3,3 @@ from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_  # noqa
 from .weight_norm import weight_norm, remove_weight_norm  # noqa: F401
 from .convert_parameters import parameters_to_vector, vector_to_parameters  # noqa: F401
 from .spectral_norm import spectral_norm, remove_spectral_norm  # noqa: F401
-from .sync_batch_norm import convert_sync_batchnorm  # noqa: F401
diff --git a/torch/nn/utils/sync_batch_norm.py b/torch/nn/utils/sync_batch_norm.py
deleted file mode 100644 (file)
index ca034da..0000000
+++ /dev/null
@@ -1,45 +0,0 @@
-import torch
-
-
-def convert_sync_batchnorm(module, process_group=None):
-    r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
-    `torch.nn.SyncBatchNorm` layer.
-
-    Args:
-        module (nn.Module): containing module
-        process_group (optional): process group to scope synchronization,
-    default is the whole world
-
-    Returns:
-        The original module with the converted `torch.nn.SyncBatchNorm` layer
-
-    Example::
-
-        >>> # Network with nn.BatchNorm layer
-        >>> module = torch.nn.Sequential(
-        >>>            torch.nn.Linear(20, 100),
-        >>>            torch.nn.BatchNorm1d(100)
-        >>>          ).cuda()
-        >>> # creating process group (optional)
-        >>> # process_ids is a list of int identifying rank ids.
-        >>> process_group = torch.distributed.new_group(process_ids)
-        >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
-
-    """
-    module_output = module
-    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
-        module_output = torch.nn.SyncBatchNorm(module.num_features,
-                                               module.eps, module.momentum,
-                                               module.affine,
-                                               module.track_running_stats,
-                                               process_group)
-        if module.affine:
-            module_output.weight.data = module.weight.data.clone().detach()
-            module_output.bias.data = module.bias.data.clone().detach()
-        module_output.running_mean = module.running_mean
-        module_output.running_var = module.running_var
-        module_output.num_batches_tracked = module.num_batches_tracked
-    for name, child in module.named_children():
-        module_output.add_module(name, convert_sync_batchnorm(child))
-    del module
-    return module_output