});
}
+std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_stats", [&] {
+ if (cuda::detail::canUse32BitIndexMath(self)) {
+ return batch_norm_stats_cuda_template<scalar_t, int32_t>(self, epsilon);
+ } else {
+ return batch_norm_stats_cuda_template<scalar_t, int64_t>(self, epsilon);
+ }
+ });
+}
+
+Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
+ const Tensor& mean, const Tensor& invstd, double epsilon) {
+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_elemt", [&] {
+ if (cuda::detail::canUse32BitIndexMath(self)) {
+ return batch_norm_elemt_cuda_template<scalar_t, int32_t>(self, weight, bias, mean, invstd, epsilon);
+ } else {
+ return batch_norm_elemt_cuda_template<scalar_t, int64_t>(self, weight, bias, mean, invstd, epsilon);
+ }
+ });
+}
+
+// accepting input(self) here to determine template data types, since running_mean/running_var are optional
+std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean,
+ const Tensor& running_var, double momentum, double epsilon, int64_t count) {
+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_update_stats", [&] {
+ int world_size = mean.size(1);
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ if (cuda::detail::canUse32BitIndexMath(self)) {
+ return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int32_t>(mean, invstd, running_mean, running_var, momentum, epsilon, static_cast<int32_t>(count));
+ } else {
+ return batch_norm_gather_stats_cuda_template<scalar_t, accscalar_t, int64_t>(mean, invstd, running_mean, running_var, momentum, epsilon, count);
+ }
+ });
+}
+
+std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean,
+ const Tensor& invstd, bool input_g, bool weight_g, bool bias_g) {
+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_reduce", [&] {
+ if (cuda::detail::canUse32BitIndexMath(self)) {
+ return batch_norm_backward_reduce_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
+ } else {
+ return batch_norm_backward_reduce_cuda_template<scalar_t, int64_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
+ }
+ });
+}
+
+Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd,
+ const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) {
+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_elemt", [&] {
+ if (cuda::detail::canUse32BitIndexMath(self)) {
+ return batch_norm_backward_elemt_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
+ } else {
+ return batch_norm_backward_elemt_cuda_template<scalar_t, int64_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
+ }
+ });
+}
+
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
return MAX_BLOCK_SIZE;
}
+static int lastPow2(unsigned int n) {
+ n |= (n >> 1);
+ n |= (n >> 2);
+ n |= (n >> 4);
+ n |= (n >> 8);
+ n |= (n >> 16);
+ return n - (n >> 1);
+}
+
// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
if (tid < WARP_SIZE) {
n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0);
- avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : 0);
- var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : 0);
+ avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
+ var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
}
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE);
// Save the mean, variance, and moving averages
if (tid == 0) {
- save_mean[plane] = avg;
- save_transformed_var[plane] = VarTransform<stat_accscalar_t>{}(var_n / N, epsilon);
+ if (save_mean.data() != NULL) {
+ save_mean[plane] = avg;
+ }
+ if (save_transformed_var.data() != NULL) {
+ save_transformed_var[plane] = VarTransform<stat_accscalar_t>{}(var_n / N, epsilon);
+ }
if (running_mean.data() != NULL) {
running_mean[plane] = static_cast<stat_scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
}
}
}
+template <typename scalar_t, typename accscalar_t, typename index_t>
+__global__ void batch_norm_reduce_statistics_kernel(
+ const PackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
+ const PackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
+ PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
+ PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
+ PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
+ PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
+ const accscalar_t epsilon,
+ const accscalar_t momentum,
+ const index_t count) {
+
+ int feature_size = vec_mean.size(1);
+ int world_size = vec_mean.size(0);
+
+ int bid = blockIdx.x;
+ int tid = threadIdx.x;
+
+ // first the reductions each thread does separately
+ for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
+ accscalar_t avg = 0;
+ accscalar_t var_n = 0;
+ index_t n = 0;
+ for (int j = 0; j < world_size; j++) {
+ accscalar_t m = vec_mean[j][i];
+ accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
+ v = (v * v - epsilon) * count;
+ accscalar_t factor = 1.0 / (n + count);
+ var_n += v + (avg - m) * (avg - m) * n * count * factor;
+ avg = n * factor * avg + count * factor * m;
+ n += count;
+ }
+ mean[i] = avg;
+ invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
+ if (running_mean.data() != NULL) {
+ running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
+ }
+ accscalar_t unbiasedVar = var_n / (n - 1);
+ if (running_var.data() != NULL) {
+ running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
+ }
+ }
+
+}
+
+template <typename scalar_t, typename accscalar_t, typename index_t>
+__global__ void batch_norm_backward_reduce_kernel(
+ const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> input,
+ const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
+ PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean,
+ PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
+ PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy,
+ PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy_xmu,
+ PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
+ PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
+
+ index_t plane = blockIdx.x;
+ index_t N = input.size(0) * input.size(2);
+
+ accscalar_t r_mean = mean[plane];
+ accscalar_t factor = invstd[plane];
+
+ GradOp<scalar_t, accscalar_t, PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
+ Float2<scalar_t, accscalar_t> res = reduce<Float2<scalar_t, accscalar_t>, GradOp<scalar_t, accscalar_t,
+ PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t>>>(g, grad_output, plane);
+
+ accscalar_t norm = accscalar_t(1) / N;
+ if (threadIdx.x == 0) {
+ if (grad_weight.size(0) > 0) {
+ grad_weight[plane] = static_cast<scalar_t>(res.v2 * factor);
+ }
+ if (grad_bias.size(0) > 0) {
+ grad_bias[plane] = static_cast<scalar_t>(res.v1);
+ }
+ if (mean_dy.size(0) > 0) {
+ mean_dy[plane] = static_cast<accscalar_t>(res.v1 * norm);
+ }
+ if (mean_dy_xmu.size(0) > 0) {
+ mean_dy_xmu[plane] = static_cast<accscalar_t>(res.v2 * norm);
+ }
+ }
+}
+
+template <typename scalar_t, typename accscalar_t, typename index_t>
+__global__ void batch_norm_backward_elemt_kernel(
+ const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> input,
+ const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
+ const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean,
+ const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
+ const PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> weight,
+ const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy,
+ const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> mean_dy_xmu,
+ PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_input) {
+
+ index_t plane = blockIdx.x;
+
+ if (plane >= input.size(1)) {
+ return;
+ }
+
+ accscalar_t m_c = mean[plane];
+ accscalar_t m_dy_c = mean_dy[plane];
+ accscalar_t factor_1_c = invstd[plane];
+ accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : static_cast<accscalar_t>(1);
+ factor_2_c *= factor_1_c;
+ factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[plane];
+
+ index_t bs = input.size(0);
+ index_t fs = input.size(2);
+
+ index_t bstep = blockDim.y * gridDim.y;
+ for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
+ auto g_i = grad_input[batch][plane];
+ auto g_o = grad_output[batch][plane];
+ auto i = input[batch][plane];
+ for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
+ g_i[feature] = static_cast<scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
+ }
+ }
+}
+
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
static PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(const Tensor& t) {
if (! t.defined()) {
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
- // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+ // and good occupancy. Quite likely, we could go with even more blocks than 1024.
// The various planes are independent, so we use blocks for them.
int tf = std::max<int>(getNumThreads(input.size(2)/4),
std::min<int>(getNumThreads(input.size(2)), 64));
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
}
+template<typename scalar_t, typename index_t>
+std::tuple<Tensor, Tensor> batch_norm_stats_cuda_template(const Tensor& input_, double epsilon) {
+
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ int64_t n_input = input_.size(1);
+ Tensor dummy_mean_;
+ Tensor dummy_var_;
+ Tensor mean_;
+ Tensor invstd_;
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+
+ auto bs = input_reshaped.size(0);
+ auto features = input_reshaped.size(2);
+ auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+ auto input_options = input_.options();
+ dummy_mean_ = at::empty({0}, input_options);
+ dummy_var_ = at::empty({0}, input_options);
+ // promote only mean_/invstd_ precision
+ if (input_.type().scalarType() == at::ScalarType::Half) {
+ input_options = input_options.dtype(ScalarType::Float);
+ }
+ mean_ = at::empty({n_input}, input_options);
+ invstd_ = at::empty({n_input}, input_options);
+ auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(mean_);
+ auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
+ auto dummy_mean = dummy_mean_.packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
+ auto dummy_invstd = dummy_var_.packed_accessor<scalar_t, 1, RestrictPtrTraits, index_t>();
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ dim3 blocks(input.size(1));
+ int tf = getNumThreads(input.size(2));
+ dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
+ batch_norm_collect_statistics_kernel<InvStd, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+ (input, epsilon, 0.0, dummy_mean, dummy_invstd, mean, invstd);
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(mean_, invstd_);
+}
+
+template<typename scalar_t, typename index_t>
+Tensor batch_norm_elemt_cuda_template(const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
+ const Tensor& mean_, const Tensor& invstd_,
+ double epsilon) {
+
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ int64_t n_input = input_.size(1);
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+ auto output_reshaped = at::empty_like(input_reshaped);
+
+ auto bs = input_reshaped.size(0);
+ auto features = input_reshaped.size(2);
+ auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+ auto input_options = input_.options();
+ if (input_.type().scalarType() == at::ScalarType::Half) {
+ input_options = input_options.dtype(ScalarType::Float);
+ }
+ auto output = output_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+ auto weight = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
+ auto bias = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
+ auto mean = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(mean_);
+ auto invstd = packed_accessor_or_dummy<accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_);
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+ // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+ // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+ // The various planes are independent, so we use blocks for them.
+ int tf = std::max<int>(getNumThreads(input.size(2)/4),
+ std::min<int>(getNumThreads(input.size(2)), 64));
+ int tb = std::max<int>(64/tf, 1);
+ dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
+ (input.size(0)+tb-1)/tb)));
+ dim3 threads_trans(tf, tb);
+ batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
+ (input, output, mean, invstd, weight, bias, epsilon);
+ THCudaCheck(cudaGetLastError());
+ return output_reshaped.view(input_.sizes());
+}
+
+template<typename scalar_t, typename accscalar_t, typename index_t>
+std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
+ const Tensor& running_mean_, const Tensor& running_var_,
+ double momentum, double epsilon, index_t count) {
+
+ Tensor save_mean_;
+ Tensor save_invstd_;
+
+ auto features = mean_.size(1);
+ auto input_options = mean_.options();
+ if (mean_.type().scalarType() == at::ScalarType::Half) {
+ input_options = input_options.dtype(ScalarType::Float);
+ }
+ save_mean_ = at::empty({features}, input_options);
+ save_invstd_ = at::empty({features}, input_options);
+
+ auto mean = packed_accessor_or_dummy<accscalar_t, 2, RestrictPtrTraits, index_t>(mean_);
+ auto invstd = packed_accessor_or_dummy<accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_);
+ auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
+ auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
+ auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+ auto save_invstd = save_invstd_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ int block = getNumThreads(features);
+ int grid = std::max<int>(1, features/block);
+ batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
+ (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, count);
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(save_mean_, save_invstd_);
+}
+
+template<typename scalar_t, typename index_t>
+std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+ const Tensor& mean_, const Tensor& invstd_,
+ const bool input_g, const bool weight_g, const bool bias_g) {
+
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ int64_t n_input = input_.size(1);
+ Tensor mean_dy_;
+ Tensor mean_dy_xmu_;
+ Tensor grad_weight_;
+ Tensor grad_bias_;
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+ auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+
+ if (input_g) {
+ mean_dy_ = at::empty_like(mean_);
+ mean_dy_xmu_ = at::empty_like(mean_);
+ }
+ auto grad_options = grad_out_.options();
+ if (weight_g) {
+ grad_weight_ = at::empty({n_input}, grad_options);
+ }
+ if (bias_g) {
+ grad_bias_ = at::empty({n_input}, grad_options);
+ }
+
+ auto input = input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+ auto grad_output = grad_output_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+ auto grad_weight = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_);
+ auto grad_bias = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_);
+ auto mean = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_);
+ auto invstd = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_);
+ auto mean_dy = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_);
+ auto mean_dy_xmu = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_xmu_);
+
+ auto batch_size = input_reshaped.size(0);
+ auto feature_size = input_reshaped.size(2);
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/32);
+ int block_x = std::min<int>(getNumThreads(feature_size), MAX_BLOCK_SIZE/block_y);
+ const dim3 block(block_x, block_y);
+ const dim3 grid(n_input);
+
+ batch_norm_backward_reduce_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
+ (input, grad_output, mean, invstd, mean_dy, mean_dy_xmu, grad_weight, grad_bias);
+ THCudaCheck(cudaGetLastError());
+
+ return std::make_tuple(mean_dy_, mean_dy_xmu_, grad_weight_, grad_bias_);
+}
+
+template<typename scalar_t, typename index_t>
+Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+ const Tensor& mean_, const Tensor& invstd_,
+ const Tensor& weight_, const Tensor& mean_dy_, const Tensor& mean_dy_xmu_) {
+
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ int64_t n_input = input_.size(1);
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+ auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+ auto grad_input_reshaped = at::empty_like(input_reshaped);
+
+ auto bs = input_reshaped.size(0);
+ auto features = input_reshaped.size(2);
+
+ auto input = input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+ auto grad_input = grad_input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+ auto grad_output = grad_output_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
+ auto mean = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_);
+ auto invstd = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_);
+ auto weight = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(weight_);
+ auto mean_dy = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_);
+ auto mean_dy_xmu = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(mean_dy_xmu_);
+
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+ // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+ // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+ // The various planes are independent, so we use blocks for them.
+ int tf = std::max<int>(getNumThreads(input.size(2)/4),
+ std::min<int>(getNumThreads(input.size(2)), 64));
+ int tb = std::max<int>(64/tf, 1);
+ dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
+ (input.size(0)+tb-1)/tb)));
+ dim3 threads_trans(tf, tb);
+ batch_norm_backward_elemt_kernel<scalar_t, accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
+ (input, grad_output, mean, invstd, weight, mean_dy, mean_dy_xmu, grad_input);
+ THCudaCheck(cudaGetLastError());
+ return grad_input_reshaped.view(input_.sizes());
+}
+
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
(input, 0., momentum, running_mean, running_var, save_mean, save_var);
THCudaCheck(cudaGetLastError());
return std::make_tuple(save_mean_, save_var_);
-
}
} } // namespace at::native
CPU: batch_norm_cpu
CUDA: batch_norm_cuda
+- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
+ dispatch:
+ CUDA: batch_norm_stats_cuda
+
+- func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
+ dispatch:
+ CUDA: batch_norm_elemt_cuda
+
+- func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)
+ dispatch:
+ CUDA: batch_norm_gather_stats_cuda
+
- func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
matches_jit_signature: True
dispatch:
CPU: batch_norm_backward_cpu
CUDA: batch_norm_backward_cuda
+- func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)
+ dispatch:
+ CUDA: batch_norm_backward_reduce_cuda
+
+- func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor mean_dy, Tensor mean_dy_xmu) -> Tensor
+ dispatch:
+ CUDA: batch_norm_backward_elemt_cuda
+
- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)
matches_jit_signature: True
dispatch:
return F.softmax(x, dim=1)
+class BatchNormNet(nn.Module):
+
+ def __init__(self):
+ super(BatchNormNet, self).__init__()
+ self.fc1 = nn.Linear(2, 40, bias=False)
+ self.bn = nn.BatchNorm1d(4)
+ self.fc2 = nn.Linear(40, 4, bias=False)
+
+ def forward(self, x):
+ x = torch.reshape(self.fc1(x), (-1, 4, 10))
+ x = self.bn(x)
+ x = torch.reshape(x, (-1, 40))
+ x = self.fc2(x)
+ return F.softmax(x, dim=1)
+
+
DDP_NET = Net()
+BN_NET = BatchNormNet()
def get_timeout(test_id):
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
+ def _test_DistributedDataParallel_SyncBatchNorm(self, gpu_subset, rank, output_device=None):
+ # Run a simple end to end DDP model, use result of single node model
+ # as baseline
+
+ # cpu training setup
+ model = BN_NET
+
+ # single gpu training setup
+ model_gpu = copy.deepcopy(model)
+ model_gpu.cuda(gpu_subset[0])
+
+ # DDP training setup
+ model_DDP = nn.utils.convert_sync_batchnorm(copy.deepcopy(model))
+ model_DDP.cuda(gpu_subset[0])
+ model_DDP = nn.parallel.DistributedDataParallel(
+ model_DDP, device_ids=gpu_subset
+ )
+
+ # test serializable/unserializable
+ if INIT_METHOD.startswith("file://"):
+ _, filename = tempfile.mkstemp(prefix=FOLDER)
+ torch.save(model_DDP, filename)
+ model_DDP = torch.load(filename)
+
+ # dummy data initialization
+ local_bs = len(gpu_subset)
+ global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+ # check two model parameters over 5 iterations
+ self._test_DDP_5iter(
+ model_gpu,
+ model_DDP,
+ input_cpu.cuda(gpu_subset[0]),
+ target.cuda(gpu_subset[0]),
+ loss,
+ local_bs,
+ rank,
+ global_bs,
+ True
+ )
+ self._barrier()
+
+ @unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
+ "Only Nccl & Gloo backend support DistributedDataParallel")
+ @skip_if_no_cuda_distributed
+ @skip_if_no_gpu
+ def test_DistributedDataParallel_SyncBatchNorm(self):
+ group, group_id, rank = self._init_global_test()
+ rank_to_GPU = self._init_multigpu_helper()
+ gpus = list(rank_to_GPU[rank])
+ self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank)
+
+ # test output_device
+ self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
+
+ # test device_ids
+ gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
+ self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
+
if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
-from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
+from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
- 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm',
+ 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
--- /dev/null
+import torch
+from torch.autograd.function import Function
+
+
+class SyncBatchNorm(Function):
+
+ @staticmethod
+ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
+ input = input.contiguous()
+
+ # calcualte mean/invstd for input.
+ mean, invstd = torch.batch_norm_stats(input, eps)
+
+ mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
+ invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device)
+ mean_l = list(mean_all.unbind(0))
+ invstd_l = list(invstd_all.unbind(0))
+ # using all_gather instead of all reduce so we can calculate mean/var in one go
+ mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True)
+ invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True)
+
+ # wait on the async communication to finish
+ mean_all_reduce.wait()
+ invstd_all_reduce.wait()
+
+ # calcualte global mean & invstd
+ mean, invstd = torch.batch_norm_gather_stats(
+ input,
+ mean_all,
+ invstd_all,
+ running_mean,
+ running_var,
+ momentum,
+ eps,
+ int(input.numel() / input.size(1))
+ )
+
+ self.save_for_backward(input, weight, mean, invstd)
+ self.process_group = process_group
+ self.world_size = world_size
+
+ # apply element-wise normalization
+ out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
+ return out
+
+ @staticmethod
+ def backward(self, grad_output):
+ grad_output = grad_output.contiguous()
+ saved_input, weight, mean, invstd = self.saved_tensors
+ grad_input = grad_weight = grad_bias = None
+ process_group = self.process_group
+ world_size = self.world_size
+
+ # calculate local stats as well as grad_weight / grad_bias
+ mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
+ grad_output,
+ saved_input,
+ mean,
+ invstd,
+ self.needs_input_grad[0],
+ self.needs_input_grad[1],
+ self.needs_input_grad[2]
+ )
+
+ if self.needs_input_grad[0]:
+ # synchronizing stats used to calculate input gradient.
+ # TODO: move div_ into batch_norm_backward_elemt kernel
+ mean_dy_all_reduce = torch.distributed.all_reduce(
+ mean_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
+ mean_dy_xmu_all_reduce = torch.distributed.all_reduce(
+ mean_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
+
+ # wait on the async communication to finish
+ mean_dy_all_reduce.wait()
+ mean_dy_xmu_all_reduce.wait()
+
+ mean_dy.div_(world_size)
+ mean_dy_xmu.div_(world_size)
+ # backward pass for gradient calculation
+ grad_input = torch.batch_norm_backward_elemt(
+ grad_output,
+ saved_input,
+ mean,
+ invstd,
+ weight,
+ mean_dy,
+ mean_dy_xmu
+ )
+
+ # synchronizing of grad_weight / grad_bias is not needed as distributed
+ # training would handle all reduce.
+ if weight is None or not self.needs_input_grad[1]:
+ grad_weight = None
+
+ if weight is None or not self.needs_input_grad[2]:
+ grad_bias = None
+
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
from __future__ import division
import torch
+from ._functions import SyncBatchNorm as sync_batch_norm
from .module import Module
from torch.nn.parameter import Parameter
from .. import functional as F
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
+
+
+class SyncBatchNorm(_BatchNorm):
+ r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
+ with additional channel dimension) as described in the paper
+ `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension over all
+ mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
+ are learnable parameter vectors of size `C` (where `C` is the input size).
+ By default, the elements of :math:`\gamma` are sampled from
+ :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
+
+ Also by default, during training this layer keeps running estimates of its
+ computed mean and variance, which are then used for normalization during
+ evaluation. The running estimates are kept with a default :attr:`momentum`
+ of 0.1.
+
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
+ keep running estimates, and batch statistics are instead used during
+ evaluation time as well.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
+ on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization
+ or Spatio-temporal Batch Normalization.
+
+ Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
+ torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
+ Network with DDP.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, +)`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters. Default: ``True``
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``True``
+ process_group: synchronization of stats happen within each process group
+ individually. Default behavior is synchronization across the whole
+ world
+
+ Shape:
+ - Input: :math:`(N, C, +)`
+ - Output: :math:`(N, C, +)` (same shape as input)
+
+ Examples::
+
+ >>> # With Learnable Parameters
+ >>> m = nn.SyncBatchNorm(100)
+ >>> # creating process group (optional)
+ >>> # process_ids is a list of int identifying rank ids.
+ >>> process_group = torch.distributed.new_group(process_ids)
+ >>> # Without Learnable Parameters
+ >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
+ >>> input = torch.randn(20, 100, 35, 45, 10)
+ >>> output = m(input)
+
+ >>> # network is nn.BatchNorm layer
+ >>> sync_bn_network = torch.nn.utils.convert_sync_batchnorm(network, process_group)
+ >>> # only single gpu per process is currently supported
+ >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
+ >>> sync_bn_network,
+ >>> device_ids=[args.local_rank],
+ >>> output_device=args.local_rank)
+
+ .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+ https://arxiv.org/abs/1502.03167
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True, process_group=None):
+ super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
+ self.process_group = process_group
+ # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
+ # under supported condition (single GPU per process)
+ self.ddp_gpu_size = None
+
+ def _check_input_dim(self, input):
+ if input.dim() <= 2:
+ raise ValueError('expected at least 3D input (got {}D input)'
+ .format(input.dim()))
+
+ def _specify_ddp_gpu_num(self, gpu_size):
+ if gpu_size > 1:
+ raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
+ self.ddp_gpu_size = gpu_size
+
+ def forward(self, input):
+ # currently only GPU input is supported
+ if not input.is_cuda:
+ raise ValueError('expected input tensor to be on GPU')
+
+ if not self.ddp_gpu_size:
+ raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
+
+ self._check_input_dim(input)
+
+ exponential_average_factor = 0.0
+
+ if self.training and self.track_running_stats:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ world_size = 1
+ process_group = torch.distributed.group.WORLD
+ if self.process_group:
+ process_group = self.process_group
+ world_size = torch.distributed.get_world_size(process_group)
+
+ # fallback to framework BN when synchronization is not necessary
+ if world_size == 1 or (not self.training and self.track_running_stats):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training or not self.track_running_stats,
+ exponential_average_factor, self.eps)
+ else:
+ return sync_batch_norm.apply(
+ input, self.weight, self.bias, self.running_mean, self.running_var,
+ self.eps, exponential_average_factor, process_group, world_size)
(2) bucketing the parameters for reductions
(3) resetting the bucketing states
(4) registering the grad hooks
+ (5) passing a handle of DDP to SyncBatchNorm Layer
"""
if len(self.device_ids) > 1:
# TODO: we don't need to replicate params in here. they're always going to
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
self._register_grad_hooks()
+ # passing a handle to torch.nn.SyncBatchNorm layer
+ self._passing_sync_batchnorm_handle(self._module_copies)
+
def __getstate__(self):
self._check_default_group()
attrs = copy.copy(self.__dict__)
for tensor, buffer_data in zip(tensors, module_buffers_data):
buffer_data.set_(tensor)
+ def _passing_sync_batchnorm_handle(self, module_copies):
+ for dev_idx, module in enumerate(module_copies):
+ for layer in module.modules():
+ if isinstance(layer, torch.nn.modules.SyncBatchNorm):
+ layer._specify_ddp_gpu_num(len(self.device_ids))
+
def _register_grad_hooks(self):
self._grad_accs = [] # need to keep them in scope
from .weight_norm import weight_norm, remove_weight_norm
from .convert_parameters import parameters_to_vector, vector_to_parameters
from .spectral_norm import spectral_norm, remove_spectral_norm
+from .sync_batch_norm import convert_sync_batchnorm
--- /dev/null
+import torch
+
+
+def convert_sync_batchnorm(module, process_group=None):
+ r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
+ `torch.nn.SyncBatchNorm` layer.
+
+ Args:
+ module (nn.Module): containing module
+ process_group (optional): process group to scope synchronization,
+ default is the whole world
+
+ Returns:
+ The original module with the converted `torch.nn.SyncBatchNorm` layer
+
+ Example::
+
+ >>> # Network with nn.BatchNorm layer
+ >>> module = torch.nn.Sequential(
+ >>> torch.nn.Linear(20, 100),
+ >>> torch.nn.BatchNorm1d(100)
+ >>> ).cuda()
+ >>> # creating process group (optional)
+ >>> # process_ids is a list of int identifying rank ids.
+ >>> process_group = torch.distributed.new_group(process_ids)
+ >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
+
+ """
+ module_output = module
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+ module_output = torch.nn.SyncBatchNorm(module.num_features,
+ module.eps, module.momentum,
+ module.affine,
+ module.track_running_stats,
+ process_group)
+ if module.affine:
+ module_output.weight.data = module.weight.data.clone().detach()
+ module_output.bias.data = module.bias.data.clone().detach()
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ for name, child in module.named_children():
+ module_output.add_module(name, convert_sync_batchnorm(child))
+ del module
+ return module_output