Revert D13747581: Optimize SpatialBN on GPU
authorHoa Dinh <dvh@fb.com>
Thu, 24 Jan 2019 23:20:09 +0000 (15:20 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 24 Jan 2019 23:26:37 +0000 (15:26 -0800)
Differential Revision:
D13747581

Original commit changeset: 48a885a240ef

fbshipit-source-id: 58cec6023843d7459865eb80c9db8dac463cb96c

caffe2/operators/hip/spatial_batch_norm_op_miopen.hip
caffe2/operators/spatial_batch_norm_gradient_op.cc
caffe2/operators/spatial_batch_norm_op.cc
caffe2/operators/spatial_batch_norm_op.cu
caffe2/operators/spatial_batch_norm_op.h
caffe2/operators/spatial_batch_norm_op_cudnn.cu [moved from caffe2/operators/spatial_batch_norm_op_cudnn.cc with 99% similarity]
caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh [new file with mode: 0644]

index d4ab900..ae905da 100644 (file)
@@ -18,6 +18,7 @@
 #include "caffe2/core/hip/context_gpu.h"
 #include "caffe2/core/hip/miopen_wrapper.h"
 #include "caffe2/operators/spatial_batch_norm_op.h"
+#include "caffe2/operators/hip/spatial_batch_norm_op_gpu_impl.cuh"
 #include "caffe2/utils/math.h"
 
 const double MIOPEN_BN_MIN_EPSILON = 1e-6;
index e7b27db..49d07f6 100644 (file)
@@ -58,8 +58,7 @@ void SpatialBNGradientOp<CPUContext>::ComputeScaleBiasGradientsAndFusedParams(
     T* dbias,
     T* alpha,
     T* beta,
-    T* gamma,
-    T* /* scratch */) {
+    T* gamma) {
   ConstEigenVectorArrayMap<T> scale_arr(scale, C);
   ConstEigenVectorArrayMap<T> mean_arr(mean, C);
   ConstEigenVectorArrayMap<T> rstd_arr(rstd, C);
index 2e6c9fc..f1b3698 100644 (file)
@@ -6,70 +6,6 @@
 
 namespace caffe2 {
 
-template <>
-template <>
-void SpatialBNOp<CPUContext>::ComputeFusedParam<float>(
-    const int C,
-    const float* scale,
-    const float* bias,
-    const float* mean,
-    const float* var,
-    float* alpha,
-    float* beta) {
-  EigenVectorArrayMap<float> alpha_arr(alpha, C);
-  EigenVectorArrayMap<float> beta_arr(beta, C);
-  alpha_arr = ConstEigenVectorArrayMap<float>(scale, C) *
-      (ConstEigenVectorArrayMap<float>(var, C) + static_cast<float>(epsilon_))
-          .rsqrt();
-  beta_arr = ConstEigenVectorArrayMap<float>(bias, C) -
-      alpha_arr * ConstEigenVectorArrayMap<float>(mean, C);
-}
-
-template <>
-template <>
-void SpatialBNOp<CPUContext>::ComputeBatchMoments<float>(
-    const int N,
-    const int C,
-    const int HxW,
-    const float* batch_mean_sum,
-    const float* batch_var_sum,
-    float* mean,
-    float* var) {
-  const float scale = 1.0f / static_cast<float>(num_batches_ * N * HxW);
-  EigenVectorArrayMap<float> mean_arr(mean, C);
-  EigenVectorArrayMap<float> var_arr(var, C);
-  mean_arr = ConstEigenVectorArrayMap<float>(batch_mean_sum, C) * scale;
-  var_arr = ConstEigenVectorArrayMap<float>(batch_var_sum, C) * scale -
-      mean_arr.square();
-}
-
-template <>
-template <>
-void SpatialBNOp<CPUContext>::ComputeRunningMomentsAndFusedParam<float>(
-    const int C,
-    const float* scale,
-    const float* bias,
-    const float* mean,
-    const float* var,
-    float* running_mean,
-    float* running_var,
-    float* rstd,
-    float* alpha,
-    float* beta) {
-  const float a = 1.0f - momentum_;
-  const float b = momentum_;
-  math::Axpby<float, float, CPUContext>(C, a, mean, b, running_mean, &context_);
-  math::Axpby<float, float, CPUContext>(C, a, var, b, running_var, &context_);
-  math::InvStd<float, CPUContext>(
-      C, static_cast<float>(epsilon_), var, rstd, &context_);
-  EigenVectorArrayMap<float> alpha_arr(alpha, C);
-  EigenVectorArrayMap<float> beta_arr(beta, C);
-  alpha_arr = ConstEigenVectorArrayMap<float>(scale, C) *
-      ConstEigenVectorArrayMap<float>(rstd, C);
-  beta_arr = ConstEigenVectorArrayMap<float>(bias, C) -
-      alpha_arr * ConstEigenVectorArrayMap<float>(mean, C);
-}
-
 namespace {
 
 OpSchema::Cost CostInferenceForSpatialBN(
index d7ec4df..ecfc200 100644 (file)
@@ -1,642 +1,9 @@
 #include "caffe2/operators/spatial_batch_norm_op.h"
 
-#include <cub/block/block_reduce.cuh>
-
-#include "caffe2/core/context_gpu.h"
-#include "caffe2/utils/math.h"
+#include "caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh"
 
 namespace caffe2 {
 
-namespace {
-
-template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
-    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
-__global__ void ComputeFusedParamCUDAKernel(
-    const int C,
-    const T epsilon,
-    const T* scale,
-    const T* bias,
-    const T* mean,
-    const T* var,
-    T* alpha,
-    T* beta);
-
-template <>
-__global__ void ComputeFusedParamCUDAKernel<float>(
-    const int C,
-    const float epsilon,
-    const float* scale,
-    const float* bias,
-    const float* mean,
-    const float* var,
-    float* alpha,
-    float* beta) {
-  const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (c < C) {
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    const float scale_x_rstd =
-        __ldg(scale + c) * rsqrtf(__ldg(var + c) + epsilon);
-    alpha[c] = scale_x_rstd;
-    beta[c] = fmaf(-scale_x_rstd, __ldg(mean + c), __ldg(bias + c));
-#else
-    const float scale_x_rstd = scale[c] * rsqrtf(var[c] + epsilon);
-    alpha[c] = scale_x_rstd;
-    beta[c] = fmaf(-scale_x_rstd, mean[c], bias[c]);
-#endif
-  }
-}
-
-template <typename T>
-__global__ void ComputeBatchMomentsCUDAKernel(
-    const int C,
-    const T scale,
-    const T* batch_mean_sum,
-    const T* batch_var_sum,
-    T* mean,
-    T* var);
-
-template <>
-__global__ void ComputeBatchMomentsCUDAKernel<float>(
-    const int C,
-    const float scale,
-    const float* batch_mean_sum,
-    const float* batch_var_sum,
-    float* mean,
-    float* var) {
-  const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (c < C) {
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    const float mu = __ldg(batch_mean_sum + c) * scale;
-    mean[c] = mu;
-    var[c] = fmaf(__ldg(batch_var_sum + c), scale, -mu * mu);
-#else
-    const float mu = batch_mean_sum[c] * scale;
-    mean[c] = mu;
-    var[c] = fmaf(batch_var_sum[c], scale, -mu * mu);
-#endif
-  }
-}
-
-template <typename T>
-__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel(
-    const int C,
-    const T momentum,
-    const T epsilon,
-    const T* scale,
-    const T* bias,
-    const T* mean,
-    const T* var,
-    T* running_mean,
-    T* running_var,
-    T* rstd,
-    T* alpha,
-    T* beta);
-
-template <>
-__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel<float>(
-    const int C,
-    const float momentum,
-    const float epsilon,
-    const float* scale,
-    const float* bias,
-    const float* mean,
-    const float* var,
-    float* running_mean,
-    float* running_var,
-    float* rstd,
-    float* alpha,
-    float* beta) {
-  const float a = 1.0f - momentum;
-  const float b = momentum;
-  const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (c < C) {
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    running_mean[c] = fmaf(a, __ldg(mean + c), b * __ldg(running_mean + c));
-    running_var[c] = fmaf(a, __ldg(var + c), b * __ldg(running_var + c));
-    const float rstd_val = rsqrtf(__ldg(var + c) + epsilon);
-    const float scale_x_rstd = __ldg(scale + c) * rstd_val;
-    rstd[c] = rstd_val;
-    alpha[c] = scale_x_rstd;
-    beta[c] = fmaf(-scale_x_rstd, __ldg(mean + c), __ldg(bias + c));
-#else
-    running_mean[c] = fmaf(a, mean[c], b * running_mean[c]);
-    running_var[c] = fmaf(a, var[c], b * running_var[c]);
-    const float rstd_val = rsqrtf(var[c] + epsilon);
-    const float scale_x_rstd = scale[c] * rstd_val;
-    rstd[c] = rstd_val;
-    alpha[c] = scale_x_rstd;
-    beta[c] = fmaf(-scale_x_rstd, mean[c], bias[c]);
-#endif
-  }
-}
-
-template <typename T>
-__global__ void ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel(
-    const int C,
-    const T inv_num_batches,
-    const T inv_nhw,
-    const T* scale,
-    const T* mean,
-    const T* rstd,
-    const T* dscale_sum,
-    const T* dbias_sum,
-    T* dscale,
-    T* dbias,
-    T* alpha,
-    T* beta,
-    T* gamma);
-
-template <>
-__global__ void
-ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel<float>(
-    const int C,
-    const float inv_num_batches,
-    const float inv_nhw,
-    const float* scale,
-    const float* mean,
-    const float* rstd,
-    const float* dscale_sum,
-    const float* dbias_sum,
-    float* dscale,
-    float* dbias,
-    float* alpha,
-    float* beta,
-    float* gamma) {
-  const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (c < C) {
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    const float dscale_val = __ldg(dscale_sum + c) * inv_num_batches;
-    const float dbias_val = __ldg(dbias_sum + c) * inv_num_batches;
-    const float scale_x_rstd = __ldg(scale + c) * __ldg(rstd + c);
-    const float dscale_x_rstd = dscale_val * __ldg(rstd + c);
-    dscale[c] = dscale_val;
-    dbias[c] = dbias_val;
-    alpha[c] = scale_x_rstd;
-    beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
-    gamma[c] = scale_x_rstd * fmaf(__ldg(mean + c), dscale_x_rstd, -dbias_val) *
-        inv_nhw;
-#else
-    const float dscale_val = dscale_sum[c] * inv_num_batches;
-    const float dbias_val = dbias_sum[c] * inv_num_batches;
-    const float scale_x_rstd = scale[c] * rstd[c];
-    const float dscale_x_rstd = dscale_val * rstd[c];
-    dscale[c] = dscale_val;
-    dbias[c] = dbias_val;
-    alpha[c] = scale_x_rstd;
-    beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
-    gamma[c] =
-        scale_x_rstd * fmaf(mean[c], dscale_x_rstd, -dbias_val) * inv_nhw;
-#endif
-  }
-}
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-__global__ void ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel(
-    const int N,
-    const int C,
-    const int HxW,
-    const T* dY,
-    const T* X,
-    const T* scale,
-    const T* mean,
-    const T* rstd,
-    T* dscale,
-    T* dbias,
-    T* alpha,
-    T* beta,
-    T* gamma) {
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
-  const T inv_nhw = T(1) / static_cast<T>(N * HxW);
-  const int c = blockIdx.x;
-  T ds_val = 0;
-  T db_val = 0;
-  for (int i = threadIdx.x; i < N; i += blockDim.x) {
-    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
-      const int index = (i * C + c) * HxW + j;
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-      ds_val += __ldg(dY + index) * __ldg(X + index);
-      db_val += __ldg(dY + index);
-#else
-      ds_val += dY[index] * X[index];
-      db_val += dY[index];
-#endif
-    }
-  }
-  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
-  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
-  if (threadIdx.x == 0 && threadIdx.y == 0) {
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    ds_val = (ds_val - db_val * __ldg(mean + c)) * __ldg(rstd + c);
-    const T scale_x_rstd = __ldg(scale + c) * __ldg(rstd + c);
-    const T dscale_x_rstd = ds_val * __ldg(rstd + c);
-    dscale[c] = ds_val;
-    dbias[c] = db_val;
-    alpha[c] = scale_x_rstd;
-    beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
-    gamma[c] =
-        scale_x_rstd * (__ldg(mean + c) * dscale_x_rstd - db_val) * inv_nhw;
-#else
-    ds_val = (ds_val - db_val * mean[c]) * rstd[c];
-    const T scale_x_rstd = scale[c] * rstd[c];
-    const T dscale_x_rstd = ds_val * rstd[c];
-    dscale[c] = ds_val;
-    dbias[c] = db_val;
-    alpha[c] = scale_x_rstd;
-    beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
-    gamma[c] = scale_x_rstd * (mean[c] * dscale_x_rstd - db_val) * inv_nhw;
-#endif
-  }
-}
-
-template <typename T>
-__global__ void ComputeScaleGradientAndFusedParamsNHWCCUDAKernel(
-    const int C,
-    const T inv_nhw,
-    const T* dYxX,
-    const T* dbias,
-    const T* scale,
-    const T* mean,
-    const T* rstd,
-    T* dscale,
-    T* alpha,
-    T* beta,
-    T* gamma);
-
-template <>
-__global__ void ComputeScaleGradientAndFusedParamsNHWCCUDAKernel<float>(
-    const int C,
-    const float inv_nhw,
-    const float* dYxX,
-    const float* dbias,
-    const float* scale,
-    const float* mean,
-    const float* rstd,
-    float* dscale,
-    float* alpha,
-    float* beta,
-    float* gamma) {
-  const int c = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (c < C) {
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    const float ds = fmaf(-__ldg(dbias + c), __ldg(mean + c), __ldg(dYxX + c)) *
-        __ldg(rstd + c);
-    dscale[c] = ds;
-    const float scale_x_rstd = __ldg(scale + c) * __ldg(rstd + c);
-    const float dscale_x_rstd = ds * __ldg(rstd + c);
-    alpha[c] = scale_x_rstd;
-    beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
-    gamma[c] = scale_x_rstd *
-        fmaf(__ldg(mean + c), dscale_x_rstd, -__ldg(dbias + c)) * inv_nhw;
-#else
-    const float ds = fmaf(-dbias[c], mean[c], dYxX[c]) * rstd[c];
-    dscale[c] = ds;
-    const float scale_x_rstd = scale[c] * rstd[c];
-    const float dscale_x_rstd = ds * rstd[c];
-    alpha[c] = scale_x_rstd;
-    beta[c] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
-    gamma[c] = scale_x_rstd * fmaf(mean[c], dscale_x_rstd, -dbias[c]) * inv_nhw;
-#endif
-  }
-}
-
-template <typename T>
-__global__ void ComputeXGradientNCHWCUDAKernel(
-    const int C,
-    const int HxW,
-    const int K,
-    const T* dY,
-    const T* X,
-    const T* alpha,
-    const T* beta,
-    const T* gamma,
-    T* dX);
-
-template <>
-__global__ void ComputeXGradientNCHWCUDAKernel<float>(
-    const int C,
-    const int HxW,
-    const int K,
-    const float* dY,
-    const float* X,
-    const float* alpha,
-    const float* beta,
-    const float* gamma,
-    float* dX) {
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
-  const int c = nc % C;
-  const int w = block * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (w < HxW) {
-    const int index = nc * HxW + w;
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    dX[index] = fmaf(
-        __ldg(alpha + c),
-        __ldg(dY + index),
-        fmaf(__ldg(beta + c), __ldg(X + index), __ldg(gamma + c)));
-#else
-    dX[index] = fmaf(alpha[c], dY[index], fma(beta[c], X[index], gamma[c]));
-#endif
-  }
-}
-
-template <typename T>
-__global__ void ComputeXGradientNHWCCUDAKernel(
-    const int C,
-    const T* dY,
-    const T* X,
-    const T* alpha,
-    const T* beta,
-    const T* gamma,
-    T* dX);
-
-template <>
-__global__ void ComputeXGradientNHWCCUDAKernel<float>(
-    const int C,
-    const float* dY,
-    const float* X,
-    const float* alpha,
-    const float* beta,
-    const float* gamma,
-    float* dX) {
-  for (int c = threadIdx.x; c < C; c += blockDim.x) {
-    const int index = blockIdx.x * C + c;
-#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
-    dX[index] = fmaf(
-        __ldg(alpha + c),
-        __ldg(dY + index),
-        fmaf(__ldg(beta + c), __ldg(X + index), __ldg(gamma + c)));
-#else
-    dX[index] = fmaf(alpha[c], dY[index], fma(beta[c], X[index], gamma[c]));
-#endif
-  }
-}
-
-} // namespace
-
-template <>
-template <>
-void SpatialBNOp<CUDAContext>::ComputeFusedParam<float>(
-    const int C,
-    const float* scale,
-    const float* bias,
-    const float* mean,
-    const float* var,
-    float* alpha,
-    float* beta) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
-  ComputeFusedParamCUDAKernel<float>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          C, static_cast<float>(epsilon_), scale, bias, mean, var, alpha, beta);
-}
-
-template <>
-template <>
-void SpatialBNOp<CUDAContext>::ComputeBatchMoments<float>(
-    const int N,
-    const int C,
-    const int HxW,
-    const float* batch_mean_sum,
-    const float* batch_var_sum,
-    float* mean,
-    float* var) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
-  const float scale = 1.0f / static_cast<float>(num_batches_ * N * HxW);
-  ComputeBatchMomentsCUDAKernel<float>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          C, scale, batch_mean_sum, batch_var_sum, mean, var);
-}
-
-template <>
-template <>
-void SpatialBNOp<CUDAContext>::ComputeRunningMomentsAndFusedParam<float>(
-    const int C,
-    const float* scale,
-    const float* bias,
-    const float* mean,
-    const float* var,
-    float* running_mean,
-    float* running_var,
-    float* rstd,
-    float* alpha,
-    float* beta) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
-  ComputeRunningMomentsAndFusedParamCUDAKernel<float>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          C,
-          static_cast<float>(momentum_),
-          static_cast<float>(epsilon_),
-          scale,
-          bias,
-          mean,
-          var,
-          running_mean,
-          running_var,
-          rstd,
-          alpha,
-          beta);
-}
-
-template <>
-template <>
-void SpatialBNGradientOp<CUDAContext>::
-    ComputeMultiBatchScaleBiasGradientsAndFusedParams<float>(
-        const int N,
-        const int C,
-        const int HxW,
-        const float* scale,
-        const float* mean,
-        const float* rstd,
-        const float* dscale_sum,
-        const float* dbias_sum,
-        float* dscale,
-        float* dbias,
-        float* alpha,
-        float* beta,
-        float* gamma) {
-  const float inv_num_batches = 1.0f / static_cast<float>(num_batches_);
-  const float inv_nhw = 1.0f / static_cast<float>(N * HxW);
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
-  ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel<float>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          C,
-          inv_num_batches,
-          inv_nhw,
-          scale,
-          mean,
-          rstd,
-          dscale_sum,
-          dbias_sum,
-          dscale,
-          dbias,
-          alpha,
-          beta,
-          gamma);
-}
-
-template <>
-template <>
-void SpatialBNGradientOp<CUDAContext>::ComputeScaleBiasGradientsAndFusedParams<
-    float>(
-    const int N,
-    const int C,
-    const int HxW,
-    const float* dY,
-    const float* X,
-    const float* scale,
-    const float* mean,
-    const float* rstd,
-    float* dscale,
-    float* dbias,
-    float* alpha,
-    float* beta,
-    float* gamma,
-    float* scratch) {
-  if (order_ == StorageOrder::NCHW) {
-    if (HxW >= 128) {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<float, 1, 128>
-          <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    } else if (HxW >= 64) {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<float, 2, 64>
-          <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    } else if (HxW >= 32) {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<float, 4, 32>
-          <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    } else {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<float, 8, 16>
-          <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    }
-  } else {
-    ReinitializeTensor(&ones_, {N * HxW}, at::dtype<float>().device(CUDA));
-    math::Set<float, CUDAContext>(
-        N * HxW, 1.0f, ones_.mutable_data<float>(), &context_);
-    const float* ones_data = ones_.data<float>();
-    math::Mul<float, CUDAContext>(N * C * HxW, dY, X, scratch, &context_);
-    math::Gemm<float, CUDAContext>(
-        CblasTrans,
-        CblasNoTrans,
-        C,
-        1,
-        N * HxW,
-        1.0f,
-        scratch,
-        ones_data,
-        0.0f,
-        dscale,
-        &context_);
-    math::Gemm<float, CUDAContext>(
-        CblasTrans,
-        CblasNoTrans,
-        C,
-        1,
-        N * HxW,
-        1.0f,
-        dY,
-        ones_data,
-        0.0f,
-        dbias,
-        &context_);
-    const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
-    ComputeScaleGradientAndFusedParamsNHWCCUDAKernel<float>
-        <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            C,
-            1.0f / static_cast<float>(N * HxW),
-            dscale,
-            dbias,
-            scale,
-            mean,
-            rstd,
-            dscale,
-            alpha,
-            beta,
-            gamma);
-  }
-}
-
-template <>
-template <>
-void SpatialBNGradientOp<CUDAContext>::ComputeXGradient<float>(
-    const int N,
-    const int C,
-    const int HxW,
-    const float* dY,
-    const float* X,
-    const float* alpha,
-    const float* beta,
-    const float* gamma,
-    float* dX) {
-  if (order_ == StorageOrder::NCHW) {
-    const int K = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
-    ComputeXGradientNCHWCUDAKernel<float>
-        <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            C, HxW, K, dY, X, alpha, beta, gamma, dX);
-  } else {
-    ComputeXGradientNHWCCUDAKernel<float>
-        <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            C, dY, X, alpha, beta, gamma, dX);
-  }
-}
-
 REGISTER_CUDA_OPERATOR(SpatialBN, SpatialBNOp<CUDAContext>);
 REGISTER_CUDA_OPERATOR(SpatialBNGradient, SpatialBNGradientOp<CUDAContext>);
 
index 5aa4989..d56ff1a 100644 (file)
@@ -204,7 +204,15 @@ class SpatialBNOp : public Operator<Context> {
       const T* mean,
       const T* var,
       T* alpha,
-      T* beta);
+      T* beta) {
+    EigenVectorArrayMap<T> alpha_arr(alpha, C);
+    EigenVectorArrayMap<T> beta_arr(beta, C);
+    alpha_arr = ConstEigenVectorArrayMap<T>(scale, C) *
+        (ConstEigenVectorArrayMap<T>(var, C) + static_cast<T>(epsilon_))
+            .rsqrt();
+    beta_arr = ConstEigenVectorArrayMap<T>(bias, C) -
+        alpha_arr * ConstEigenVectorArrayMap<T>(mean, C);
+  }
 
   template <typename T>
   void ComputeBatchMoments(
@@ -214,7 +222,14 @@ class SpatialBNOp : public Operator<Context> {
       const T* batch_mean_sum,
       const T* batch_var_sum,
       T* mean,
-      T* var);
+      T* var) {
+    const T scale = T(1) / static_cast<T>(num_batches_ * N * HxW);
+    EigenVectorArrayMap<T> mean_arr(mean, C);
+    EigenVectorArrayMap<T> var_arr(var, C);
+    mean_arr = ConstEigenVectorArrayMap<T>(batch_mean_sum, C) * scale;
+    var_arr = ConstEigenVectorArrayMap<T>(batch_var_sum, C) * scale -
+        mean_arr.square();
+  }
 
   template <typename T>
   void ComputeRunningMomentsAndFusedParam(
@@ -227,7 +242,19 @@ class SpatialBNOp : public Operator<Context> {
       T* running_var,
       T* rstd,
       T* alpha,
-      T* beta);
+      T* beta) {
+    const T a = T(1) - static_cast<T>(momentum_);
+    const T b = static_cast<T>(momentum_);
+    math::Axpby<T, T, Context>(C, a, mean, b, running_mean, &context_);
+    math::Axpby<T, T, Context>(C, a, var, b, running_var, &context_);
+    math::InvStd<T, Context>(C, static_cast<T>(epsilon_), var, rstd, &context_);
+    EigenVectorArrayMap<T> alpha_arr(alpha, C);
+    EigenVectorArrayMap<T> beta_arr(beta, C);
+    alpha_arr = ConstEigenVectorArrayMap<T>(scale, C) *
+        ConstEigenVectorArrayMap<T>(rstd, C);
+    beta_arr = ConstEigenVectorArrayMap<T>(bias, C) -
+        alpha_arr * ConstEigenVectorArrayMap<T>(mean, C);
+  }
 
   const bool is_test_;
   double epsilon_;
@@ -365,8 +392,7 @@ class SpatialBNGradientOp : public Operator<Context> {
           dbias_data,
           alpha_data,
           beta_data,
-          gamma_data,
-          dX_data);
+          gamma_data);
     }
     ComputeXGradient<T>(
         N, C, HxW, dY_data, X_data, alpha_data, beta_data, gamma_data, dX_data);
@@ -405,8 +431,7 @@ class SpatialBNGradientOp : public Operator<Context> {
       T* dbias,
       T* alpha,
       T* beta,
-      T* gamma,
-      T* scratch);
+      T* gamma);
 
   template <typename T>
   void ComputeXGradient(
@@ -427,7 +452,6 @@ class SpatialBNGradientOp : public Operator<Context> {
   Tensor alpha_;
   Tensor beta_;
   Tensor gamma_;
-  Tensor ones_;
 
   INPUT_TAGS(
       INPUT,
@@ -7,6 +7,7 @@
 
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/core/cudnn_wrappers.h"
+#include "caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh"
 #include "caffe2/utils/math.h"
 
 #if CUDNN_VERSION_MIN(5, 0, 0)
diff --git a/caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh b/caffe2/operators/spatial_batch_norm_op_gpu_impl.cuh
new file mode 100644 (file)
index 0000000..7d87e81
--- /dev/null
@@ -0,0 +1,439 @@
+#include "caffe2/operators/spatial_batch_norm_op.h"
+
+#include <cub/block/block_reduce.cuh>
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+namespace {
+
+template <typename T>
+using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
+
+template <typename T>
+__global__ void ComputeFusedParamCUDAKernel(
+    const int C,
+    const T epsilon,
+    const T* scale,
+    const T* bias,
+    const T* mean,
+    const T* var,
+    T* alpha,
+    T* beta);
+
+template <>
+__global__ void ComputeFusedParamCUDAKernel<float>(
+    const int C,
+    const float epsilon,
+    const float* scale,
+    const float* bias,
+    const float* mean,
+    const float* var,
+    float* alpha,
+    float* beta) {
+  CUDA_1D_KERNEL_LOOP(i, C) {
+#if __CUDA_ARCH__ >= 350
+    const float scale_x_rstd =
+        __ldg(scale + i) * rsqrtf(__ldg(var + i) + epsilon);
+    alpha[i] = scale_x_rstd;
+    beta[i] = __ldg(bias + i) - scale_x_rstd * __ldg(mean + i);
+#else
+    const float scale_x_rstd = scale[i] * rsqrtf(var[i] + epsilon);
+    alpha[i] = scale_x_rstd;
+    beta[i] = bias[i] - scale_x_rstd * mean[i];
+#endif
+  }
+}
+
+template <typename T>
+__global__ void ComputeBatchMomentsCUDAKernel(
+    const int C,
+    const T scale,
+    const T* batch_mean_sum,
+    const T* batch_var_sum,
+    T* mean,
+    T* var) {
+  CUDA_1D_KERNEL_LOOP(i, C) {
+#if __CUDA_ARCH__ >= 350
+    const T mu = __ldg(batch_mean_sum + i) * scale;
+    mean[i] = mu;
+    var[i] = __ldg(batch_var_sum + i) * scale - mu * mu;
+#else
+    const T mu = batch_mean_sum[i] * scale;
+    mean[i] = mu;
+    var[i] = batch_var_sum[i] * scale - mu * mu;
+#endif
+  }
+}
+
+template <typename T>
+__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel(
+    const int C,
+    const T momentum,
+    const T epsilon,
+    const T* scale,
+    const T* bias,
+    const T* mean,
+    const T* var,
+    T* running_mean,
+    T* running_var,
+    T* rstd,
+    T* alpha,
+    T* beta);
+
+template <>
+__global__ void ComputeRunningMomentsAndFusedParamCUDAKernel<float>(
+    const int C,
+    const float momentum,
+    const float epsilon,
+    const float* scale,
+    const float* bias,
+    const float* mean,
+    const float* var,
+    float* running_mean,
+    float* running_var,
+    float* rstd,
+    float* alpha,
+    float* beta) {
+  const float a = 1.0f - momentum;
+  const float b = momentum;
+  CUDA_1D_KERNEL_LOOP(i, C) {
+#if __CUDA_ARCH__ >= 350
+    running_mean[i] = a * __ldg(mean + i) + b * __ldg(running_mean + i);
+    running_var[i] = a * __ldg(var + i) + b * __ldg(running_var + i);
+    const float rstd_val = rsqrtf(__ldg(var + i) + epsilon);
+    const float scale_x_rstd = __ldg(scale + i) * rstd_val;
+    rstd[i] = rstd_val;
+    alpha[i] = scale_x_rstd;
+    beta[i] = bias[i] - scale_x_rstd * __ldg(mean + i);
+#else
+    running_mean[i] = a * mean[i] + b * running_mean[i];
+    running_var[i] = a * var[i] + b * running_var[i];
+    const float rstd_val = rsqrtf(var[i] + epsilon);
+    const float scale_x_rstd = scale[i] * rstd_val;
+    rstd[i] = rstd_val;
+    alpha[i] = scale_x_rstd;
+    beta[i] = bias[i] - scale_x_rstd * mean[i];
+#endif
+  }
+}
+
+template <typename T>
+__global__ void ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel(
+    const int C,
+    const T inv_num_batches,
+    const T inv_nhw,
+    const T* scale,
+    const T* mean,
+    const T* rstd,
+    const T* dscale_sum,
+    const T* dbias_sum,
+    T* dscale,
+    T* dbias,
+    T* alpha,
+    T* beta,
+    T* gamma) {
+  CUDA_1D_KERNEL_LOOP(i, C) {
+#if __CUDA_ARCH__ >= 350
+    const T dscale_val = __ldg(dscale_sum + i) * inv_num_batches;
+    const T dbias_val = __ldg(dbias_sum + i) * inv_num_batches;
+    const T scale_x_rstd = __ldg(scale + i) * __ldg(rstd + i);
+    const T dscale_x_rstd = dscale_val * __ldg(rstd + i);
+    dscale[i] = dscale_val;
+    dbias[i] = dbias_val;
+    alpha[i] = scale_x_rstd;
+    beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
+    gamma[i] =
+        scale_x_rstd * (__ldg(mean + i) * dscale_x_rstd - dbias_val) * inv_nhw;
+#else
+    const T dscale_val = dscale_sum[i] * inv_num_batches;
+    const T dbias_val = dbias_sum[i] * inv_num_batches;
+    const T scale_x_rstd = scale[i] * rstd[i];
+    const T dscale_x_rstd = dscale_val * rstd[i];
+    dscale[i] = dscale_val;
+    dbias[i] = dbias_val;
+    alpha[i] = scale_x_rstd;
+    beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
+    gamma[i] = scale_x_rstd * (mean[i] * dscale_x_rstd - dbias_val) * inv_nhw;
+#endif
+  }
+}
+
+template <typename T, StorageOrder kOrder>
+__global__ void ComputeScaleBiasGradientsAndFusedParamsCUDAKernel(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* scale,
+    const T* mean,
+    const T* rstd,
+    T* dscale,
+    T* dbias,
+    T* alpha,
+    T* beta,
+    T* gamma) {
+  const int outer_size = C;
+  const int inner_size = N * HxW;
+  const T inv_nhw = T(1) / static_cast<T>(N * HxW);
+  __shared__ typename BlockReduce<T>::TempStorage ds_storage;
+  __shared__ typename BlockReduce<T>::TempStorage db_storage;
+  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
+    T ds_val = 0;
+    T db_val = 0;
+    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
+      const int index = kOrder == StorageOrder::NCHW
+          ? (j / HxW * C + i) * HxW + j % HxW
+          : j * C + i;
+      ds_val += dY[index] * (X[index] - mean[i]) * rstd[i];
+      db_val += dY[index];
+    }
+    ds_val = BlockReduce<T>(ds_storage).Sum(ds_val);
+    db_val = BlockReduce<T>(db_storage).Sum(db_val);
+    if (threadIdx.x == 0) {
+#if __CUDA_ARCH__ >= 350
+      const T scale_x_rstd = __ldg(scale + i) * __ldg(rstd + i);
+      const T dscale_x_rstd = ds_val * __ldg(rstd + i);
+      dscale[i] = ds_val;
+      dbias[i] = db_val;
+      alpha[i] = scale_x_rstd;
+      beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
+      gamma[i] =
+          scale_x_rstd * (__ldg(mean + i) * dscale_x_rstd - db_val) * inv_nhw;
+#else
+      const T scale_x_rstd = scale[i] * rstd[i];
+      const T dscale_x_rstd = ds_val * rstd[i];
+      dscale[i] = ds_val;
+      dbias[i] = db_val;
+      alpha[i] = scale_x_rstd;
+      beta[i] = -scale_x_rstd * dscale_x_rstd * inv_nhw;
+      gamma[i] = scale_x_rstd * (mean[i] * dscale_x_rstd - db_val) * inv_nhw;
+#endif
+    }
+    __syncthreads();
+  }
+}
+
+template <typename T, StorageOrder kOrder>
+__global__ void ComputeXGradientCUDAKernel(
+    const int size,
+    const int C,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* alpha,
+    const T* beta,
+    const T* gamma,
+    T* dX) {
+  CUDA_1D_KERNEL_LOOP(i, size) {
+    const int c = kOrder == StorageOrder::NCHW ? i / HxW % C : i % C;
+#if __CUDA_ARCH__ >= 350
+    dX[i] = __ldg(alpha + c) * __ldg(dY + i) + __ldg(beta + c) * __ldg(X + i) +
+        __ldg(gamma + c);
+#else
+    dX[i] = alpha[c] * dY[i] + beta[c] * X[i] + gamma[c];
+#endif
+  }
+}
+
+} // namespace
+
+template <>
+template <typename T>
+void SpatialBNOp<CUDAContext>::ComputeFusedParam(
+    const int C,
+    const T* scale,
+    const T* bias,
+    const T* mean,
+    const T* var,
+    T* alpha,
+    T* beta) {
+  ComputeFusedParamCUDAKernel<T>
+      <<<CAFFE_GET_BLOCKS(C),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context_.cuda_stream()>>>(
+          C, static_cast<T>(epsilon_), scale, bias, mean, var, alpha, beta);
+}
+
+template <>
+template <typename T>
+void SpatialBNOp<CUDAContext>::ComputeBatchMoments(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* batch_mean_sum,
+    const T* batch_var_sum,
+    T* mean,
+    T* var) {
+  const T scale = T(1) / static_cast<T>(num_batches_ * N * HxW);
+  ComputeBatchMomentsCUDAKernel<T>
+      <<<CAFFE_GET_BLOCKS(C),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context_.cuda_stream()>>>(
+          C, scale, batch_mean_sum, batch_var_sum, mean, var);
+}
+
+template <>
+template <typename T>
+void SpatialBNOp<CUDAContext>::ComputeRunningMomentsAndFusedParam(
+    const int C,
+    const T* scale,
+    const T* bias,
+    const T* mean,
+    const T* var,
+    T* running_mean,
+    T* running_var,
+    T* rstd,
+    T* alpha,
+    T* beta) {
+  ComputeRunningMomentsAndFusedParamCUDAKernel<T>
+      <<<CAFFE_GET_BLOCKS(C),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context_.cuda_stream()>>>(
+          C,
+          static_cast<T>(momentum_),
+          static_cast<T>(epsilon_),
+          scale,
+          bias,
+          mean,
+          var,
+          running_mean,
+          running_var,
+          rstd,
+          alpha,
+          beta);
+}
+
+template <>
+template <typename T>
+void SpatialBNGradientOp<CUDAContext>::
+    ComputeMultiBatchScaleBiasGradientsAndFusedParams(
+        const int N,
+        const int C,
+        const int HxW,
+        const T* scale,
+        const T* mean,
+        const T* rstd,
+        const T* dscale_sum,
+        const T* dbias_sum,
+        T* dscale,
+        T* dbias,
+        T* alpha,
+        T* beta,
+        T* gamma) {
+  const T inv_num_batches = T(1) / static_cast<T>(num_batches_);
+  const T inv_nhw = T(1) / static_cast<T>(N * HxW);
+  ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel<T>
+      <<<CAFFE_GET_BLOCKS(C),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context_.cuda_stream()>>>(
+          C,
+          inv_num_batches,
+          inv_nhw,
+          scale,
+          mean,
+          rstd,
+          dscale_sum,
+          dbias_sum,
+          dscale,
+          dbias,
+          alpha,
+          beta,
+          gamma);
+}
+
+template <>
+template <typename T>
+void SpatialBNGradientOp<CUDAContext>::ComputeScaleBiasGradientsAndFusedParams(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* scale,
+    const T* mean,
+    const T* rstd,
+    T* dscale,
+    T* dbias,
+    T* alpha,
+    T* beta,
+    T* gamma) {
+  if (order_ == StorageOrder::NCHW) {
+    ComputeScaleBiasGradientsAndFusedParamsCUDAKernel<T, StorageOrder::NCHW>
+        <<<std::min(C, CAFFE_MAXIMUM_NUM_BLOCKS),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context_.cuda_stream()>>>(
+            N,
+            C,
+            HxW,
+            dY,
+            X,
+            scale,
+            mean,
+            rstd,
+            dscale,
+            dbias,
+            alpha,
+            beta,
+            gamma);
+  } else {
+    ComputeScaleBiasGradientsAndFusedParamsCUDAKernel<T, StorageOrder::NHWC>
+        <<<std::min(C, CAFFE_MAXIMUM_NUM_BLOCKS),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context_.cuda_stream()>>>(
+            N,
+            C,
+            HxW,
+            dY,
+            X,
+            scale,
+            mean,
+            rstd,
+            dscale,
+            dbias,
+            alpha,
+            beta,
+            gamma);
+  }
+}
+
+template <>
+template <typename T>
+void SpatialBNGradientOp<CUDAContext>::ComputeXGradient(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* alpha,
+    const T* beta,
+    const T* gamma,
+    T* dX) {
+  const int size = N * C * HxW;
+  if (order_ == StorageOrder::NCHW) {
+    ComputeXGradientCUDAKernel<T, StorageOrder::NCHW>
+        <<<CAFFE_GET_BLOCKS(size),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context_.cuda_stream()>>>(
+            size, C, HxW, dY, X, alpha, beta, gamma, dX);
+  } else {
+    ComputeXGradientCUDAKernel<T, StorageOrder::NHWC>
+        <<<CAFFE_GET_BLOCKS(size),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context_.cuda_stream()>>>(
+            size, C, HxW, dY, X, alpha, beta, gamma, dX);
+  }
+}
+
+} // namespace caffe2