From: Xiaomeng Yang Date: Thu, 24 Jan 2019 10:50:35 +0000 (-0800) Subject: Optimize SpatialBN on GPU (#16202) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1705 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=45c3cc9174f707883664e35a2d95d08af3ad3251;p=platform%2Fupstream%2Fpytorch.git Optimize SpatialBN on GPU (#16202) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16202 Optimize SpatialBN on GPU Reviewed By: houseroad Differential Revision: D13747581 fbshipit-source-id: 48a885a240ef2a325235e8f89ebbe50e7c780c84 --- diff --git a/caffe2/operators/hip/spatial_batch_norm_op_miopen.hip b/caffe2/operators/hip/spatial_batch_norm_op_miopen.hip index ae905da..d4ab900 100644 --- a/caffe2/operators/hip/spatial_batch_norm_op_miopen.hip +++ b/caffe2/operators/hip/spatial_batch_norm_op_miopen.hip @@ -18,7 +18,6 @@ #include "caffe2/core/hip/context_gpu.h" #include "caffe2/core/hip/miopen_wrapper.h" #include "caffe2/operators/spatial_batch_norm_op.h" -#include "caffe2/operators/hip/spatial_batch_norm_op_gpu_impl.cuh" #include "caffe2/utils/math.h" const double MIOPEN_BN_MIN_EPSILON = 1e-6; diff --git a/caffe2/operators/spatial_batch_norm_gradient_op.cc b/caffe2/operators/spatial_batch_norm_gradient_op.cc index 49d07f6..e7b27db 100644 --- a/caffe2/operators/spatial_batch_norm_gradient_op.cc +++ b/caffe2/operators/spatial_batch_norm_gradient_op.cc @@ -58,7 +58,8 @@ void SpatialBNGradientOp::ComputeScaleBiasGradientsAndFusedParams( T* dbias, T* alpha, T* beta, - T* gamma) { + T* gamma, + T* /* scratch */) { ConstEigenVectorArrayMap scale_arr(scale, C); ConstEigenVectorArrayMap mean_arr(mean, C); ConstEigenVectorArrayMap rstd_arr(rstd, C); diff --git a/caffe2/operators/spatial_batch_norm_op.cc b/caffe2/operators/spatial_batch_norm_op.cc index f1b3698..2e6c9fc 100644 --- a/caffe2/operators/spatial_batch_norm_op.cc +++ b/caffe2/operators/spatial_batch_norm_op.cc @@ -6,6 +6,70 @@ namespace caffe2 { +template <> +template <> +void SpatialBNOp::ComputeFusedParam( + const int C, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* alpha, + float* beta) { + EigenVectorArrayMap alpha_arr(alpha, C); + EigenVectorArrayMap beta_arr(beta, C); + alpha_arr = ConstEigenVectorArrayMap(scale, C) * + (ConstEigenVectorArrayMap(var, C) + static_cast(epsilon_)) + .rsqrt(); + beta_arr = ConstEigenVectorArrayMap(bias, C) - + alpha_arr * ConstEigenVectorArrayMap(mean, C); +} + +template <> +template <> +void SpatialBNOp::ComputeBatchMoments( + const int N, + const int C, + const int HxW, + const float* batch_mean_sum, + const float* batch_var_sum, + float* mean, + float* var) { + const float scale = 1.0f / static_cast(num_batches_ * N * HxW); + EigenVectorArrayMap mean_arr(mean, C); + EigenVectorArrayMap var_arr(var, C); + mean_arr = ConstEigenVectorArrayMap(batch_mean_sum, C) * scale; + var_arr = ConstEigenVectorArrayMap(batch_var_sum, C) * scale - + mean_arr.square(); +} + +template <> +template <> +void SpatialBNOp::ComputeRunningMomentsAndFusedParam( + const int C, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* running_mean, + float* running_var, + float* rstd, + float* alpha, + float* beta) { + const float a = 1.0f - momentum_; + const float b = momentum_; + math::Axpby(C, a, mean, b, running_mean, &context_); + math::Axpby(C, a, var, b, running_var, &context_); + math::InvStd( + C, static_cast(epsilon_), var, rstd, &context_); + EigenVectorArrayMap alpha_arr(alpha, C); + EigenVectorArrayMap beta_arr(beta, C); + alpha_arr = ConstEigenVectorArrayMap(scale, C) * + ConstEigenVectorArrayMap(rstd, C); + beta_arr = ConstEigenVectorArrayMap(bias, C) - + alpha_arr * ConstEigenVectorArrayMap(mean, C); +} + namespace { OpSchema::Cost CostInferenceForSpatialBN( diff --git a/caffe2/operators/spatial_batch_norm_op.cu b/caffe2/operators/spatial_batch_norm_op.cu index ecfc200..d7ec4df 100644 --- a/caffe2/operators/spatial_batch_norm_op.cu +++ b/caffe2/operators/spatial_batch_norm_op.cu @@ -1,9 +1,642 @@ #include "caffe2/operators/spatial_batch_norm_op.h" -#include "caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh" +#include + +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math.h" namespace caffe2 { +namespace { + +template +using BlockReduce = cub::BlockReduce; + +template +using BlockReduce2D = cub:: + BlockReduce; + +template +__global__ void ComputeFusedParamCUDAKernel( + const int C, + const T epsilon, + const T* scale, + const T* bias, + const T* mean, + const T* var, + T* alpha, + T* beta); + +template <> +__global__ void ComputeFusedParamCUDAKernel( + const int C, + const float epsilon, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* alpha, + float* beta) { + const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + const float scale_x_rstd = + __ldg(scale + c) * rsqrtf(__ldg(var + c) + epsilon); + alpha[c] = scale_x_rstd; + beta[c] = fmaf(-scale_x_rstd, __ldg(mean + c), __ldg(bias + c)); +#else + const float scale_x_rstd = scale[c] * rsqrtf(var[c] + epsilon); + alpha[c] = scale_x_rstd; + beta[c] = fmaf(-scale_x_rstd, mean[c], bias[c]); +#endif + } +} + +template +__global__ void ComputeBatchMomentsCUDAKernel( + const int C, + const T scale, + const T* batch_mean_sum, + const T* batch_var_sum, + T* mean, + T* var); + +template <> +__global__ void ComputeBatchMomentsCUDAKernel( + const int C, + const float scale, + const float* batch_mean_sum, + const float* batch_var_sum, + float* mean, + float* var) { + const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + const float mu = __ldg(batch_mean_sum + c) * scale; + mean[c] = mu; + var[c] = fmaf(__ldg(batch_var_sum + c), scale, -mu * mu); +#else + const float mu = batch_mean_sum[c] * scale; + mean[c] = mu; + var[c] = fmaf(batch_var_sum[c], scale, -mu * mu); +#endif + } +} + +template +__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel( + const int C, + const T momentum, + const T epsilon, + const T* scale, + const T* bias, + const T* mean, + const T* var, + T* running_mean, + T* running_var, + T* rstd, + T* alpha, + T* beta); + +template <> +__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel( + const int C, + const float momentum, + const float epsilon, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* running_mean, + float* running_var, + float* rstd, + float* alpha, + float* beta) { + const float a = 1.0f - momentum; + const float b = momentum; + const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + running_mean[c] = fmaf(a, __ldg(mean + c), b * __ldg(running_mean + c)); + running_var[c] = fmaf(a, __ldg(var + c), b * __ldg(running_var + c)); + const float rstd_val = rsqrtf(__ldg(var + c) + epsilon); + const float scale_x_rstd = __ldg(scale + c) * rstd_val; + rstd[c] = rstd_val; + alpha[c] = scale_x_rstd; + beta[c] = fmaf(-scale_x_rstd, __ldg(mean + c), __ldg(bias + c)); +#else + running_mean[c] = fmaf(a, mean[c], b * running_mean[c]); + running_var[c] = fmaf(a, var[c], b * running_var[c]); + const float rstd_val = rsqrtf(var[c] + epsilon); + const float scale_x_rstd = scale[c] * rstd_val; + rstd[c] = rstd_val; + alpha[c] = scale_x_rstd; + beta[c] = fmaf(-scale_x_rstd, mean[c], bias[c]); +#endif + } +} + +template +__global__ void ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel( + const int C, + const T inv_num_batches, + const T inv_nhw, + const T* scale, + const T* mean, + const T* rstd, + const T* dscale_sum, + const T* dbias_sum, + T* dscale, + T* dbias, + T* alpha, + T* beta, + T* gamma); + +template <> +__global__ void +ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel( + const int C, + const float inv_num_batches, + const float inv_nhw, + const float* scale, + const float* mean, + const float* rstd, + const float* dscale_sum, + const float* dbias_sum, + float* dscale, + float* dbias, + float* alpha, + float* beta, + float* gamma) { + const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + const float dscale_val = __ldg(dscale_sum + c) * inv_num_batches; + const float dbias_val = __ldg(dbias_sum + c) * inv_num_batches; + const float scale_x_rstd = __ldg(scale + c) * __ldg(rstd + c); + const float dscale_x_rstd = dscale_val * __ldg(rstd + c); + dscale[c] = dscale_val; + dbias[c] = dbias_val; + alpha[c] = scale_x_rstd; + beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw; + gamma[c] = scale_x_rstd * fmaf(__ldg(mean + c), dscale_x_rstd, -dbias_val) * + inv_nhw; +#else + const float dscale_val = dscale_sum[c] * inv_num_batches; + const float dbias_val = dbias_sum[c] * inv_num_batches; + const float scale_x_rstd = scale[c] * rstd[c]; + const float dscale_x_rstd = dscale_val * rstd[c]; + dscale[c] = dscale_val; + dbias[c] = dbias_val; + alpha[c] = scale_x_rstd; + beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw; + gamma[c] = + scale_x_rstd * fmaf(mean[c], dscale_x_rstd, -dbias_val) * inv_nhw; +#endif + } +} + +template +__global__ void ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel( + const int N, + const int C, + const int HxW, + const T* dY, + const T* X, + const T* scale, + const T* mean, + const T* rstd, + T* dscale, + T* dbias, + T* alpha, + T* beta, + T* gamma) { + __shared__ + typename BlockReduce2D::TempStorage ds_storage; + __shared__ + typename BlockReduce2D::TempStorage db_storage; + const T inv_nhw = T(1) / static_cast(N * HxW); + const int c = blockIdx.x; + T ds_val = 0; + T db_val = 0; + for (int i = threadIdx.x; i < N; i += blockDim.x) { + for (int j = threadIdx.y; j < HxW; j += blockDim.y) { + const int index = (i * C + c) * HxW + j; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + ds_val += __ldg(dY + index) * __ldg(X + index); + db_val += __ldg(dY + index); +#else + ds_val += dY[index] * X[index]; + db_val += dY[index]; +#endif + } + } + ds_val = BlockReduce2D(ds_storage).Sum(ds_val); + db_val = BlockReduce2D(db_storage).Sum(db_val); + if (threadIdx.x == 0 && threadIdx.y == 0) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + ds_val = (ds_val - db_val * __ldg(mean + c)) * __ldg(rstd + c); + const T scale_x_rstd = __ldg(scale + c) * __ldg(rstd + c); + const T dscale_x_rstd = ds_val * __ldg(rstd + c); + dscale[c] = ds_val; + dbias[c] = db_val; + alpha[c] = scale_x_rstd; + beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw; + gamma[c] = + scale_x_rstd * (__ldg(mean + c) * dscale_x_rstd - db_val) * inv_nhw; +#else + ds_val = (ds_val - db_val * mean[c]) * rstd[c]; + const T scale_x_rstd = scale[c] * rstd[c]; + const T dscale_x_rstd = ds_val * rstd[c]; + dscale[c] = ds_val; + dbias[c] = db_val; + alpha[c] = scale_x_rstd; + beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw; + gamma[c] = scale_x_rstd * (mean[c] * dscale_x_rstd - db_val) * inv_nhw; +#endif + } +} + +template +__global__ void ComputeScaleGradientAndFusedParamsNHWCCUDAKernel( + const int C, + const T inv_nhw, + const T* dYxX, + const T* dbias, + const T* scale, + const T* mean, + const T* rstd, + T* dscale, + T* alpha, + T* beta, + T* gamma); + +template <> +__global__ void ComputeScaleGradientAndFusedParamsNHWCCUDAKernel( + const int C, + const float inv_nhw, + const float* dYxX, + const float* dbias, + const float* scale, + const float* mean, + const float* rstd, + float* dscale, + float* alpha, + float* beta, + float* gamma) { + const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (c < C) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + const float ds = fmaf(-__ldg(dbias + c), __ldg(mean + c), __ldg(dYxX + c)) * + __ldg(rstd + c); + dscale[c] = ds; + const float scale_x_rstd = __ldg(scale + c) * __ldg(rstd + c); + const float dscale_x_rstd = ds * __ldg(rstd + c); + alpha[c] = scale_x_rstd; + beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw; + gamma[c] = scale_x_rstd * + fmaf(__ldg(mean + c), dscale_x_rstd, -__ldg(dbias + c)) * inv_nhw; +#else + const float ds = fmaf(-dbias[c], mean[c], dYxX[c]) * rstd[c]; + dscale[c] = ds; + const float scale_x_rstd = scale[c] * rstd[c]; + const float dscale_x_rstd = ds * rstd[c]; + alpha[c] = scale_x_rstd; + beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw; + gamma[c] = scale_x_rstd * fmaf(mean[c], dscale_x_rstd, -dbias[c]) * inv_nhw; +#endif + } +} + +template +__global__ void ComputeXGradientNCHWCUDAKernel( + const int C, + const int HxW, + const int K, + const T* dY, + const T* X, + const T* alpha, + const T* beta, + const T* gamma, + T* dX); + +template <> +__global__ void ComputeXGradientNCHWCUDAKernel( + const int C, + const int HxW, + const int K, + const float* dY, + const float* X, + const float* alpha, + const float* beta, + const float* gamma, + float* dX) { + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const int c = nc % C; + const int w = block * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (w < HxW) { + const int index = nc * HxW + w; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[index] = fmaf( + __ldg(alpha + c), + __ldg(dY + index), + fmaf(__ldg(beta + c), __ldg(X + index), __ldg(gamma + c))); +#else + dX[index] = fmaf(alpha[c], dY[index], fma(beta[c], X[index], gamma[c])); +#endif + } +} + +template +__global__ void ComputeXGradientNHWCCUDAKernel( + const int C, + const T* dY, + const T* X, + const T* alpha, + const T* beta, + const T* gamma, + T* dX); + +template <> +__global__ void ComputeXGradientNHWCCUDAKernel( + const int C, + const float* dY, + const float* X, + const float* alpha, + const float* beta, + const float* gamma, + float* dX) { + for (int c = threadIdx.x; c < C; c += blockDim.x) { + const int index = blockIdx.x * C + c; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[index] = fmaf( + __ldg(alpha + c), + __ldg(dY + index), + fmaf(__ldg(beta + c), __ldg(X + index), __ldg(gamma + c))); +#else + dX[index] = fmaf(alpha[c], dY[index], fma(beta[c], X[index], gamma[c])); +#endif + } +} + +} // namespace + +template <> +template <> +void SpatialBNOp::ComputeFusedParam( + const int C, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* alpha, + float* beta) { + const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS); + ComputeFusedParamCUDAKernel + <<>>( + C, static_cast(epsilon_), scale, bias, mean, var, alpha, beta); +} + +template <> +template <> +void SpatialBNOp::ComputeBatchMoments( + const int N, + const int C, + const int HxW, + const float* batch_mean_sum, + const float* batch_var_sum, + float* mean, + float* var) { + const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS); + const float scale = 1.0f / static_cast(num_batches_ * N * HxW); + ComputeBatchMomentsCUDAKernel + <<>>( + C, scale, batch_mean_sum, batch_var_sum, mean, var); +} + +template <> +template <> +void SpatialBNOp::ComputeRunningMomentsAndFusedParam( + const int C, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* running_mean, + float* running_var, + float* rstd, + float* alpha, + float* beta) { + const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS); + ComputeRunningMomentsAndFusedParamCUDAKernel + <<>>( + C, + static_cast(momentum_), + static_cast(epsilon_), + scale, + bias, + mean, + var, + running_mean, + running_var, + rstd, + alpha, + beta); +} + +template <> +template <> +void SpatialBNGradientOp:: + ComputeMultiBatchScaleBiasGradientsAndFusedParams( + const int N, + const int C, + const int HxW, + const float* scale, + const float* mean, + const float* rstd, + const float* dscale_sum, + const float* dbias_sum, + float* dscale, + float* dbias, + float* alpha, + float* beta, + float* gamma) { + const float inv_num_batches = 1.0f / static_cast(num_batches_); + const float inv_nhw = 1.0f / static_cast(N * HxW); + const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS); + ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel + <<>>( + C, + inv_num_batches, + inv_nhw, + scale, + mean, + rstd, + dscale_sum, + dbias_sum, + dscale, + dbias, + alpha, + beta, + gamma); +} + +template <> +template <> +void SpatialBNGradientOp::ComputeScaleBiasGradientsAndFusedParams< + float>( + const int N, + const int C, + const int HxW, + const float* dY, + const float* X, + const float* scale, + const float* mean, + const float* rstd, + float* dscale, + float* dbias, + float* alpha, + float* beta, + float* gamma, + float* scratch) { + if (order_ == StorageOrder::NCHW) { + if (HxW >= 128) { + ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel + <<>>( + N, + C, + HxW, + dY, + X, + scale, + mean, + rstd, + dscale, + dbias, + alpha, + beta, + gamma); + } else if (HxW >= 64) { + ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel + <<>>( + N, + C, + HxW, + dY, + X, + scale, + mean, + rstd, + dscale, + dbias, + alpha, + beta, + gamma); + } else if (HxW >= 32) { + ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel + <<>>( + N, + C, + HxW, + dY, + X, + scale, + mean, + rstd, + dscale, + dbias, + alpha, + beta, + gamma); + } else { + ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel + <<>>( + N, + C, + HxW, + dY, + X, + scale, + mean, + rstd, + dscale, + dbias, + alpha, + beta, + gamma); + } + } else { + ReinitializeTensor(&ones_, {N * HxW}, at::dtype().device(CUDA)); + math::Set( + N * HxW, 1.0f, ones_.mutable_data(), &context_); + const float* ones_data = ones_.data(); + math::Mul(N * C * HxW, dY, X, scratch, &context_); + math::Gemm( + CblasTrans, + CblasNoTrans, + C, + 1, + N * HxW, + 1.0f, + scratch, + ones_data, + 0.0f, + dscale, + &context_); + math::Gemm( + CblasTrans, + CblasNoTrans, + C, + 1, + N * HxW, + 1.0f, + dY, + ones_data, + 0.0f, + dbias, + &context_); + const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS); + ComputeScaleGradientAndFusedParamsNHWCCUDAKernel + <<>>( + C, + 1.0f / static_cast(N * HxW), + dscale, + dbias, + scale, + mean, + rstd, + dscale, + alpha, + beta, + gamma); + } +} + +template <> +template <> +void SpatialBNGradientOp::ComputeXGradient( + const int N, + const int C, + const int HxW, + const float* dY, + const float* X, + const float* alpha, + const float* beta, + const float* gamma, + float* dX) { + if (order_ == StorageOrder::NCHW) { + const int K = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS); + ComputeXGradientNCHWCUDAKernel + <<>>( + C, HxW, K, dY, X, alpha, beta, gamma, dX); + } else { + ComputeXGradientNHWCCUDAKernel + <<>>( + C, dY, X, alpha, beta, gamma, dX); + } +} + REGISTER_CUDA_OPERATOR(SpatialBN, SpatialBNOp); REGISTER_CUDA_OPERATOR(SpatialBNGradient, SpatialBNGradientOp); diff --git a/caffe2/operators/spatial_batch_norm_op.h b/caffe2/operators/spatial_batch_norm_op.h index d56ff1a..5aa4989 100644 --- a/caffe2/operators/spatial_batch_norm_op.h +++ b/caffe2/operators/spatial_batch_norm_op.h @@ -204,15 +204,7 @@ class SpatialBNOp : public Operator { const T* mean, const T* var, T* alpha, - T* beta) { - EigenVectorArrayMap alpha_arr(alpha, C); - EigenVectorArrayMap beta_arr(beta, C); - alpha_arr = ConstEigenVectorArrayMap(scale, C) * - (ConstEigenVectorArrayMap(var, C) + static_cast(epsilon_)) - .rsqrt(); - beta_arr = ConstEigenVectorArrayMap(bias, C) - - alpha_arr * ConstEigenVectorArrayMap(mean, C); - } + T* beta); template void ComputeBatchMoments( @@ -222,14 +214,7 @@ class SpatialBNOp : public Operator { const T* batch_mean_sum, const T* batch_var_sum, T* mean, - T* var) { - const T scale = T(1) / static_cast(num_batches_ * N * HxW); - EigenVectorArrayMap mean_arr(mean, C); - EigenVectorArrayMap var_arr(var, C); - mean_arr = ConstEigenVectorArrayMap(batch_mean_sum, C) * scale; - var_arr = ConstEigenVectorArrayMap(batch_var_sum, C) * scale - - mean_arr.square(); - } + T* var); template void ComputeRunningMomentsAndFusedParam( @@ -242,19 +227,7 @@ class SpatialBNOp : public Operator { T* running_var, T* rstd, T* alpha, - T* beta) { - const T a = T(1) - static_cast(momentum_); - const T b = static_cast(momentum_); - math::Axpby(C, a, mean, b, running_mean, &context_); - math::Axpby(C, a, var, b, running_var, &context_); - math::InvStd(C, static_cast(epsilon_), var, rstd, &context_); - EigenVectorArrayMap alpha_arr(alpha, C); - EigenVectorArrayMap beta_arr(beta, C); - alpha_arr = ConstEigenVectorArrayMap(scale, C) * - ConstEigenVectorArrayMap(rstd, C); - beta_arr = ConstEigenVectorArrayMap(bias, C) - - alpha_arr * ConstEigenVectorArrayMap(mean, C); - } + T* beta); const bool is_test_; double epsilon_; @@ -392,7 +365,8 @@ class SpatialBNGradientOp : public Operator { dbias_data, alpha_data, beta_data, - gamma_data); + gamma_data, + dX_data); } ComputeXGradient( N, C, HxW, dY_data, X_data, alpha_data, beta_data, gamma_data, dX_data); @@ -431,7 +405,8 @@ class SpatialBNGradientOp : public Operator { T* dbias, T* alpha, T* beta, - T* gamma); + T* gamma, + T* scratch); template void ComputeXGradient( @@ -452,6 +427,7 @@ class SpatialBNGradientOp : public Operator { Tensor alpha_; Tensor beta_; Tensor gamma_; + Tensor ones_; INPUT_TAGS( INPUT, diff --git a/caffe2/operators/spatial_batch_norm_op_cudnn.cu b/caffe2/operators/spatial_batch_norm_op_cudnn.cc similarity index 99% rename from caffe2/operators/spatial_batch_norm_op_cudnn.cu rename to caffe2/operators/spatial_batch_norm_op_cudnn.cc index 16d112b..bc61ca5 100644 --- a/caffe2/operators/spatial_batch_norm_op_cudnn.cu +++ b/caffe2/operators/spatial_batch_norm_op_cudnn.cc @@ -7,7 +7,6 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/core/cudnn_wrappers.h" -#include "caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh" #include "caffe2/utils/math.h" #if CUDNN_VERSION_MIN(5, 0, 0) diff --git a/caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh b/caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh deleted file mode 100644 index 7d87e81..0000000 --- a/caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh +++ /dev/null @@ -1,439 +0,0 @@ -#include "caffe2/operators/spatial_batch_norm_op.h" - -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math.h" - -namespace caffe2 { - -namespace { - -template -using BlockReduce = cub::BlockReduce; - -template -__global__ void ComputeFusedParamCUDAKernel( - const int C, - const T epsilon, - const T* scale, - const T* bias, - const T* mean, - const T* var, - T* alpha, - T* beta); - -template <> -__global__ void ComputeFusedParamCUDAKernel( - const int C, - const float epsilon, - const float* scale, - const float* bias, - const float* mean, - const float* var, - float* alpha, - float* beta) { - CUDA_1D_KERNEL_LOOP(i, C) { -#if __CUDA_ARCH__ >= 350 - const float scale_x_rstd = - __ldg(scale + i) * rsqrtf(__ldg(var + i) + epsilon); - alpha[i] = scale_x_rstd; - beta[i] = __ldg(bias + i) - scale_x_rstd * __ldg(mean + i); -#else - const float scale_x_rstd = scale[i] * rsqrtf(var[i] + epsilon); - alpha[i] = scale_x_rstd; - beta[i] = bias[i] - scale_x_rstd * mean[i]; -#endif - } -} - -template -__global__ void ComputeBatchMomentsCUDAKernel( - const int C, - const T scale, - const T* batch_mean_sum, - const T* batch_var_sum, - T* mean, - T* var) { - CUDA_1D_KERNEL_LOOP(i, C) { -#if __CUDA_ARCH__ >= 350 - const T mu = __ldg(batch_mean_sum + i) * scale; - mean[i] = mu; - var[i] = __ldg(batch_var_sum + i) * scale - mu * mu; -#else - const T mu = batch_mean_sum[i] * scale; - mean[i] = mu; - var[i] = batch_var_sum[i] * scale - mu * mu; -#endif - } -} - -template -__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel( - const int C, - const T momentum, - const T epsilon, - const T* scale, - const T* bias, - const T* mean, - const T* var, - T* running_mean, - T* running_var, - T* rstd, - T* alpha, - T* beta); - -template <> -__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel( - const int C, - const float momentum, - const float epsilon, - const float* scale, - const float* bias, - const float* mean, - const float* var, - float* running_mean, - float* running_var, - float* rstd, - float* alpha, - float* beta) { - const float a = 1.0f - momentum; - const float b = momentum; - CUDA_1D_KERNEL_LOOP(i, C) { -#if __CUDA_ARCH__ >= 350 - running_mean[i] = a * __ldg(mean + i) + b * __ldg(running_mean + i); - running_var[i] = a * __ldg(var + i) + b * __ldg(running_var + i); - const float rstd_val = rsqrtf(__ldg(var + i) + epsilon); - const float scale_x_rstd = __ldg(scale + i) * rstd_val; - rstd[i] = rstd_val; - alpha[i] = scale_x_rstd; - beta[i] = bias[i] - scale_x_rstd * __ldg(mean + i); -#else - running_mean[i] = a * mean[i] + b * running_mean[i]; - running_var[i] = a * var[i] + b * running_var[i]; - const float rstd_val = rsqrtf(var[i] + epsilon); - const float scale_x_rstd = scale[i] * rstd_val; - rstd[i] = rstd_val; - alpha[i] = scale_x_rstd; - beta[i] = bias[i] - scale_x_rstd * mean[i]; -#endif - } -} - -template -__global__ void ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel( - const int C, - const T inv_num_batches, - const T inv_nhw, - const T* scale, - const T* mean, - const T* rstd, - const T* dscale_sum, - const T* dbias_sum, - T* dscale, - T* dbias, - T* alpha, - T* beta, - T* gamma) { - CUDA_1D_KERNEL_LOOP(i, C) { -#if __CUDA_ARCH__ >= 350 - const T dscale_val = __ldg(dscale_sum + i) * inv_num_batches; - const T dbias_val = __ldg(dbias_sum + i) * inv_num_batches; - const T scale_x_rstd = __ldg(scale + i) * __ldg(rstd + i); - const T dscale_x_rstd = dscale_val * __ldg(rstd + i); - dscale[i] = dscale_val; - dbias[i] = dbias_val; - alpha[i] = scale_x_rstd; - beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw; - gamma[i] = - scale_x_rstd * (__ldg(mean + i) * dscale_x_rstd - dbias_val) * inv_nhw; -#else - const T dscale_val = dscale_sum[i] * inv_num_batches; - const T dbias_val = dbias_sum[i] * inv_num_batches; - const T scale_x_rstd = scale[i] * rstd[i]; - const T dscale_x_rstd = dscale_val * rstd[i]; - dscale[i] = dscale_val; - dbias[i] = dbias_val; - alpha[i] = scale_x_rstd; - beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw; - gamma[i] = scale_x_rstd * (mean[i] * dscale_x_rstd - dbias_val) * inv_nhw; -#endif - } -} - -template -__global__ void ComputeScaleBiasGradientsAndFusedParamsCUDAKernel( - const int N, - const int C, - const int HxW, - const T* dY, - const T* X, - const T* scale, - const T* mean, - const T* rstd, - T* dscale, - T* dbias, - T* alpha, - T* beta, - T* gamma) { - const int outer_size = C; - const int inner_size = N * HxW; - const T inv_nhw = T(1) / static_cast(N * HxW); - __shared__ typename BlockReduce::TempStorage ds_storage; - __shared__ typename BlockReduce::TempStorage db_storage; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - T ds_val = 0; - T db_val = 0; - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int index = kOrder == StorageOrder::NCHW - ? (j / HxW * C + i) * HxW + j % HxW - : j * C + i; - ds_val += dY[index] * (X[index] - mean[i]) * rstd[i]; - db_val += dY[index]; - } - ds_val = BlockReduce(ds_storage).Sum(ds_val); - db_val = BlockReduce(db_storage).Sum(db_val); - if (threadIdx.x == 0) { -#if __CUDA_ARCH__ >= 350 - const T scale_x_rstd = __ldg(scale + i) * __ldg(rstd + i); - const T dscale_x_rstd = ds_val * __ldg(rstd + i); - dscale[i] = ds_val; - dbias[i] = db_val; - alpha[i] = scale_x_rstd; - beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw; - gamma[i] = - scale_x_rstd * (__ldg(mean + i) * dscale_x_rstd - db_val) * inv_nhw; -#else - const T scale_x_rstd = scale[i] * rstd[i]; - const T dscale_x_rstd = ds_val * rstd[i]; - dscale[i] = ds_val; - dbias[i] = db_val; - alpha[i] = scale_x_rstd; - beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw; - gamma[i] = scale_x_rstd * (mean[i] * dscale_x_rstd - db_val) * inv_nhw; -#endif - } - __syncthreads(); - } -} - -template -__global__ void ComputeXGradientCUDAKernel( - const int size, - const int C, - const int HxW, - const T* dY, - const T* X, - const T* alpha, - const T* beta, - const T* gamma, - T* dX) { - CUDA_1D_KERNEL_LOOP(i, size) { - const int c = kOrder == StorageOrder::NCHW ? i / HxW % C : i % C; -#if __CUDA_ARCH__ >= 350 - dX[i] = __ldg(alpha + c) * __ldg(dY + i) + __ldg(beta + c) * __ldg(X + i) + - __ldg(gamma + c); -#else - dX[i] = alpha[c] * dY[i] + beta[c] * X[i] + gamma[c]; -#endif - } -} - -} // namespace - -template <> -template -void SpatialBNOp::ComputeFusedParam( - const int C, - const T* scale, - const T* bias, - const T* mean, - const T* var, - T* alpha, - T* beta) { - ComputeFusedParamCUDAKernel - <<>>( - C, static_cast(epsilon_), scale, bias, mean, var, alpha, beta); -} - -template <> -template -void SpatialBNOp::ComputeBatchMoments( - const int N, - const int C, - const int HxW, - const T* batch_mean_sum, - const T* batch_var_sum, - T* mean, - T* var) { - const T scale = T(1) / static_cast(num_batches_ * N * HxW); - ComputeBatchMomentsCUDAKernel - <<>>( - C, scale, batch_mean_sum, batch_var_sum, mean, var); -} - -template <> -template -void SpatialBNOp::ComputeRunningMomentsAndFusedParam( - const int C, - const T* scale, - const T* bias, - const T* mean, - const T* var, - T* running_mean, - T* running_var, - T* rstd, - T* alpha, - T* beta) { - ComputeRunningMomentsAndFusedParamCUDAKernel - <<>>( - C, - static_cast(momentum_), - static_cast(epsilon_), - scale, - bias, - mean, - var, - running_mean, - running_var, - rstd, - alpha, - beta); -} - -template <> -template -void SpatialBNGradientOp:: - ComputeMultiBatchScaleBiasGradientsAndFusedParams( - const int N, - const int C, - const int HxW, - const T* scale, - const T* mean, - const T* rstd, - const T* dscale_sum, - const T* dbias_sum, - T* dscale, - T* dbias, - T* alpha, - T* beta, - T* gamma) { - const T inv_num_batches = T(1) / static_cast(num_batches_); - const T inv_nhw = T(1) / static_cast(N * HxW); - ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel - <<>>( - C, - inv_num_batches, - inv_nhw, - scale, - mean, - rstd, - dscale_sum, - dbias_sum, - dscale, - dbias, - alpha, - beta, - gamma); -} - -template <> -template -void SpatialBNGradientOp::ComputeScaleBiasGradientsAndFusedParams( - const int N, - const int C, - const int HxW, - const T* dY, - const T* X, - const T* scale, - const T* mean, - const T* rstd, - T* dscale, - T* dbias, - T* alpha, - T* beta, - T* gamma) { - if (order_ == StorageOrder::NCHW) { - ComputeScaleBiasGradientsAndFusedParamsCUDAKernel - <<>>( - N, - C, - HxW, - dY, - X, - scale, - mean, - rstd, - dscale, - dbias, - alpha, - beta, - gamma); - } else { - ComputeScaleBiasGradientsAndFusedParamsCUDAKernel - <<>>( - N, - C, - HxW, - dY, - X, - scale, - mean, - rstd, - dscale, - dbias, - alpha, - beta, - gamma); - } -} - -template <> -template -void SpatialBNGradientOp::ComputeXGradient( - const int N, - const int C, - const int HxW, - const T* dY, - const T* X, - const T* alpha, - const T* beta, - const T* gamma, - T* dX) { - const int size = N * C * HxW; - if (order_ == StorageOrder::NCHW) { - ComputeXGradientCUDAKernel - <<>>( - size, C, HxW, dY, X, alpha, beta, gamma, dX); - } else { - ComputeXGradientCUDAKernel - <<>>( - size, C, HxW, dY, X, alpha, beta, gamma, dX); - } -} - -} // namespace caffe2