(#14267)
authorjiej <jiej@nvidia.com>
Wed, 6 Mar 2019 21:36:14 +0000 (13:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Mar 2019 21:39:11 +0000 (13:39 -0800)
Summary:
- Summary:

Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group.
Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API)

- User-facing api:

1. torch.nn.utils.convert_sync_batchnorm(modules, process_group=None)

2. torch.nn.SyncBatchNorm(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, ***process_group=None***)

- supported use case:
DistributedDataParallel with ***single-gpu multi-process***

a. User creates model containing `torch.nn.SyncBatchNorm` layers through one of the ways listed below:

  1. use layers directly:

     torch.nn.SyncBatchNorm(...)

     similar API as with torch.nn.BatchNormXd(...)
     with added argument `process_group` which is used to limit the scope of
     synchronization within each process group. Default value is None, which
     implies synchronization across all GPUs

  2. use torch.nn.utils.convert_sync_batchnorm(modules, process_group)

     recursively convert all `torch.nn.BatchNormXd` into `torch.nn.SyncBatchNorm`
     preserving values of parameters/buffers.
     the utility function also allows user to specify process_group value to all
     converted layers.

b. user wraps their model with
   `torch.distributed.parallel.DataParallelDistributed`, from this point, user
   should follow the general guidelines for DDP use guide

- Error checking

For use cases not supported, we error out:

1. Application launched without ddp:
   > import torch
   > sbn = torch.nn.SyncBatchNorm(10).cuda()
   > inp = torch.randn(5, 10, 3, 3).cuda()
   > sbn(inp) --> Error!
   > AttributeError: SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel

2. Application launched using DDP with multi-GPU per-process:
   > ddp_module = nn.parallel.DistributedDataParallel(module, device_ids=device_ids, output_device=args.local_rank)
   > ValueError: SyncBatchNorm is only supported for DDP with single GPU per process
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14267

Differential Revision: D14270035

Pulled By: ezyang

fbshipit-source-id: 4956d8fa565c32e9df5408d53719ff9f945f4d6d

aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/Normalization.cuh
aten/src/ATen/native/native_functions.yaml
test/test_distributed.py
torch/nn/modules/__init__.py
torch/nn/modules/_functions.py [new file with mode: 0644]
torch/nn/modules/batchnorm.py
torch/nn/parallel/distributed.py
torch/nn/utils/__init__.py
torch/nn/utils/sync_batch_norm.py [new file with mode: 0644]

index e1acf51..d77236e 100644 (file)
@@ -24,6 +24,63 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
     });
 }
 
+std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_stats", [&] {
+      if (cuda::detail::canUse32BitIndexMath(self)) {
+        return batch_norm_stats_cuda_template<scalar_t, int32_t>(self, epsilon);
+      } else {
+        return batch_norm_stats_cuda_template<scalar_t, int64_t>(self, epsilon);
+      }
+    });
+}
+
+Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
+                             const Tensor& mean, const Tensor& invstd, double epsilon) {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_elemt", [&] {
+      if (cuda::detail::canUse32BitIndexMath(self)) {
+        return batch_norm_elemt_cuda_template<scalar_t, int32_t>(self, weight, bias, mean, invstd, epsilon);
+      } else {
+        return batch_norm_elemt_cuda_template<scalar_t, int64_t>(self, weight, bias, mean, invstd, epsilon);
+      }
+    });
+}
+
+// accepting input(self) here to determine template data types, since running_mean/running_var are optional
+std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean,
+                                                        const Tensor& running_var, double momentum, double epsilon, int64_t count) {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_update_stats", [&] {
+      int world_size = mean.size(1);
+      using accscalar_t = at::acc_type<scalar_t, true>;
+      if (cuda::detail::canUse32BitIndexMath(self)) {
+        return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int32_t>(mean, invstd, running_mean, running_var, momentum, epsilon, static_cast<int32_t>(count));
+      } else {
+        return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int64_t>(mean, invstd, running_mean, running_var, momentum, epsilon, count);
+      }
+    });
+}
+
+std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean,
+                                                                           const Tensor& invstd, bool input_g, bool weight_g, bool bias_g) {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_reduce", [&] {
+      if (cuda::detail::canUse32BitIndexMath(self)) {
+        return batch_norm_backward_reduce_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
+      } else {
+        return batch_norm_backward_reduce_cuda_template<scalar_t, int64_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
+      }
+    });
+}
+
+Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd,
+                                      const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_elemt", [&] {
+      if (cuda::detail::canUse32BitIndexMath(self)) {
+        return batch_norm_backward_elemt_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
+      } else {
+        return batch_norm_backward_elemt_cuda_template<scalar_t, int64_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
+      }
+    });
+}
+
 std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
         const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
   return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
index 5ac9c85..35d42bb 100644 (file)
@@ -40,6 +40,15 @@ static int getNumThreads(int nElem) {
   return MAX_BLOCK_SIZE;
 }
 
+static int lastPow2(unsigned int n) {
+    n |= (n >>  1);
+    n |= (n >>  2);
+    n |= (n >>  4);
+    n |= (n >>  8);
+    n |= (n >> 16);
+    return n - (n >> 1);
+}
+
 // Returns the index of the most significant 1 bit in `val`.
 __device__ __forceinline__ int getMSB(int val) {
   return 31 - __clz(val);
@@ -283,8 +292,8 @@ __global__ void batch_norm_collect_statistics_kernel(
 
   if (tid < WARP_SIZE) {
     n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0);
-    avg = (tid < blockDim.x * blockDim.y  / WARP_SIZE ? shared_avg_var[2 * tid] : 0);
-    var_n = (tid < blockDim.x * blockDim.y  / WARP_SIZE ? shared_avg_var[2 * tid + 1] : 0);
+    avg = (tid < blockDim.x * blockDim.y  / WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
+    var_n = (tid < blockDim.x * blockDim.y  / WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
   }
   for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
     stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE);
@@ -297,8 +306,12 @@ __global__ void batch_norm_collect_statistics_kernel(
 
   // Save the mean, variance, and moving averages
   if (tid == 0) {
-    save_mean[plane] = avg;
-    save_transformed_var[plane] = VarTransform<stat_accscalar_t>{}(var_n / N, epsilon);
+    if (save_mean.data() != NULL) {
+      save_mean[plane] = avg;
+    }
+    if (save_transformed_var.data() != NULL) {
+      save_transformed_var[plane] = VarTransform<stat_accscalar_t>{}(var_n / N, epsilon);
+    }
     if (running_mean.data() != NULL) {
       running_mean[plane] = static_cast<stat_scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
     }
@@ -381,6 +394,127 @@ __global__ void batch_norm_backward_kernel(
   }
 }
 
+template <typename scalar_t, typename accscalar_t, typename index_t>
+__global__ void batch_norm_reduce_statistics_kernel(
+    const PackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
+    const PackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
+    PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
+    PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
+    PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
+    PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
+    const accscalar_t epsilon,
+    const accscalar_t momentum,
+    const index_t count) {
+
+  int feature_size = vec_mean.size(1);
+  int world_size = vec_mean.size(0);
+
+  int bid = blockIdx.x;
+  int tid = threadIdx.x;
+
+  // first the reductions each thread does separately
+  for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
+    accscalar_t avg = 0;
+    accscalar_t var_n = 0;
+    index_t n = 0;
+    for (int j = 0; j < world_size; j++) {
+      accscalar_t m = vec_mean[j][i];
+      accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
+      v = (v * v - epsilon) * count;
+      accscalar_t factor = 1.0 / (n + count);
+      var_n += v + (avg - m) * (avg - m) * n * count * factor;
+      avg = n * factor * avg + count * factor * m;
+      n += count;
+    }
+    mean[i] = avg;
+    invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
+    if (running_mean.data() != NULL) {
+      running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
+    }
+    accscalar_t unbiasedVar = var_n / (n - 1);
+    if (running_var.data() != NULL) {
+      running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
+    }
+  }
+
+}
+
+template <typename scalar_t, typename accscalar_t, typename index_t>
+__global__ void batch_norm_backward_reduce_kernel(
+    const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> input,
+    const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
+    PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean,
+    PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
+    PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy,
+    PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy_xmu,
+    PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
+    PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
+
+  index_t plane = blockIdx.x;
+  index_t N = input.size(0) * input.size(2);
+
+  accscalar_t r_mean = mean[plane];
+  accscalar_t factor = invstd[plane];
+
+  GradOp<scalar_t, accscalar_t, PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
+  Float2<scalar_t, accscalar_t> res = reduce<Float2<scalar_t, accscalar_t>, GradOp<scalar_t, accscalar_t,
+                                                                                   PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t>>>(g, grad_output, plane);
+
+  accscalar_t norm = accscalar_t(1) / N;
+  if (threadIdx.x == 0) {
+    if (grad_weight.size(0) > 0) {
+      grad_weight[plane] = static_cast<scalar_t>(res.v2 * factor);
+    }
+    if (grad_bias.size(0) > 0) {
+      grad_bias[plane] = static_cast<scalar_t>(res.v1);
+    }
+    if (mean_dy.size(0) > 0) {
+      mean_dy[plane] = static_cast<accscalar_t>(res.v1 * norm);
+    }
+    if (mean_dy_xmu.size(0) > 0) {
+      mean_dy_xmu[plane] = static_cast<accscalar_t>(res.v2 * norm);
+    }
+  }
+}
+
+template <typename scalar_t, typename accscalar_t, typename index_t>
+__global__ void batch_norm_backward_elemt_kernel(
+    const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> input,
+    const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
+    const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean,
+    const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
+    const PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> weight,
+    const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy,
+    const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy_xmu,
+    PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_input) {
+
+  index_t plane = blockIdx.x;
+
+  if (plane >= input.size(1)) {
+    return;
+  }
+
+  accscalar_t m_c = mean[plane];
+  accscalar_t m_dy_c = mean_dy[plane];
+  accscalar_t factor_1_c = invstd[plane];
+  accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : static_cast<accscalar_t>(1);
+  factor_2_c *= factor_1_c;
+  factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[plane];
+
+  index_t bs = input.size(0);
+  index_t fs = input.size(2);
+
+  index_t bstep  = blockDim.y * gridDim.y;
+  for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
+    auto g_i = grad_input[batch][plane];
+    auto g_o = grad_output[batch][plane];
+    auto i = input[batch][plane];
+    for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
+      g_i[feature] = static_cast<scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
+    }
+  }
+}
+
 template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
 static PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(const Tensor& t) {
   if (! t.defined()) {
@@ -435,7 +569,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda_template(const Tensor& input_
 
   // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
   // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
-  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+  // and good occupancy. Quite likely, we could go with even more blocks than 1024.
   // The various planes are independent, so we use blocks for them.
   int tf = std::max<int>(getNumThreads(input.size(2)/4),
                          std::min<int>(getNumThreads(input.size(2)), 64));
@@ -510,6 +644,208 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tenso
   return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
 }
 
+template<typename scalar_t, typename index_t>
+std::tuple<Tensor, Tensor> batch_norm_stats_cuda_template(const Tensor& input_, double epsilon) {
+
+  using accscalar_t = at::acc_type<scalar_t, true>;
+  int64_t n_input = input_.size(1);
+  Tensor dummy_mean_;
+  Tensor dummy_var_;
+  Tensor mean_;
+  Tensor invstd_;
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+
+  auto bs = input_reshaped.size(0);
+  auto features = input_reshaped.size(2);
+  auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+  auto input_options = input_.options();
+  dummy_mean_ = at::empty({0}, input_options);
+  dummy_var_ = at::empty({0}, input_options);
+  // promote only mean_/invstd_ precision
+  if (input_.type().scalarType() == at::ScalarType::Half) {
+    input_options = input_options.dtype(ScalarType::Float);
+  }
+  mean_ = at::empty({n_input}, input_options);
+  invstd_ = at::empty({n_input}, input_options);
+  auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(mean_);
+  auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
+  auto dummy_mean = dummy_mean_.packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
+  auto dummy_invstd = dummy_var_.packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 blocks(input.size(1));
+  int tf = getNumThreads(input.size(2));
+  dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
+  batch_norm_collect_statistics_kernel<InvStd, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+    (input, epsilon, 0.0, dummy_mean, dummy_invstd, mean, invstd);
+  THCudaCheck(cudaGetLastError());
+  return std::make_tuple(mean_, invstd_);
+}
+
+template<typename scalar_t, typename index_t>
+Tensor batch_norm_elemt_cuda_template(const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
+                                                                  const Tensor& mean_, const Tensor& invstd_,
+                                                                  double epsilon) {
+
+  using accscalar_t = at::acc_type<scalar_t, true>;
+  int64_t n_input = input_.size(1);
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto output_reshaped = at::empty_like(input_reshaped);
+
+  auto bs = input_reshaped.size(0);
+  auto features = input_reshaped.size(2);
+  auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+  auto input_options = input_.options();
+  if (input_.type().scalarType() == at::ScalarType::Half) {
+    input_options = input_options.dtype(ScalarType::Float);
+  }
+  auto output = output_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+  auto weight = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
+  auto bias = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
+  auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(mean_);
+  auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+  // The various planes are independent, so we use blocks for them.
+  int tf = std::max<int>(getNumThreads(input.size(2)/4),
+                         std::min<int>(getNumThreads(input.size(2)), 64));
+  int tb = std::max<int>(64/tf, 1);
+  dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
+                                                                  (input.size(0)+tb-1)/tb)));
+  dim3 threads_trans(tf, tb);
+  batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
+    (input, output, mean, invstd, weight, bias, epsilon);
+  THCudaCheck(cudaGetLastError());
+  return output_reshaped.view(input_.sizes());
+}
+
+template<typename scalar_t, typename accscalar_t, typename index_t>
+std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
+                                                                 const Tensor& running_mean_, const Tensor& running_var_,
+                                                                 double momentum, double epsilon, index_t count) {
+
+  Tensor save_mean_;
+  Tensor save_invstd_;
+
+  auto features = mean_.size(1);
+  auto input_options = mean_.options();
+  if (mean_.type().scalarType() == at::ScalarType::Half) {
+    input_options = input_options.dtype(ScalarType::Float);
+  }
+  save_mean_ = at::empty({features}, input_options);
+  save_invstd_ = at::empty({features}, input_options);
+
+  auto mean = packed_accessor_or_dummy<accscalar_t, 2, RestrictPtrTraits, index_t>(mean_);
+  auto invstd = packed_accessor_or_dummy<accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_);
+  auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
+  auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
+  auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+  auto save_invstd = save_invstd_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  int block = getNumThreads(features);
+  int grid = std::max<int>(1, features/block);
+  batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
+      (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, count);
+  THCudaCheck(cudaGetLastError());
+  return std::make_tuple(save_mean_, save_invstd_);
+}
+
+template<typename scalar_t, typename index_t>
+std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+                                                                                    const Tensor& mean_, const Tensor& invstd_,
+                                                                                    const bool input_g, const bool weight_g, const bool bias_g) {
+
+  using accscalar_t = at::acc_type<scalar_t, true>;
+  int64_t n_input = input_.size(1);
+  Tensor mean_dy_;
+  Tensor mean_dy_xmu_;
+  Tensor grad_weight_;
+  Tensor grad_bias_;
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+
+  if (input_g) {
+    mean_dy_ = at::empty_like(mean_);
+    mean_dy_xmu_ = at::empty_like(mean_);
+  }
+  auto grad_options = grad_out_.options();
+  if (weight_g) {
+    grad_weight_ = at::empty({n_input}, grad_options);
+  }
+  if (bias_g) {
+    grad_bias_ = at::empty({n_input}, grad_options);
+  }
+
+  auto input = input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+  auto grad_output = grad_output_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+  auto grad_weight = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_);
+  auto grad_bias = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_);
+  auto mean = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_);
+  auto invstd = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_);
+  auto mean_dy = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_);
+  auto mean_dy_xmu = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_xmu_);
+
+  auto batch_size = input_reshaped.size(0);
+  auto feature_size = input_reshaped.size(2);
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/32);
+  int block_x = std::min<int>(getNumThreads(feature_size), MAX_BLOCK_SIZE/block_y);
+  const dim3 block(block_x, block_y);
+  const dim3 grid(n_input);
+
+  batch_norm_backward_reduce_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
+    (input, grad_output, mean, invstd, mean_dy, mean_dy_xmu, grad_weight, grad_bias);
+  THCudaCheck(cudaGetLastError());
+
+  return std::make_tuple(mean_dy_, mean_dy_xmu_, grad_weight_, grad_bias_);
+}
+
+template<typename scalar_t, typename index_t>
+Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+                                               const Tensor& mean_, const Tensor& invstd_,
+                                               const Tensor& weight_, const Tensor& mean_dy_, const Tensor& mean_dy_xmu_) {
+
+  using accscalar_t = at::acc_type<scalar_t, true>;
+  int64_t n_input = input_.size(1);
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+  auto grad_input_reshaped = at::empty_like(input_reshaped);
+
+  auto bs = input_reshaped.size(0);
+  auto features = input_reshaped.size(2);
+
+  auto input = input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+  auto grad_input = grad_input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+  auto grad_output = grad_output_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+  auto mean = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_);
+  auto invstd = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_);
+  auto weight = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(weight_);
+  auto mean_dy = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_);
+  auto mean_dy_xmu = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_xmu_);
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+  // The various planes are independent, so we use blocks for them.
+  int tf = std::max<int>(getNumThreads(input.size(2)/4),
+                         std::min<int>(getNumThreads(input.size(2)), 64));
+  int tb = std::max<int>(64/tf, 1);
+  dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
+                                                                  (input.size(0)+tb-1)/tb)));
+  dim3 threads_trans(tf, tb);
+  batch_norm_backward_elemt_kernel<scalar_t, accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
+    (input, grad_output, mean, invstd, weight, mean_dy, mean_dy_xmu, grad_input);
+  THCudaCheck(cudaGetLastError());
+  return grad_input_reshaped.view(input_.sizes());
+}
+
 template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
 std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
         const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
@@ -542,7 +878,6 @@ std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
     (input, 0., momentum, running_mean, running_var, save_mean, save_var);
   THCudaCheck(cudaGetLastError());
   return std::make_tuple(save_mean_, save_var_);
-
 }
 
 } } // namespace at::native
index 2e8bb06..8c370fe 100644 (file)
     CPU: batch_norm_cpu
     CUDA: batch_norm_cuda
 
+- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_stats_cuda
+
+- func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
+  dispatch:
+    CUDA: batch_norm_elemt_cuda
+
+- func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_gather_stats_cuda
+
 - func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
   matches_jit_signature: True
   dispatch:
     CPU: batch_norm_backward_cpu
     CUDA: batch_norm_backward_cuda
 
+- func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_backward_reduce_cuda
+
+- func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor mean_dy, Tensor mean_dy_xmu) -> Tensor
+  dispatch:
+    CUDA: batch_norm_backward_elemt_cuda
+
 - func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)
   matches_jit_signature: True
   dispatch:
index 637336b..152499d 100644 (file)
@@ -62,7 +62,24 @@ class Net(nn.Module):
         return F.softmax(x, dim=1)
 
 
+class BatchNormNet(nn.Module):
+
+    def __init__(self):
+        super(BatchNormNet, self).__init__()
+        self.fc1 = nn.Linear(2, 40, bias=False)
+        self.bn = nn.BatchNorm1d(4)
+        self.fc2 = nn.Linear(40, 4, bias=False)
+
+    def forward(self, x):
+        x = torch.reshape(self.fc1(x), (-1, 4, 10))
+        x = self.bn(x)
+        x = torch.reshape(x, (-1, 40))
+        x = self.fc2(x)
+        return F.softmax(x, dim=1)
+
+
 DDP_NET = Net()
+BN_NET = BatchNormNet()
 
 
 def get_timeout(test_id):
@@ -1357,6 +1374,65 @@ class _DistTestBase(object):
         gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
         self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
 
+    def _test_DistributedDataParallel_SyncBatchNorm(self, gpu_subset, rank, output_device=None):
+        # Run a simple end to end DDP model, use result of single node model
+        # as baseline
+
+        # cpu training setup
+        model = BN_NET
+
+        # single gpu training setup
+        model_gpu = copy.deepcopy(model)
+        model_gpu.cuda(gpu_subset[0])
+
+        # DDP training setup
+        model_DDP = nn.utils.convert_sync_batchnorm(copy.deepcopy(model))
+        model_DDP.cuda(gpu_subset[0])
+        model_DDP = nn.parallel.DistributedDataParallel(
+            model_DDP, device_ids=gpu_subset
+        )
+
+        # test serializable/unserializable
+        if INIT_METHOD.startswith("file://"):
+            _, filename = tempfile.mkstemp(prefix=FOLDER)
+            torch.save(model_DDP, filename)
+            model_DDP = torch.load(filename)
+
+        # dummy data initialization
+        local_bs = len(gpu_subset)
+        global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+        # check two model parameters over 5 iterations
+        self._test_DDP_5iter(
+            model_gpu,
+            model_DDP,
+            input_cpu.cuda(gpu_subset[0]),
+            target.cuda(gpu_subset[0]),
+            loss,
+            local_bs,
+            rank,
+            global_bs,
+            True
+        )
+        self._barrier()
+
+    @unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
+                     "Only Nccl & Gloo backend support DistributedDataParallel")
+    @skip_if_no_cuda_distributed
+    @skip_if_no_gpu
+    def test_DistributedDataParallel_SyncBatchNorm(self):
+        group, group_id, rank = self._init_global_test()
+        rank_to_GPU = self._init_multigpu_helper()
+        gpus = list(rank_to_GPU[rank])
+        self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank)
+
+        # test output_device
+        self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
+
+        # test device_ids
+        gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
+        self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
+
 if BACKEND == "gloo" or BACKEND == "nccl":
     WORLD_SIZE = os.environ["WORLD_SIZE"]
 
index a14d6d9..71c1267 100644 (file)
@@ -13,7 +13,7 @@ from .container import Container, Sequential, ModuleList, ModuleDict, ParameterL
 from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
     MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
     AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
-from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
+from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm
 from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
 from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
 from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
@@ -40,7 +40,7 @@ __all__ = [
     'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
     'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
     'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
-    'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm',
+    'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
     'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
     'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
     'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py
new file mode 100644 (file)
index 0000000..ba2dfa7
--- /dev/null
@@ -0,0 +1,98 @@
+import torch
+from torch.autograd.function import Function
+
+
+class SyncBatchNorm(Function):
+
+    @staticmethod
+    def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
+        input = input.contiguous()
+
+        # calcualte mean/invstd for input.
+        mean, invstd = torch.batch_norm_stats(input, eps)
+
+        mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
+        invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device)
+        mean_l = list(mean_all.unbind(0))
+        invstd_l = list(invstd_all.unbind(0))
+        # using all_gather instead of all reduce so we can calculate mean/var in one go
+        mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True)
+        invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True)
+
+        # wait on the async communication to finish
+        mean_all_reduce.wait()
+        invstd_all_reduce.wait()
+
+        # calcualte global mean & invstd
+        mean, invstd = torch.batch_norm_gather_stats(
+            input,
+            mean_all,
+            invstd_all,
+            running_mean,
+            running_var,
+            momentum,
+            eps,
+            int(input.numel() / input.size(1))
+        )
+
+        self.save_for_backward(input, weight, mean, invstd)
+        self.process_group = process_group
+        self.world_size = world_size
+
+        # apply element-wise normalization
+        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
+        return out
+
+    @staticmethod
+    def backward(self, grad_output):
+        grad_output = grad_output.contiguous()
+        saved_input, weight, mean, invstd = self.saved_tensors
+        grad_input = grad_weight = grad_bias = None
+        process_group = self.process_group
+        world_size = self.world_size
+
+        # calculate local stats as well as grad_weight / grad_bias
+        mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
+            grad_output,
+            saved_input,
+            mean,
+            invstd,
+            self.needs_input_grad[0],
+            self.needs_input_grad[1],
+            self.needs_input_grad[2]
+        )
+
+        if self.needs_input_grad[0]:
+            # synchronizing stats used to calculate input gradient.
+            # TODO: move div_ into batch_norm_backward_elemt kernel
+            mean_dy_all_reduce = torch.distributed.all_reduce(
+                mean_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
+            mean_dy_xmu_all_reduce = torch.distributed.all_reduce(
+                mean_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
+
+            # wait on the async communication to finish
+            mean_dy_all_reduce.wait()
+            mean_dy_xmu_all_reduce.wait()
+
+            mean_dy.div_(world_size)
+            mean_dy_xmu.div_(world_size)
+            # backward pass for gradient calculation
+            grad_input = torch.batch_norm_backward_elemt(
+                grad_output,
+                saved_input,
+                mean,
+                invstd,
+                weight,
+                mean_dy,
+                mean_dy_xmu
+            )
+
+        # synchronizing of grad_weight / grad_bias is not needed as distributed
+        # training would handle all reduce.
+        if weight is None or not self.needs_input_grad[1]:
+            grad_weight = None
+
+        if weight is None or not self.needs_input_grad[2]:
+            grad_bias = None
+
+        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
index b6a692d..dc697e8 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import division
 
 import torch
+from ._functions import SyncBatchNorm as sync_batch_norm
 from .module import Module
 from torch.nn.parameter import Parameter
 from .. import functional as F
@@ -316,3 +317,144 @@ class BatchNorm3d(_BatchNorm):
         if input.dim() != 5:
             raise ValueError('expected 5D input (got {}D input)'
                              .format(input.dim()))
+
+
+class SyncBatchNorm(_BatchNorm):
+    r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
+    with additional channel dimension) as described in the paper
+    `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
+
+    .. math::
+
+        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+    The mean and standard-deviation are calculated per-dimension over all
+    mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
+    are learnable parameter vectors of size `C` (where `C` is the input size).
+    By default, the elements of :math:`\gamma` are sampled from
+    :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
+
+    Also by default, during training this layer keeps running estimates of its
+    computed mean and variance, which are then used for normalization during
+    evaluation. The running estimates are kept with a default :attr:`momentum`
+    of 0.1.
+
+    If :attr:`track_running_stats` is set to ``False``, this layer then does not
+    keep running estimates, and batch statistics are instead used during
+    evaluation time as well.
+
+    .. note::
+        This :attr:`momentum` argument is different from one used in optimizer
+        classes and the conventional notion of momentum. Mathematically, the
+        update rule for running statistics here is
+        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
+        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+        new observed value.
+
+    Because the Batch Normalization is done over the `C` dimension, computing statistics
+    on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization
+    or Spatio-temporal Batch Normalization.
+
+    Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
+    torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
+    Network with DDP.
+
+    Args:
+        num_features: :math:`C` from an expected input of size
+            :math:`(N, C, +)`
+        eps: a value added to the denominator for numerical stability.
+            Default: 1e-5
+        momentum: the value used for the running_mean and running_var
+            computation. Can be set to ``None`` for cumulative moving average
+            (i.e. simple average). Default: 0.1
+        affine: a boolean value that when set to ``True``, this module has
+            learnable affine parameters. Default: ``True``
+        track_running_stats: a boolean value that when set to ``True``, this
+            module tracks the running mean and variance, and when set to ``False``,
+            this module does not track such statistics and always uses batch
+            statistics in both training and eval modes. Default: ``True``
+        process_group: synchronization of stats happen within each process group
+            individually. Default behavior is synchronization across the whole
+            world
+
+    Shape:
+        - Input: :math:`(N, C, +)`
+        - Output: :math:`(N, C, +)` (same shape as input)
+
+    Examples::
+
+        >>> # With Learnable Parameters
+        >>> m = nn.SyncBatchNorm(100)
+        >>> # creating process group (optional)
+        >>> # process_ids is a list of int identifying rank ids.
+        >>> process_group = torch.distributed.new_group(process_ids)
+        >>> # Without Learnable Parameters
+        >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
+        >>> input = torch.randn(20, 100, 35, 45, 10)
+        >>> output = m(input)
+
+        >>> # network is nn.BatchNorm layer
+        >>> sync_bn_network = torch.nn.utils.convert_sync_batchnorm(network, process_group)
+        >>> # only single gpu per process is currently supported
+        >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
+        >>>                         sync_bn_network,
+        >>>                         device_ids=[args.local_rank],
+        >>>                         output_device=args.local_rank)
+
+    .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+        https://arxiv.org/abs/1502.03167
+    """
+
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+                 track_running_stats=True, process_group=None):
+        super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
+        self.process_group = process_group
+        # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
+        # under supported condition (single GPU per process)
+        self.ddp_gpu_size = None
+
+    def _check_input_dim(self, input):
+        if input.dim() <= 2:
+            raise ValueError('expected at least 3D input (got {}D input)'
+                             .format(input.dim()))
+
+    def _specify_ddp_gpu_num(self, gpu_size):
+        if gpu_size > 1:
+            raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
+        self.ddp_gpu_size = gpu_size
+
+    def forward(self, input):
+        # currently only GPU input is supported
+        if not input.is_cuda:
+            raise ValueError('expected input tensor to be on GPU')
+
+        if not self.ddp_gpu_size:
+            raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
+
+        self._check_input_dim(input)
+
+        exponential_average_factor = 0.0
+
+        if self.training and self.track_running_stats:
+            self.num_batches_tracked += 1
+            if self.momentum is None:  # use cumulative moving average
+                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
+            else:  # use exponential moving average
+                exponential_average_factor = self.momentum
+
+        world_size = 1
+        process_group = torch.distributed.group.WORLD
+        if self.process_group:
+            process_group = self.process_group
+        world_size = torch.distributed.get_world_size(process_group)
+
+        # fallback to framework BN when synchronization is not necessary
+        if world_size == 1 or (not self.training and self.track_running_stats):
+            return F.batch_norm(
+                input, self.running_mean, self.running_var, self.weight, self.bias,
+                self.training or not self.track_running_stats,
+                exponential_average_factor, self.eps)
+        else:
+            return sync_batch_norm.apply(
+                input, self.weight, self.bias, self.running_mean, self.running_var,
+                self.eps, exponential_average_factor, process_group, world_size)
index 730a735..a33f321 100644 (file)
@@ -230,6 +230,7 @@ class DistributedDataParallel(Module):
         (2) bucketing the parameters for reductions
         (3) resetting the bucketing states
         (4) registering the grad hooks
+        (5) passing a handle of DDP to SyncBatchNorm Layer
         """
         if len(self.device_ids) > 1:
             # TODO: we don't need to replicate params in here. they're always going to
@@ -307,6 +308,9 @@ class DistributedDataParallel(Module):
         self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
         self._register_grad_hooks()
 
+        # passing a handle to torch.nn.SyncBatchNorm layer
+        self._passing_sync_batchnorm_handle(self._module_copies)
+
     def __getstate__(self):
         self._check_default_group()
         attrs = copy.copy(self.__dict__)
@@ -407,6 +411,12 @@ class DistributedDataParallel(Module):
                         for tensor, buffer_data in zip(tensors, module_buffers_data):
                             buffer_data.set_(tensor)
 
+    def _passing_sync_batchnorm_handle(self, module_copies):
+        for dev_idx, module in enumerate(module_copies):
+            for layer in module.modules():
+                if isinstance(layer, torch.nn.modules.SyncBatchNorm):
+                    layer._specify_ddp_gpu_num(len(self.device_ids))
+
     def _register_grad_hooks(self):
         self._grad_accs = []  # need to keep them in scope
 
index 1af5034..472aaa8 100644 (file)
@@ -3,3 +3,4 @@ from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
 from .weight_norm import weight_norm, remove_weight_norm
 from .convert_parameters import parameters_to_vector, vector_to_parameters
 from .spectral_norm import spectral_norm, remove_spectral_norm
+from .sync_batch_norm import convert_sync_batchnorm
diff --git a/torch/nn/utils/sync_batch_norm.py b/torch/nn/utils/sync_batch_norm.py
new file mode 100644 (file)
index 0000000..ca034da
--- /dev/null
@@ -0,0 +1,45 @@
+import torch
+
+
+def convert_sync_batchnorm(module, process_group=None):
+    r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
+    `torch.nn.SyncBatchNorm` layer.
+
+    Args:
+        module (nn.Module): containing module
+        process_group (optional): process group to scope synchronization,
+    default is the whole world
+
+    Returns:
+        The original module with the converted `torch.nn.SyncBatchNorm` layer
+
+    Example::
+
+        >>> # Network with nn.BatchNorm layer
+        >>> module = torch.nn.Sequential(
+        >>>            torch.nn.Linear(20, 100),
+        >>>            torch.nn.BatchNorm1d(100)
+        >>>          ).cuda()
+        >>> # creating process group (optional)
+        >>> # process_ids is a list of int identifying rank ids.
+        >>> process_group = torch.distributed.new_group(process_ids)
+        >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
+
+    """
+    module_output = module
+    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+        module_output = torch.nn.SyncBatchNorm(module.num_features,
+                                               module.eps, module.momentum,
+                                               module.affine,
+                                               module.track_running_stats,
+                                               process_group)
+        if module.affine:
+            module_output.weight.data = module.weight.data.clone().detach()
+            module_output.bias.data = module.bias.data.clone().detach()
+        module_output.running_mean = module.running_mean
+        module_output.running_var = module.running_var
+        module_output.num_batches_tracked = module.num_batches_tracked
+    for name, child in module.named_children():
+        module_output.add_module(name, convert_sync_batchnorm(child))
+    del module
+    return module_output