Use macro for reduce on 2d blocks (#16344)
authorXiaomeng Yang <yangxm@fb.com>
Sat, 2 Feb 2019 07:45:38 +0000 (23:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 2 Feb 2019 07:49:07 +0000 (23:49 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16344

Use macro for reduce on 2d blocks

i-am-not-moving-c2-to-c10

Reviewed By: houseroad

Differential Revision: D13808988

fbshipit-source-id: b68c0fb6079c1b6e203a072083aba7a95c202bc2

caffe2/operators/group_norm_op.cu
caffe2/operators/spatial_batch_norm_op_impl.cuh
caffe2/utils/math/elementwise.cu
caffe2/utils/math/reduce.cu
caffe2/utils/math/reduce.cuh [new file with mode: 0644]
tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py

index 62436d2..ad35e12 100644 (file)
@@ -9,22 +9,17 @@
 #include "caffe2/operators/group_norm_op.h"
 
 #include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
 
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/utils/math.h"
+#include "caffe2/utils/math/reduce.cuh"
 
 namespace caffe2 {
 
 namespace {
 
 template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
-    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
 __global__ void ComputeFusedParamsCUDAKernel(
     const int G,
     const int K,
@@ -54,7 +49,7 @@ __global__ void ComputeFusedParamsCUDAKernel(
 
 template <typename T>
 __global__ void GroupNormForwardNCHWCUDAKernel(
-    const int K,
+    const int M,
     const int HxW,
     const T* X,
     const T* scale,
@@ -63,14 +58,14 @@ __global__ void GroupNormForwardNCHWCUDAKernel(
 
 template <>
 __global__ void GroupNormForwardNCHWCUDAKernel<float>(
-    const int W,
+    const int M,
     const int HxW,
     const float* X,
     const float* scale,
     const float* bias,
     float* Y) {
-  const int nc = blockIdx.x / W;
-  const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  const int nc = blockIdx.x / M;
+  const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
   if (hw < HxW) {
     const int index = nc * HxW + hw;
 #if __CUDA_ARCH__ >= 350
@@ -99,7 +94,8 @@ __global__ void GroupNormForwardNHWCCUDAKernel<float>(
     const float* bias,
     float* Y) {
   const int n = blockIdx.x / HxW;
-  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+  const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (c < C) {
     const int index = blockIdx.x * C + c;
     const int nc = n * C + c;
 #if __CUDA_ARCH__ >= 350
@@ -206,7 +202,7 @@ template <typename T>
 __global__ void GroupNormBackwardNCHWCUDAKernel(
     const int G,
     const int K,
-    const int W,
+    const int M,
     const int HxW,
     const T* dY,
     const T* X,
@@ -218,12 +214,12 @@ __global__ void GroupNormBackwardNCHWCUDAKernel(
     T* dX) {
   const int C = G * K;
   const T denom = T(1) / static_cast<T>(K * HxW);
-  const int nc = blockIdx.x / W;
+  const int nc = blockIdx.x / M;
   const int n = nc / C;
   const int c = nc % C;
   const int g = c / K;
   const int ng = n * G + g;
-  const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
   const int index = nc * HxW + hw;
   if (hw < HxW) {
 #if __CUDA_ARCH__ >= 350
@@ -261,7 +257,8 @@ __global__ void GroupNormBackwardNHWCCUDAKernel(
   const int g = blockIdx.y;
   const int n = x / HxW;
   const int ng = n * G + g;
-  for (int i = threadIdx.x; i < K; i += blockDim.x) {
+  const int i = blockIdx.z * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (i < K) {
     const int c = g * K + i;
     const int index = x * C + c;
 #if __CUDA_ARCH__ >= 350
@@ -393,10 +390,10 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNCHW(
     const float* scale,
     const float* bias,
     float* Y) {
-  const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+  const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
   GroupNormForwardNCHWCUDAKernel<float>
-      <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          W, HxW, X, scale, bias, Y);
+      <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          M, HxW, X, scale, bias, Y);
 }
 
 template <>
@@ -408,8 +405,9 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNHWC(
     const float* scale,
     const float* bias,
     float* Y) {
+  const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
   GroupNormForwardNHWCCUDAKernel<float>
-      <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      <<<dim3(N * HxW, M), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
           C, HxW, X, scale, bias, Y);
 }
 
@@ -440,31 +438,28 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
     // Computes dL/ds and dL/db.
     // dL/ds = Sum(dL/dY * gamma * X)
     // dL/db = Sum(dL/dY * gamma)
-    if (HxW >= 128) {
-      ComputeInternalGradientsNCHWCUDAKernel<float, 1, 128>
-          <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    } else if (HxW >= 64) {
-      ComputeInternalGradientsNCHWCUDAKernel<float, 2, 64>
-          <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    } else if (HxW >= 32) {
-      ComputeInternalGradientsNCHWCUDAKernel<float, 4, 32>
-          <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    } else {
-      ComputeInternalGradientsNCHWCUDAKernel<float, 8, 16>
-          <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    }
+    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+        HxW,
+        ComputeInternalGradientsNCHWCUDAKernel,
+        float,
+        dim3(N, G),
+        context_.cuda_stream(),
+        G,
+        K,
+        HxW,
+        dY_data,
+        X_data,
+        gamma_data,
+        ds_data,
+        db_data);
 
     // Computes dL/dX.
-    const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+    const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
     GroupNormBackwardNCHWCUDAKernel<float>
-        <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+        <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
             G,
             K,
-            W,
+            M,
             HxW,
             dY_data,
             X_data,
@@ -476,84 +471,45 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
             dX_data);
 
     // Computes dL/dgamma and dL/dbeta.
-    if (HxW >= 128) {
-      GammaBetaBackwardNCHWCUDAKernel<float, 1, 128>
-          <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    } else if (HxW >= 64) {
-      GammaBetaBackwardNCHWCUDAKernel<float, 2, 64>
-          <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    } else if (HxW >= 32) {
-      GammaBetaBackwardNCHWCUDAKernel<float, 4, 32>
-          <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    } else {
-      GammaBetaBackwardNCHWCUDAKernel<float, 8, 16>
-          <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    }
+    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+        HxW,
+        GammaBetaBackwardNCHWCUDAKernel,
+        float,
+        C,
+        context_.cuda_stream(),
+        N,
+        G,
+        K,
+        HxW,
+        dY_data,
+        X_data,
+        mu_data,
+        rsig_data,
+        dgamma_data,
+        dbeta_data);
   } else {
     // Computes dL/ds and dL/db.
     // dL/ds = Sum(dL/dY * gamma * X)
     // dL/db = Sum(dL/dY * gamma)
-    if (K >= 128) {
-      ComputeInternalGradientsNHWCCUDAKernel<float, 1, 128>
-          <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    } else if (K >= 64) {
-      ComputeInternalGradientsNHWCCUDAKernel<float, 2, 64>
-          <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    } else if (K >= 32) {
-      ComputeInternalGradientsNHWCCUDAKernel<float, 4, 32>
-          <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    } else {
-      ComputeInternalGradientsNHWCCUDAKernel<float, 8, 16>
-          <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
-              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
-    }
+    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+        K,
+        ComputeInternalGradientsNHWCCUDAKernel,
+        float,
+        dim3(N, G),
+        context_.cuda_stream(),
+        G,
+        K,
+        HxW,
+        dY_data,
+        X_data,
+        gamma_data,
+        ds_data,
+        db_data);
 
     // Computes dL/dX.
+    const int M = math::DivUp(K, CAFFE_CUDA_NUM_THREADS);
     GroupNormBackwardNHWCCUDAKernel<float>
-        <<<dim3(N * HxW, G),
+        <<<dim3(N * HxW, G, M),
            CAFFE_CUDA_NUM_THREADS,
            0,
            context_.cuda_stream()>>>(
@@ -570,59 +526,22 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
             dX_data);
 
     // Computes dL/dgamma and dL/dbeta.
-    if (HxW >= 128) {
-      GammaBetaBackwardNHWCCUDAKernel<float, 1, 128>
-          <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    } else if (HxW >= 64) {
-      GammaBetaBackwardNHWCCUDAKernel<float, 2, 64>
-          <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    } else if (HxW >= 32) {
-      GammaBetaBackwardNHWCCUDAKernel<float, 4, 32>
-          <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    } else {
-      GammaBetaBackwardNHWCCUDAKernel<float, 8, 16>
-          <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
-              N,
-              G,
-              K,
-              HxW,
-              dY_data,
-              X_data,
-              mu_data,
-              rsig_data,
-              dgamma_data,
-              dbeta_data);
-    }
+    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+        HxW,
+        GammaBetaBackwardNHWCCUDAKernel,
+        float,
+        C,
+        context_.cuda_stream(),
+        N,
+        G,
+        K,
+        HxW,
+        dY_data,
+        X_data,
+        mu_data,
+        rsig_data,
+        dgamma_data,
+        dbeta_data);
   }
   return true;
 }
index 6be58d2..94a4697 100644 (file)
@@ -8,19 +8,13 @@
 
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/utils/math.h"
+#include "caffe2/utils/math/reduce.cuh"
 
 namespace caffe2 {
 
 namespace {
 
 template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
-    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
 __global__ void ComputeFusedParamCUDAKernel(
     const int C,
     const T epsilon,
@@ -316,7 +310,7 @@ __global__ void ComputeScaleGradientAndFusedParamsNHWCCUDAKernel<float>(
 template <typename T>
 __global__ void ComputeXGradientNCHWCUDAKernel(
     const int C,
-    const int K,
+    const int M,
     const int HxW,
     const T* dY,
     const T* X,
@@ -328,7 +322,7 @@ __global__ void ComputeXGradientNCHWCUDAKernel(
 template <>
 __global__ void ComputeXGradientNCHWCUDAKernel<float>(
     const int C,
-    const int K,
+    const int M,
     const int HxW,
     const float* dY,
     const float* X,
@@ -336,9 +330,9 @@ __global__ void ComputeXGradientNCHWCUDAKernel<float>(
     const float* beta,
     const float* gamma,
     float* dX) {
-  const int nc = blockIdx.x / K;
+  const int nc = blockIdx.x / M;
   const int c = nc % C;
-  const int x = blockIdx.x % K * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  const int x = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
   if (x < HxW) {
     const int index = nc * HxW + x;
 #if __CUDA_ARCH__ >= 350
@@ -399,9 +393,9 @@ void SpatialBNOp<CUDAContext>::ComputeFusedParam(
     const T* var,
     T* alpha,
     T* beta) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+  const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
   ComputeFusedParamCUDAKernel<T>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
           C, static_cast<T>(epsilon_), scale, bias, mean, var, alpha, beta);
 }
 
@@ -415,10 +409,10 @@ void SpatialBNOp<CUDAContext>::ComputeBatchMoments(
     const T* batch_var_sum,
     T* mean,
     T* var) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+  const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
   const T scale = T(1) / static_cast<T>(num_batches_ * N * HxW);
   ComputeBatchMomentsCUDAKernel<T>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
           C, scale, batch_mean_sum, batch_var_sum, mean, var);
 }
 
@@ -435,9 +429,9 @@ void SpatialBNOp<CUDAContext>::ComputeRunningMomentsAndFusedParam(
     T* rstd,
     T* alpha,
     T* beta) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+  const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
   ComputeRunningMomentsAndFusedParamCUDAKernel<T>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
           C,
           static_cast<T>(momentum_),
           static_cast<T>(epsilon_),
@@ -469,11 +463,11 @@ void SpatialBNGradientOp<CUDAContext>::
         T* alpha,
         T* beta,
         T* gamma) {
-  const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+  const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
   const T batch_scale = T(1) / static_cast<T>(num_batches_);
   const T mean_scale = T(1) / static_cast<T>(N * HxW);
   ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel<T>
-      <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
           C,
           batch_scale,
           mean_scale,
@@ -507,71 +501,25 @@ void SpatialBNGradientOp<CUDAContext>::ComputeScaleBiasGradientsAndFusedParams(
     T* gamma,
     T* scratch) {
   if (order_ == StorageOrder::NCHW) {
-    if (HxW >= 128) {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 1, 128>
-          <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    } else if (HxW >= 64) {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 2, 64>
-          <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    } else if (HxW >= 32) {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 4, 32>
-          <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    } else {
-      ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 8, 16>
-          <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
-              N,
-              C,
-              HxW,
-              dY,
-              X,
-              scale,
-              mean,
-              rstd,
-              dscale,
-              dbias,
-              alpha,
-              beta,
-              gamma);
-    }
+    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+        HxW,
+        ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel,
+        T,
+        C,
+        context_.cuda_stream(),
+        N,
+        C,
+        HxW,
+        dY,
+        X,
+        scale,
+        mean,
+        rstd,
+        dscale,
+        dbias,
+        alpha,
+        beta,
+        gamma);
   } else {
     ReinitializeTensor(&ones_, N * HxW, at::dtype<T>().device(CUDA));
     math::Set<T, CUDAContext>(
@@ -602,9 +550,9 @@ void SpatialBNGradientOp<CUDAContext>::ComputeScaleBiasGradientsAndFusedParams(
         0.0f,
         dbias,
         &context_);
-    const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+    const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
     ComputeScaleGradientAndFusedParamsNHWCCUDAKernel<T>
-        <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
             C,
             T(1) / static_cast<T>(N * HxW),
             dscale,
@@ -632,15 +580,17 @@ void SpatialBNGradientOp<CUDAContext>::ComputeXGradient(
     const T* gamma,
     T* dX) {
   if (order_ == StorageOrder::NCHW) {
-    const int K = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+    const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
     ComputeXGradientNCHWCUDAKernel<T>
-        <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            C, K, HxW, dY, X, alpha, beta, gamma, dX);
+        <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+            C, M, HxW, dY, X, alpha, beta, gamma, dX);
   } else {
-    const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+    const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
     ComputeXGradientNHWCCUDAKernel<T>
-        <<<dim3(N * HxW, K), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            C, HxW, dY, X, alpha, beta, gamma, dX);
+        <<<dim3(N * HxW, M),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context_.cuda_stream()>>>(C, HxW, dY, X, alpha, beta, gamma, dX);
   }
 }
 
index 0798b6f..cc7613c 100644 (file)
@@ -23,8 +23,8 @@ __global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) {
 template <typename T>
 __global__ void AffineChannelNCHWCUDAKernel(
     const int C,
+    const int M,
     const int HxW,
-    const int K,
     const T* X,
     const T* scale,
     const T* bias,
@@ -33,15 +33,15 @@ __global__ void AffineChannelNCHWCUDAKernel(
 template <>
 __global__ void AffineChannelNCHWCUDAKernel<float>(
     const int C,
+    const int M,
     const int HxW,
-    const int K,
     const float* X,
     const float* scale,
     const float* bias,
     float* Y) {
-  const int nc = blockIdx.x / K;
+  const int nc = blockIdx.x / M;
   const int c = nc % C;
-  const int w = blockIdx.x % K * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
   if (w < HxW) {
     const int index = nc * HxW + w;
 #if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
@@ -180,10 +180,10 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(double)
       const T* bias,                                                         \
       T* Y,                                                                  \
       CUDAContext* context) {                                                \
-    const int K = DivUp(HxW, CAFFE_CUDA_NUM_THREADS);                        \
+    const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS);                        \
     AffineChannelNCHWCUDAKernel<T>                                           \
-        <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(  \
-            C, HxW, K, X, scale, bias, Y);                                   \
+        <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(  \
+            C, M, HxW, X, scale, bias, Y);                                   \
   }                                                                          \
   template <>                                                                \
   CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \
@@ -195,9 +195,9 @@ CAFFE2_SPECIALIZED_CUDA_SINCOS(double)
       const T* bias,                                                         \
       T* Y,                                                                  \
       CUDAContext* context) {                                                \
-    const int K = DivUp(C, CAFFE_CUDA_NUM_THREADS);                          \
+    const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS);                          \
     AffineChannelNHWCCUDAKernel<T>                                           \
-        <<<dim3(N* HxW, K),                                                  \
+        <<<dim3(N* HxW, M),                                                  \
            CAFFE_CUDA_NUM_THREADS,                                           \
            0,                                                                \
            context->cuda_stream()>>>(C, X, scale, bias, Y);                  \
index f597ec7..31a6539 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/utils/fixed_divisor.h"
+#include "caffe2/utils/math/reduce.cuh"
 #include "caffe2/utils/math_utils.h"
 
 namespace caffe2 {
@@ -18,13 +19,6 @@ namespace math {
 namespace {
 
 template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
-    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
 __global__ void
 RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) {
   __shared__ typename BlockReduce<T>::TempStorage m_storage;
@@ -229,23 +223,18 @@ CAFFE2_CUDA_EXPORT void MomentsCUDA(
   int N;
   int K;
   if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) {
-    if (K >= 128) {
-      BothEndsMomentsCUDAKernel<T, 1, 128>
-          <<<N, dim3(1, 128), 0, context->cuda_stream()>>>(
-              M, N, K, X, mean, var);
-    } else if (K >= 64) {
-      BothEndsMomentsCUDAKernel<T, 2, 64>
-          <<<N, dim3(2, 64), 0, context->cuda_stream()>>>(
-              M, N, K, X, mean, var);
-    } else if (K >= 32) {
-      BothEndsMomentsCUDAKernel<T, 4, 32>
-          <<<N, dim3(4, 32), 0, context->cuda_stream()>>>(
-              M, N, K, X, mean, var);
-    } else {
-      BothEndsMomentsCUDAKernel<T, 8, 16>
-          <<<N, dim3(8, 16), 0, context->cuda_stream()>>>(
-              M, N, K, X, mean, var);
-    }
+    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+        K,
+        BothEndsMomentsCUDAKernel,
+        T,
+        N,
+        context->cuda_stream(),
+        M,
+        N,
+        K,
+        X,
+        mean,
+        var);
     return;
   }
   std::vector<int> axes(ndim);
diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh
new file mode 100644 (file)
index 0000000..d191cbc
--- /dev/null
@@ -0,0 +1,35 @@
+#ifndef CAFFE2_UTILS_MATH_REDUCE_CUH_
+#define CAFFE2_UTILS_MATH_REDUCE_CUH_
+
+#include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
+
+#include "caffe2/core/common_gpu.h"
+
+namespace caffe2 {
+
+template <typename T>
+using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+using BlockReduce2D = cub::
+    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
+
+#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(                                   \
+    size, Func, T, grid_dim, cuda_stream, ...)                                \
+  do {                                                                        \
+    if (size >= 128) {                                                        \
+      Func<T, 1, 128>                                                         \
+          <<<grid_dim, dim3(1, 128), 0, cuda_stream>>>(__VA_ARGS__);          \
+    } else if (size >= 64) {                                                  \
+      Func<T, 2, 64><<<grid_dim, dim3(2, 64), 0, cuda_stream>>>(__VA_ARGS__); \
+    } else if (size >= 32) {                                                  \
+      Func<T, 4, 32><<<grid_dim, dim3(4, 32), 0, cuda_stream>>>(__VA_ARGS__); \
+    } else {                                                                  \
+      Func<T, 8, 16><<<grid_dim, dim3(8, 16), 0, cuda_stream>>>(__VA_ARGS__); \
+    }                                                                         \
+  } while (false)
+
+} // namespace caffe2
+
+#endif // CAFFE2_UTILS_MATH_REDUCE_CUH_
index bfd4f7f..926a797 100644 (file)
@@ -2258,6 +2258,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([
     ("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)),
     ("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)),
     ("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)),
+    ("/math/reduce.cuh", ("/math/hip/reduce.cuh", API_CAFFE2)),
     ("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)),
     ("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)),
     ("REGISTER_CUDA_OPERATOR" , ("REGISTER_HIP_OPERATOR", API_CAFFE2)),