(#14267)
authorjiej <jiej@nvidia.com>
Wed, 6 Mar 2019 21:36:14 +0000 (13:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Mar 2019 21:39:11 +0000 (13:39 -0800)
commit39669316a6fc31c26f5ef1be5c0a0fe862d661cf
treeead4d553f56b913fcf3613d4517be7a68dddcc9c
parent0ed1b9fb980f5f069910cc011d6eedaf075b63b2
(#14267)

Summary:
- Summary:

Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group.
Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API)

- User-facing api:

1. torch.nn.utils.convert_sync_batchnorm(modules, process_group=None)

2. torch.nn.SyncBatchNorm(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, ***process_group=None***)

- supported use case:
DistributedDataParallel with ***single-gpu multi-process***

a. User creates model containing `torch.nn.SyncBatchNorm` layers through one of the ways listed below:

  1. use layers directly:

     torch.nn.SyncBatchNorm(...)

     similar API as with torch.nn.BatchNormXd(...)
     with added argument `process_group` which is used to limit the scope of
     synchronization within each process group. Default value is None, which
     implies synchronization across all GPUs

  2. use torch.nn.utils.convert_sync_batchnorm(modules, process_group)

     recursively convert all `torch.nn.BatchNormXd` into `torch.nn.SyncBatchNorm`
     preserving values of parameters/buffers.
     the utility function also allows user to specify process_group value to all
     converted layers.

b. user wraps their model with
   `torch.distributed.parallel.DataParallelDistributed`, from this point, user
   should follow the general guidelines for DDP use guide

- Error checking

For use cases not supported, we error out:

1. Application launched without ddp:
   > import torch
   > sbn = torch.nn.SyncBatchNorm(10).cuda()
   > inp = torch.randn(5, 10, 3, 3).cuda()
   > sbn(inp) --> Error!
   > AttributeError: SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel

2. Application launched using DDP with multi-GPU per-process:
   > ddp_module = nn.parallel.DistributedDataParallel(module, device_ids=device_ids, output_device=args.local_rank)
   > ValueError: SyncBatchNorm is only supported for DDP with single GPU per process
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14267

Differential Revision: D14270035

Pulled By: ezyang

fbshipit-source-id: 4956d8fa565c32e9df5408d53719ff9f945f4d6d
aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/Normalization.cuh
aten/src/ATen/native/native_functions.yaml
test/test_distributed.py
torch/nn/modules/__init__.py
torch/nn/modules/_functions.py [new file with mode: 0644]
torch/nn/modules/batchnorm.py
torch/nn/parallel/distributed.py
torch/nn/utils/__init__.py
torch/nn/utils/sync_batch_norm.py [new file with mode: 0644]