Add count_include_pad arg for AveragePoolOp on GPU (#15787)
authorXiaomeng Yang <yangxm@fb.com>
Tue, 8 Jan 2019 05:33:44 +0000 (21:33 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 05:36:26 +0000 (21:36 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15787

Add count_include_pad arg for AveragePoolOp on GPU

Reviewed By: houseroad

Differential Revision: D13589185

fbshipit-source-id: 235a84cfcd2033ee796c13e338fc3d03e832b5b1

caffe2/operators/pool_gradient_op.cc
caffe2/operators/pool_op.cu

index c34335e..e33b408 100644 (file)
@@ -6,9 +6,6 @@
 
 namespace caffe2 {
 
-using std::max;
-using std::min;
-
 namespace {
 
 template <typename T, StorageOrder kOrder>
index 55c59cb..25e2b2c 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe2/core/context_gpu.h"
 
 namespace caffe2 {
+
 namespace {
 
 struct AveragePool {
@@ -16,256 +17,262 @@ struct MaxPool {
   explicit MaxPool(const OperatorBase& /* op */) {}
 };
 
-}  // namespace
-
-namespace {
 template <typename T>
-__global__ void Average1DPoolForwardNCHW(
-    const int nthreads,
-    const T* bottom_data,
-    const int num,
-    const int channels,
-    const int height,
-    const int pooled_height,
-    const int kernel_h,
-    const int stride_h,
-    const int pad_t,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int n = index;
-    int ph = n % pooled_height;
-    n /= pooled_height;
-    int c = n % channels;
-    n /= channels;
-    int hstart = ph * stride_h - pad_t;
-    int hend = min(hstart + kernel_h, height);
-    hstart = max(hstart, 0);
-    top_data[index] = 0;
-    int bottom_offset = (n * channels + c) * height;
-    for (int h = hstart; h < hend; ++h) {
-      top_data[index] += bottom_data[bottom_offset + h];
+__global__ void AveragePool1DForwardNCHWCUDAKernel(
+    const int K,
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const bool count_include_pad,
+    const T* X,
+    T* Y) {
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const T* X_ptr = X + nc * X_size;
+  T* Y_ptr = Y + nc * Y_size;
+  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (y < Y_size) {
+    const int x = y * stride;
+    const int l = max(x - pad, 0);
+    const int r = min(x - pad + kernel, X_size);
+    const T scale = T(1) / static_cast<T>(count_include_pad ? kernel : r - l);
+    T sum = 0;
+    for (int i = l; i < r; ++i) {
+      sum += X_ptr[i];
     }
-    top_data[index] /= (hend - hstart);
+    Y_ptr[y] = sum * scale;
   }
 }
 
 template <typename T>
-__global__ void Average2DPoolForwardNCHW(
-    const int nthreads,
-    const T* bottom_data,
-    const int num,
-    const int channels,
-    const int height,
-    const int width,
-    const int pooled_height,
-    const int pooled_width,
-    const int kernel_h,
-    const int kernel_w,
-    const int stride_h,
-    const int stride_w,
-    const int pad_t,
-    const int pad_l,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int n = index;
-    int pw = n % pooled_width;
-    n /= pooled_width;
-    int ph = n % pooled_height;
-    n /= pooled_height;
-    int c = n % channels;
-    n /= channels;
-    int hstart = ph * stride_h - pad_t;
-    int wstart = pw * stride_w - pad_l;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    top_data[index] = 0;
-    int bottom_offset = (n * channels + c) * height * width;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        top_data[index] += bottom_data[bottom_offset + h * width + w];
-      }
+__global__ void AveragePool1DForwardNHWCCUDAKernel(
+    const int C,
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const bool count_include_pad,
+    const T* X,
+    T* Y) {
+  const int n = blockIdx.x / Y_size;
+  const int y = blockIdx.x % Y_size;
+  const int x = y * stride;
+  const int l = max(x - pad, 0);
+  const int r = min(x - pad + kernel, X_size);
+  const T scale = T(1) / static_cast<T>(count_include_pad ? kernel : r - l);
+  const T* X_ptr = X + n * X_size * C;
+  T* Y_ptr = Y + n * Y_size * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = l; i < r; ++i) {
+      sum += X_ptr[i * C + c];
     }
-    top_data[index] /= (hend - hstart) * (wend - wstart);
+    Y_ptr[y * C + c] = sum * scale;
   }
 }
 
 template <typename T>
-__global__ void Average3DPoolForwardNCHW(
-    const int nthreads,
-    const T* bottom_data,
-    const int num,
-    const int channels,
-    const int height,
-    const int width,
-    const int depth,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+__global__ void AveragePool2DForwardNCHWCUDAKernel(
+    const int K,
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int n = index;
-    int pd = n % pooled_depth;
-    n /= pooled_depth;
-    int pw = n % pooled_width;
-    n /= pooled_width;
-    int ph = n % pooled_height;
-    n /= pooled_height;
-    int c = n % channels;
-    n /= channels;
-    int hstart = ph * stride_h - pad_t;
-    int wstart = pw * stride_w - pad_l;
-    int dstart = pd * stride_d - pad_f;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    int dend = min(dstart + kernel_d, depth);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    dstart = max(dstart, 0);
-    top_data[index] = 0;
-    int bottom_offset = (n * channels + c) * height * width * depth;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        for (int d = dstart; d < dend; ++d) {
-          const int input_index =
-              bottom_offset + h * width * depth + w * depth + d;
-          top_data[index] += bottom_data[input_index];
-        }
+    const bool count_include_pad,
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const T* X_ptr = X + nc * X_HxW;
+  T* Y_ptr = Y + nc * Y_HxW;
+  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (y < Y_HxW) {
+    const int yh = y / Y_W;
+    const int yw = y % Y_W;
+    const int xh = yh * stride_h;
+    const int xw = yw * stride_w;
+    const int t = max(xh - pad_t, 0);
+    const int b = min(xh - pad_t + kernel_h, X_H);
+    const int l = max(xw - pad_l, 0);
+    const int r = min(xw - pad_l + kernel_w, X_W);
+    const T scale = T(1) /
+        static_cast<T>(count_include_pad ? kernel_h * kernel_w
+                                         : (b - t) * (r - l));
+    T sum = 0;
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        sum += X_ptr[i * X_W + j];
       }
     }
-    top_data[index] /= (hend - hstart) * (wend - wstart) * (dend - dstart);
+    Y_ptr[y] = sum * scale;
   }
 }
 
 template <typename T>
-__global__ void Average1DPoolForwardNHWC(
-    const int nthreads,
-    const T* bottom_data,
-    const int num,
-    const int height,
-    const int channels,
-    const int pooled_height,
+__global__ void AveragePool2DForwardNHWCCUDAKernel(
+    const int C,
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
+    const int kernel_w,
     const int stride_h,
+    const int stride_w,
     const int pad_t,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int c = index % channels;
-    int ph = (index / channels) % pooled_height;
-    int n = index / channels / pooled_height;
-    int hstart = ph * stride_h - pad_t;
-    int hend = min(hstart + kernel_h, height);
-    hstart = max(hstart, 0);
-    T output = 0;
-    int bottom_offset = n * height * channels + c;
-    for (int h = hstart; h < hend; ++h) {
-      output += bottom_data[bottom_offset + h * channels];
+    const int pad_l,
+    const bool count_include_pad,
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int n = blockIdx.x / Y_HxW;
+  const int y = blockIdx.x % Y_HxW;
+  const int yh = y / Y_W;
+  const int yw = y % Y_W;
+  const int xh = yh * stride_h;
+  const int xw = yw * stride_w;
+  const int t = max(xh - pad_t, 0);
+  const int b = min(xh - pad_t + kernel_h, X_H);
+  const int l = max(xw - pad_l, 0);
+  const int r = min(xw - pad_l + kernel_w, X_W);
+  const T scale = T(1) /
+      static_cast<T>(count_include_pad ? kernel_h * kernel_w
+                                       : (b - t) * (r - l));
+  const T* X_ptr = X + n * X_HxW * C;
+  T* Y_ptr = Y + n * Y_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        sum += X_ptr[(i * X_W + j) * C + c];
+      }
     }
-    int pool_size = (hend - hstart);
-    top_data[index] = output / pool_size;
+    Y_ptr[y * C + c] = sum * scale;
   }
 }
 
 template <typename T>
-__global__ void Average2DPoolForwardNHWC(
-    const int nthreads,
-    const T* bottom_data,
-    const int num,
-    const int height,
-    const int width,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
+__global__ void AveragePool3DForwardNCHWCUDAKernel(
+    const int K,
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int c = index % channels;
-    int pw = (index / channels) % pooled_width;
-    int ph = (index / channels / pooled_width) % pooled_height;
-    int n = index / channels / pooled_width / pooled_height;
-    int hstart = ph * stride_h - pad_t;
-    int wstart = pw * stride_w - pad_l;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    T output = 0;
-    int bottom_offset = n * height * width * channels + c;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        output += bottom_data[bottom_offset + (h * width + w) * channels];
+    const bool count_include_pad,
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const T* X_ptr = X + nc * X_HxW;
+  T* Y_ptr = Y + nc * Y_HxW;
+  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (y < Y_HxW) {
+    const int yy = y / Y_W;
+    const int yw = y % Y_W;
+    const int yh = yy % Y_H;
+    const int yd = yy / Y_H;
+    const int xd = yd * stride_d;
+    const int xh = yh * stride_h;
+    const int xw = yw * stride_w;
+    const int p = max(xd - pad_p, 0);
+    const int a = min(xd - pad_p + kernel_h, X_D);
+    const int t = max(xh - pad_t, 0);
+    const int b = min(xh - pad_t + kernel_h, X_H);
+    const int l = max(xw - pad_l, 0);
+    const int r = min(xw - pad_l + kernel_w, X_W);
+    const T scale = T(1) /
+        static_cast<T>(count_include_pad ? kernel_d * kernel_h * kernel_w
+                                         : (a - p) * (b - t) * (r - l));
+    T sum = 0;
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          sum += X_ptr[(i * X_H + j) * X_W + k];
+        }
       }
     }
-    int pool_size = (hend - hstart) * (wend - wstart);
-    top_data[index] = output / pool_size;
+    Y_ptr[y] = sum * scale;
   }
 }
 
 template <typename T>
-__global__ void Average3DPoolForwardNHWC(
-    const int nthreads,
-    const T* bottom_data,
-    const int num,
-    const int height,
-    const int width,
-    const int depth,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+__global__ void AveragePool3DForwardNHWCCUDAKernel(
+    const int C,
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int c = index % channels;
-    int pd = (index / channels) % pooled_depth;
-    int pw = (index / channels / pooled_depth) % pooled_width;
-    int ph = (index / channels / pooled_depth / pooled_width) % pooled_height;
-    int n = index / channels / pooled_depth / pooled_width / pooled_height;
-    int hstart = ph * stride_h - pad_t;
-    int wstart = pw * stride_w - pad_l;
-    int dstart = pd * stride_d - pad_f;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    int dend = min(dstart + kernel_d, depth);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    dstart = max(dstart, 0);
-    T output = 0;
-    int bottom_offset = n * height * width * depth * channels + c;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        for (int d = dstart; d < dend; ++d) {
-          const int bottom_index =
-              bottom_offset + (h * depth * width + w * depth + d) * channels;
-          output += bottom_data[bottom_index];
+    const bool count_include_pad,
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int n = blockIdx.x / Y_HxW;
+  const int y = blockIdx.x % Y_HxW;
+  const int yy = y / Y_W;
+  const int yw = y % Y_W;
+  const int yh = yy % Y_H;
+  const int yd = yy / Y_H;
+  const int xd = yd * stride_d;
+  const int xh = yh * stride_h;
+  const int xw = yw * stride_w;
+  const int p = max(xd - pad_p, 0);
+  const int a = min(xd - pad_p + kernel_h, X_D);
+  const int t = max(xh - pad_t, 0);
+  const int b = min(xh - pad_t + kernel_h, X_H);
+  const int l = max(xw - pad_l, 0);
+  const int r = min(xw - pad_l + kernel_w, X_W);
+  const T scale = T(1) /
+      static_cast<T>(count_include_pad ? kernel_d * kernel_h * kernel_w
+                                       : (a - p) * (b - t) * (r - l));
+  const T* X_ptr = X + n * X_HxW * C;
+  T* Y_ptr = Y + n * Y_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          sum += X_ptr[((i * X_H + j) * X_W + k) * C + c];
         }
       }
     }
-    int pool_size = (hend - hstart) * (wend - wstart) * (dend - dstart);
-    top_data[index] = output / pool_size;
+    Y_ptr[y * C + c] = sum * scale;
   }
 }
 
@@ -562,163 +569,211 @@ __global__ void Ave3DPoolBackwardNHWC(
   }
 }
 
-}  // namespace
+} // namespace
 
 template <>
-bool PoolOp<float, CUDAContext, AveragePool>::RunOnDeviceWithOrderNCHW() {
-  auto& X = Input(0);
-  auto* Y = Output(0);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
-  int output_size = Y->size();
-  switch (kernel_.size()) {
-    case 1:
-      Average1DPoolForwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              Y->dim32(2),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              Y->template mutable_data<float>());
-      break;
-    case 2:
-      Average2DPoolForwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              Y->dim32(2),
-              Y->dim32(3),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              Y->template mutable_data<float>());
-      break;
-    case 3:
-      Average3DPoolForwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              Y->dim32(2),
-              Y->dim32(3),
-              Y->dim32(4),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              Y->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
-  }
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW, 1>(
+    const int N,
+    const int C,
+    const std::array<int, 1>& X_dims,
+    const std::array<int, 1>& Y_dims,
+    const std::array<int, 1>& kernel,
+    const std::array<int, 1>& /* dilation */,
+    const std::array<int, 1>& stride,
+    const std::array<int, 2>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  const int K =
+      (Y_dims[0] + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+  AveragePool1DForwardNCHWCUDAKernel<float>
+      <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          K,
+          X_dims[0],
+          Y_dims[0],
+          kernel[0],
+          stride[0],
+          pads[0],
+          count_include_pad,
+          X,
+          Y);
   return true;
 }
 
 template <>
-bool PoolOp<float, CUDAContext, AveragePool>::RunOnDeviceWithOrderNHWC() {
-  auto& X = Input(0);
-  auto* Y = Output(0);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(X.ndim() - 1));
-  int output_size = Y->size();
-  switch (kernel_.size()) {
-    case 1:
-      Average1DPoolForwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              Y->dim32(1),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              Y->template mutable_data<float>());
-      break;
-    case 2:
-      Average2DPoolForwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              Y->dim32(1),
-              Y->dim32(2),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              Y->template mutable_data<float>());
-      break;
-    case 3:
-      Average3DPoolForwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              Y->dim32(1),
-              Y->dim32(2),
-              Y->dim32(3),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              Y->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
-  }
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC, 1>(
+    const int N,
+    const int C,
+    const std::array<int, 1>& X_dims,
+    const std::array<int, 1>& Y_dims,
+    const std::array<int, 1>& kernel,
+    const std::array<int, 1>& /* dilation */,
+    const std::array<int, 1>& stride,
+    const std::array<int, 2>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  AveragePool1DForwardNHWCCUDAKernel<float>
+      <<<N * Y_dims[0], CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          C,
+          X_dims[0],
+          Y_dims[0],
+          kernel[0],
+          stride[0],
+          pads[0],
+          count_include_pad,
+          X,
+          Y);
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW, 2>(
+    const int N,
+    const int C,
+    const std::array<int, 2>& X_dims,
+    const std::array<int, 2>& Y_dims,
+    const std::array<int, 2>& kernel,
+    const std::array<int, 2>& /* dilation */,
+    const std::array<int, 2>& stride,
+    const std::array<int, 4>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  const int Y_HxW = Y_dims[0] * Y_dims[1];
+  const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+  AveragePool2DForwardNCHWCUDAKernel<float>
+      <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          K,
+          X_dims[0],
+          X_dims[1],
+          Y_dims[0],
+          Y_dims[1],
+          kernel[0],
+          kernel[1],
+          stride[0],
+          stride[1],
+          pads[0],
+          pads[1],
+          count_include_pad,
+          X,
+          Y);
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC, 2>(
+    const int N,
+    const int C,
+    const std::array<int, 2>& X_dims,
+    const std::array<int, 2>& Y_dims,
+    const std::array<int, 2>& kernel,
+    const std::array<int, 2>& /* dilation */,
+    const std::array<int, 2>& stride,
+    const std::array<int, 4>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  const int Y_HxW = Y_dims[0] * Y_dims[1];
+  AveragePool2DForwardNHWCCUDAKernel<float>
+      <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          C,
+          X_dims[0],
+          X_dims[1],
+          Y_dims[0],
+          Y_dims[1],
+          kernel[0],
+          kernel[1],
+          stride[0],
+          stride[1],
+          pads[0],
+          pads[1],
+          count_include_pad,
+          X,
+          Y);
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW, 3>(
+    const int N,
+    const int C,
+    const std::array<int, 3>& X_dims,
+    const std::array<int, 3>& Y_dims,
+    const std::array<int, 3>& kernel,
+    const std::array<int, 3>& /* dilation */,
+    const std::array<int, 3>& stride,
+    const std::array<int, 6>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  const int Y_HxW = Y_dims[0] * Y_dims[1] * Y_dims[2];
+  const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+  AveragePool3DForwardNCHWCUDAKernel<float>
+      <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          K,
+          X_dims[0],
+          X_dims[1],
+          X_dims[2],
+          Y_dims[0],
+          Y_dims[1],
+          Y_dims[2],
+          kernel[0],
+          kernel[1],
+          kernel[2],
+          stride[0],
+          stride[1],
+          stride[2],
+          pads[0],
+          pads[1],
+          pads[2],
+          count_include_pad,
+          X,
+          Y);
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC, 3>(
+    const int N,
+    const int C,
+    const std::array<int, 3>& X_dims,
+    const std::array<int, 3>& Y_dims,
+    const std::array<int, 3>& kernel,
+    const std::array<int, 3>& /* dilation */,
+    const std::array<int, 3>& stride,
+    const std::array<int, 6>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  const int Y_HxW = Y_dims[0] * Y_dims[1] * Y_dims[2];
+  AveragePool3DForwardNHWCCUDAKernel<float>
+      <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          C,
+          X_dims[0],
+          X_dims[1],
+          X_dims[2],
+          Y_dims[0],
+          Y_dims[1],
+          Y_dims[2],
+          kernel[0],
+          kernel[1],
+          kernel[2],
+          stride[0],
+          stride[1],
+          stride[2],
+          pads[0],
+          pads[1],
+          pads[2],
+          count_include_pad,
+          X,
+          Y);
   return true;
 }
 
@@ -889,7 +944,6 @@ bool PoolGradientOp<float, CUDAContext, AveragePool>::
   return true;
 }
 
-
 namespace {
 
 template <typename T>
@@ -1335,8 +1389,7 @@ __global__ void MaxPool2DBackwardNHWC(
     const int phend = min(h / stride_h + 1, pooled_height);
     const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
     const int pwend = min(w / stride_w + 1, pooled_width);
-    const int top_offset =
-        n * pooled_height * pooled_width * channels + c;
+    const int top_offset = n * pooled_height * pooled_width * channels + c;
     bottom_diff[index] = 0;
     for (int ph = phstart; ph < phend; ++ph) {
       for (int pw = pwstart; pw < pwend; ++pw) {
@@ -1403,7 +1456,7 @@ __global__ void MaxPool3DBackwardNHWC(
     }
   }
 }
-}  // namespace
+} // namespace
 
 template <>
 bool PoolOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNCHW() {
@@ -1735,28 +1788,38 @@ bool PoolGradientOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNHWC() {
   return true;
 }
 
-REGISTER_CUDA_OPERATOR(AveragePool, PoolOp<float, CUDAContext, AveragePool>);
-REGISTER_CUDA_OPERATOR(AveragePoolGradient,
-                       PoolGradientOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+    AveragePool,
+    PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
+REGISTER_CUDA_OPERATOR(
+    AveragePoolGradient,
+    PoolGradientOp<float, CUDAContext, AveragePool>);
 
-REGISTER_CUDA_OPERATOR(AveragePool1D, PoolOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+    AveragePool1D,
+    PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePool1DGradient,
     PoolGradientOp<float, CUDAContext, AveragePool>);
 
-REGISTER_CUDA_OPERATOR(AveragePool2D, PoolOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+    AveragePool2D,
+    PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePool2DGradient,
     PoolGradientOp<float, CUDAContext, AveragePool>);
 
-REGISTER_CUDA_OPERATOR(AveragePool3D, PoolOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+    AveragePool3D,
+    PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePool3DGradient,
     PoolGradientOp<float, CUDAContext, AveragePool>);
 
 REGISTER_CUDA_OPERATOR(MaxPool, PoolOp<float, CUDAContext, MaxPool>);
-REGISTER_CUDA_OPERATOR(MaxPoolGradient,
-                       PoolGradientOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+    MaxPoolGradient,
+    PoolGradientOp<float, CUDAContext, MaxPool>);
 
 REGISTER_CUDA_OPERATOR(MaxPool1D, PoolOp<float, CUDAContext, MaxPool>);
 REGISTER_CUDA_OPERATOR(
@@ -1772,4 +1835,4 @@ REGISTER_CUDA_OPERATOR(MaxPool3D, PoolOp<float, CUDAContext, MaxPool>);
 REGISTER_CUDA_OPERATOR(
     MaxPool3DGradient,
     PoolGradientOp<float, CUDAContext, MaxPool>);
-}  // namespace caffe2
+} // namespace caffe2