Add global pooling specialization and also update MaxPooling on GPU (#15824)
authorXiaomeng Yang <yangxm@fb.com>
Sat, 12 Jan 2019 06:35:12 +0000 (22:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 12 Jan 2019 06:37:48 +0000 (22:37 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15824

Add global pooling specialization and also update MaxPooling on GPU

Reviewed By: houseroad

Differential Revision: D13596340

fbshipit-source-id: c8a42aa69ee92c383c9f19d3ed57b77cb3e5bd28

caffe2/operators/pool_op.cc
caffe2/operators/pool_op.cu
caffe2/operators/pool_op.h

index 0c58431..8f5d522 100644 (file)
@@ -6,6 +6,7 @@
 
 #include "caffe2/operators/pool_op_util.h"
 #include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
 
 namespace caffe2 {
 
@@ -561,6 +562,48 @@ void RunMaxPool3D(
 
 } // namespace
 
+template <>
+template <>
+bool AveragePoolFunctor<CPUContext>::
+    GlobalPoolingForward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CPUContext* context) const {
+  const std::array<int, 2> dims = {N * C, HxW};
+  const int axis = 1;
+  math::ReduceMean<float, CPUContext>(
+      2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CPUContext>::
+    GlobalPoolingForward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CPUContext* context) const {
+  math::Set<float, CPUContext>(N * C, 0.0f, Y, context);
+  const float* X_ptr = X;
+  float* Y_ptr = Y;
+  for (int i = 0; i < N; ++i) {
+    for (int j = 0; j < HxW; ++j) {
+      math::Add<float, CPUContext>(C, Y_ptr, X_ptr + j * C, Y_ptr, context);
+    }
+    X_ptr += HxW * C;
+    Y_ptr += C;
+  }
+  math::Scale<float, float, CPUContext>(
+      N * C, 1.0f / static_cast<float>(HxW), Y, Y, context);
+  return true;
+}
+
 #define CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD(T, kOrder)           \
   template <>                                                                \
   template <>                                                                \
@@ -667,6 +710,49 @@ CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD(float, StorageOrder::NCHW)
 CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD(float, StorageOrder::NHWC)
 #undef CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD
 
+template <>
+template <>
+bool MaxPoolFunctor<CPUContext>::
+    GlobalPoolingForward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CPUContext* context) const {
+  const std::array<int, 2> dims = {N * C, HxW};
+  const int axis = 1;
+  math::ReduceMax<float, CPUContext>(
+      2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  return true;
+}
+
+template <>
+template <>
+bool MaxPoolFunctor<CPUContext>::
+    GlobalPoolingForward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CPUContext* context) const {
+  math::Set<float, CPUContext>(
+      N * C, std::numeric_limits<float>::lowest(), Y, context);
+  const float* X_ptr = X;
+  float* Y_ptr = Y;
+  for (int i = 0; i < N; ++i) {
+    ConstEigenArrayMap<float> X_arr(X_ptr, C, HxW);
+    EigenVectorArrayMap<float> Y_arr(Y_ptr, C);
+    for (int j = 0; j < HxW; ++j) {
+      Y_arr = Y_arr.max(X_arr.col(j));
+    }
+    X_ptr += HxW * C;
+    Y_ptr += C;
+  }
+  return true;
+}
+
 #define CAFFE2_SPECIALIZED_MAX_POOL_FUNCTOR_FORWARD(T, kOrder)                \
   template <>                                                                 \
   template <>                                                                 \
index 2a18be9..bca9e52 100644 (file)
@@ -1,11 +1,13 @@
 // TODO(ataei): reduce the apparent redundancy of all the code below.
 #include "caffe2/operators/pool_op.h"
 
-#include <cfloat>
+#include <array>
 #include <functional>
+#include <limits>
 #include <numeric>
 
 #include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
 
 namespace caffe2 {
 
@@ -574,6 +576,29 @@ __global__ void Ave3DPoolBackwardNHWC(
 } // namespace
 
 template <>
+template <typename T, StorageOrder kOrder>
+bool AveragePoolFunctor<CUDAContext>::GlobalPoolingForward(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* X,
+    T* Y,
+    CUDAContext* context) const {
+  if (kOrder == StorageOrder::NCHW) {
+    const std::array<int, 2> dims = {N * C, HxW};
+    const int axis = 1;
+    math::ReduceMean<float, CUDAContext>(
+        2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  } else {
+    const std::array<int, 3> dims = {N, HxW, C};
+    const int axis = 1;
+    math::ReduceMean<float, CUDAContext>(
+        3, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  }
+  return true;
+}
+
+template <>
 template <>
 bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
     const int N,
@@ -587,6 +612,7 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
     const float* X,
     float* Y,
     CUDAContext* context) const {
+  // Split each image into K segments, each CUDA block handles one segment.
   const int ndim = X_dims.size();
   const int Y_HxW = std::accumulate(
       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
@@ -670,6 +696,7 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
     const float* X,
     float* Y,
     CUDAContext* context) const {
+  // Each CUDA block handles one point, one thread per channel.
   const int ndim = X_dims.size();
   const int Y_HxW = std::accumulate(
       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
@@ -908,251 +935,241 @@ bool PoolGradientOp<float, CUDAContext, AveragePool>::
 namespace {
 
 template <typename T>
-__global__ void MaxPool1DForwardNCHW(
-    const int nthreads,
-    const T* bottom_data,
-    const int channels,
-    const int height,
-    const int pooled_height,
-    const int kernel_h,
-    const int stride_h,
-    const int pad_t,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int ph = index % pooled_height;
-    int c = (index / pooled_height) % channels;
-    int n = index / pooled_height / channels;
-    int hstart = ph * stride_h - pad_t;
-    int hend = min(hstart + kernel_h, height);
-    hstart = max(hstart, 0);
-    T maxval = -FLT_MAX;
-    const T* bdata_offset = bottom_data + n * channels * height;
-    for (int h = hstart; h < hend; ++h) {
-      int idx = c * height + h;
-      if (bdata_offset[idx] > maxval) {
-        maxval = bdata_offset[idx];
-      }
+__global__ void MaxPool1DForwardNCHWCUDAKernel(
+    const int K,
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const T* X,
+    T* Y) {
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const T* X_ptr = X + nc * X_size;
+  T* Y_ptr = Y + nc * Y_size;
+  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (y < Y_size) {
+    const int x = y * stride;
+    const int l = max(x - pad, 0);
+    const int r = min(x - pad + kernel, X_size);
+    T val = std::numeric_limits<T>::lowest();
+    for (int i = l; i < r; ++i) {
+      val = max(val, X_ptr[i]);
     }
-    top_data[index] = maxval;
+    Y_ptr[y] = val;
   }
 }
 
 template <typename T>
-__global__ void MaxPool2DForwardNCHW(
-    const int nthreads,
-    const T* bottom_data,
-    const int channels,
-    const int height,
-    const int width,
-    const int pooled_height,
-    const int pooled_width,
-    const int kernel_h,
-    const int kernel_w,
-    const int stride_h,
-    const int stride_w,
-    const int pad_t,
-    const int pad_l,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height;
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
-    int hstart = ph * stride_h - pad_t;
-    int wstart = pw * stride_w - pad_l;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    T maxval = -FLT_MAX;
-    const T* bdata_offset = bottom_data + n * channels * height * width;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        int idx = c * height * width + h * width + w;
-        if (bdata_offset[idx] > maxval) {
-          maxval = bdata_offset[idx];
-        }
-      }
+__global__ void MaxPool1DForwardNHWCCUDAKernel(
+    const int C,
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const T* X,
+    T* Y) {
+  const int n = blockIdx.x / Y_size;
+  const int y = blockIdx.x % Y_size;
+  const int x = y * stride;
+  const int l = max(x - pad, 0);
+  const int r = min(x - pad + kernel, X_size);
+  const T* X_ptr = X + n * X_size * C;
+  T* Y_ptr = Y + n * Y_size * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T val = std::numeric_limits<T>::lowest();
+    for (int i = l; i < r; ++i) {
+      val = max(val, X_ptr[i * C + c]);
     }
-    top_data[index] = maxval;
+    Y_ptr[y * C + c] = val;
   }
 }
 
 template <typename T>
-__global__ void MaxPool3DForwardNCHW(
-    const int nthreads,
-    const T* bottom_data,
-    const int channels,
-    const int height,
-    const int width,
-    const int depth,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+__global__ void MaxPool2DForwardNCHWCUDAKernel(
+    const int K,
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int pd = index % pooled_depth;
-    int pw = (index / pooled_depth) % pooled_width;
-    int ph = (index / pooled_depth / pooled_width) % pooled_height;
-    int c = (index / pooled_depth / pooled_width / pooled_height) % channels;
-    int n = index / pooled_depth / pooled_width / pooled_height / channels;
-    int hstart = ph * stride_h - pad_t;
-    int wstart = pw * stride_w - pad_l;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    int dstart = pd * stride_d - pad_f;
-    int dend = min(dstart + kernel_d, depth);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    dstart = max(dstart, 0);
-    T maxval = -FLT_MAX;
-    const T* bdata_offset = bottom_data + n * channels * height * width * depth;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        for (int d = dstart; d < dend; ++d) {
-          int idx = ((c * height + h) * width + w) * depth + d;
-          if (bdata_offset[idx] > maxval) {
-            maxval = bdata_offset[idx];
-          }
-        }
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const T* X_ptr = X + nc * X_HxW;
+  T* Y_ptr = Y + nc * Y_HxW;
+  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (y < Y_HxW) {
+    const int yh = y / Y_W;
+    const int yw = y % Y_W;
+    const int xh = yh * stride_h;
+    const int xw = yw * stride_w;
+    const int t = max(xh - pad_t, 0);
+    const int b = min(xh - pad_t + kernel_h, X_H);
+    const int l = max(xw - pad_l, 0);
+    const int r = min(xw - pad_l + kernel_w, X_W);
+    T val = std::numeric_limits<T>::lowest();
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        val = max(val, X_ptr[i * X_W + j]);
       }
     }
-    top_data[index] = maxval;
+    Y_ptr[y] = val;
   }
 }
 
 template <typename T>
-__global__ void MaxPool1DForwardNHWC(
-    const int nthreads,
-    const T* bottom_data,
-    const int height,
-    const int channels,
-    const int pooled_height,
+__global__ void MaxPool2DForwardNHWCCUDAKernel(
+    const int C,
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
+    const int kernel_w,
     const int stride_h,
+    const int stride_w,
     const int pad_t,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int n = index;
-    int c = n % channels;
-    n /= channels;
-    int hstart = (n % pooled_height) * stride_h - pad_t;
-    n /= pooled_height;
-    int hend = min(hstart + kernel_h, height);
-    hstart = max(hstart, 0);
-    T maxval = -FLT_MAX;
-    const T* bdata_offset = bottom_data + n * height * channels;
-    for (int h = hstart; h < hend; ++h) {
-      int idx = h * channels + c;
-      if (bdata_offset[idx] > maxval) {
-        maxval = bdata_offset[idx];
+    const int pad_l,
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int n = blockIdx.x / Y_HxW;
+  const int y = blockIdx.x % Y_HxW;
+  const int yh = y / Y_W;
+  const int yw = y % Y_W;
+  const int xh = yh * stride_h;
+  const int xw = yw * stride_w;
+  const int t = max(xh - pad_t, 0);
+  const int b = min(xh - pad_t + kernel_h, X_H);
+  const int l = max(xw - pad_l, 0);
+  const int r = min(xw - pad_l + kernel_w, X_W);
+  const T* X_ptr = X + n * X_HxW * C;
+  T* Y_ptr = Y + n * Y_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T val = std::numeric_limits<T>::lowest();
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        val = max(val, X_ptr[(i * X_W + j) * C + c]);
       }
     }
-    top_data[index] = maxval;
+    Y_ptr[y * C + c] = val;
   }
 }
 
 template <typename T>
-__global__ void MaxPool2DForwardNHWC(
-    const int nthreads,
-    const T* bottom_data,
-    const int height,
-    const int width,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
+__global__ void MaxPool3DForwardNCHWCUDAKernel(
+    const int K,
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int n = index;
-    int c = n % channels;
-    n /= channels;
-    int wstart = (n % pooled_width) * stride_w - pad_l;
-    n /= pooled_width;
-    int hstart = (n % pooled_height) * stride_h - pad_t;
-    n /= pooled_height;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    T maxval = -FLT_MAX;
-    const T* bdata_offset = bottom_data + n * height * width * channels;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        int idx = (h * width + w) * channels + c;
-        if (bdata_offset[idx] > maxval) {
-          maxval = bdata_offset[idx];
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const T* X_ptr = X + nc * X_HxW;
+  T* Y_ptr = Y + nc * Y_HxW;
+  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (y < Y_HxW) {
+    const int yy = y / Y_W;
+    const int yw = y % Y_W;
+    const int yh = yy % Y_H;
+    const int yd = yy / Y_H;
+    const int xd = yd * stride_d;
+    const int xh = yh * stride_h;
+    const int xw = yw * stride_w;
+    const int p = max(xd - pad_p, 0);
+    const int a = min(xd - pad_p + kernel_d, X_D);
+    const int t = max(xh - pad_t, 0);
+    const int b = min(xh - pad_t + kernel_h, X_H);
+    const int l = max(xw - pad_l, 0);
+    const int r = min(xw - pad_l + kernel_w, X_W);
+    T val = std::numeric_limits<T>::lowest();
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          val = max(val, X_ptr[(i * X_H + j) * X_W + k]);
         }
       }
     }
-    top_data[index] = maxval;
+    Y_ptr[y] = val;
   }
 }
 
 template <typename T>
-__global__ void MaxPool3DForwardNHWC(
-    const int nthreads,
-    const T* bottom_data,
-    const int height,
-    const int width,
-    const int depth,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+__global__ void MaxPool3DForwardNHWCCUDAKernel(
+    const int C,
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* top_data) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int n = index;
-    int c = n % channels;
-    n /= channels;
-    int dstart = (n % pooled_depth) * stride_d - pad_f;
-    n /= pooled_depth;
-    int wstart = (n % pooled_width) * stride_w - pad_l;
-    n /= pooled_width;
-    int hstart = (n % pooled_height) * stride_h - pad_t;
-    n /= pooled_height;
-    int hend = min(hstart + kernel_h, height);
-    int wend = min(wstart + kernel_w, width);
-    int dend = min(dstart + kernel_d, depth);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    dstart = max(dstart, 0);
-    T maxval = -FLT_MAX;
-    const T* bdata_offset = bottom_data + n * height * width * depth * channels;
-    for (int h = hstart; h < hend; ++h) {
-      for (int w = wstart; w < wend; ++w) {
-        for (int d = dstart; d < dend; ++d) {
-          int idx = ((h * width + w) * depth + d) * channels + c;
-          if (bdata_offset[idx] > maxval) {
-            maxval = bdata_offset[idx];
-          }
+    const T* X,
+    T* Y) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int n = blockIdx.x / Y_HxW;
+  const int y = blockIdx.x % Y_HxW;
+  const int yy = y / Y_W;
+  const int yw = y % Y_W;
+  const int yh = yy % Y_H;
+  const int yd = yy / Y_H;
+  const int xd = yd * stride_d;
+  const int xh = yh * stride_h;
+  const int xw = yw * stride_w;
+  const int p = max(xd - pad_p, 0);
+  const int a = min(xd - pad_p + kernel_d, X_D);
+  const int t = max(xh - pad_t, 0);
+  const int b = min(xh - pad_t + kernel_h, X_H);
+  const int l = max(xw - pad_l, 0);
+  const int r = min(xw - pad_l + kernel_w, X_W);
+  const T* X_ptr = X + n * X_HxW * C;
+  T* Y_ptr = Y + n * Y_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T val = std::numeric_limits<T>::lowest();
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          val = max(val, X_ptr[((i * X_H + j) * X_W + k) * C + c]);
         }
       }
     }
-    top_data[index] = maxval;
+    Y_ptr[y * C + c] = val;
   }
 }
 
@@ -1417,158 +1434,177 @@ __global__ void MaxPool3DBackwardNHWC(
     }
   }
 }
+
 } // namespace
 
 template <>
-bool PoolOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNCHW() {
-  auto& X = Input(0);
-  auto* Y = Output(0);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
-  int output_size = Y->size();
-  switch (kernel_.size()) {
-    case 1:
-      MaxPool1DForwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(1),
-              X.dim32(2),
-              Y->dim32(2),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              Y->template mutable_data<float>());
-      break;
-    case 2:
-      MaxPool2DForwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              Y->dim32(2),
-              Y->dim32(3),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              Y->template mutable_data<float>());
-      break;
-    case 3:
-      MaxPool3DForwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              Y->dim32(2),
-              Y->dim32(3),
-              Y->dim32(4),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              Y->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
+template <typename T, StorageOrder kOrder>
+bool MaxPoolFunctor<CUDAContext>::GlobalPoolingForward(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* X,
+    T* Y,
+    CUDAContext* context) const {
+  if (kOrder == StorageOrder::NCHW) {
+    const std::array<int, 2> dims = {N * C, HxW};
+    const int axis = 1;
+    math::ReduceMax<float, CUDAContext>(
+        2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  } else {
+    const std::array<int, 3> dims = {N, HxW, C};
+    const int axis = 1;
+    math::ReduceMax<float, CUDAContext>(
+        3, dims.data(), 1, &axis, 1.0f, X, Y, context);
   }
   return true;
 }
 
 template <>
-bool PoolOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNHWC() {
-  auto& X = Input(0);
-  auto* Y = Output(0);
-  ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(X.ndim() - 1));
-  int output_size = Y->size();
-  switch (kernel_.size()) {
-    case 1:
-      MaxPool1DForwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(1),
-              X.dim32(2),
-              Y->dim32(1),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              Y->template mutable_data<float>());
-      break;
-    case 2:
-      MaxPool2DForwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              Y->dim32(1),
-              Y->dim32(2),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              Y->template mutable_data<float>());
-      break;
-    case 3:
-      MaxPool3DForwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(output_size),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              output_size,
-              X.data<float>(),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              Y->dim32(1),
-              Y->dim32(2),
-              Y->dim32(3),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              Y->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
+template <>
+bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
+    const int N,
+    const int C,
+    const std::vector<int>& X_dims,
+    const std::vector<int>& Y_dims,
+    const std::vector<int>& kernel,
+    const std::vector<int>& /* dilation */,
+    const std::vector<int>& stride,
+    const std::vector<int>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  // Split each image into K segments, each CUDA block handles one segment.
+  const int ndim = X_dims.size();
+  const int Y_HxW = std::accumulate(
+      Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
+  const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+  switch (ndim) {
+    case 1: {
+      MaxPool1DForwardNCHWCUDAKernel<float>
+          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              K, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y);
+      return true;
+    }
+    case 2: {
+      MaxPool2DForwardNCHWCUDAKernel<float>
+          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              K,
+              X_dims[0],
+              X_dims[1],
+              Y_dims[0],
+              Y_dims[1],
+              kernel[0],
+              kernel[1],
+              stride[0],
+              stride[1],
+              pads[0],
+              pads[1],
+              X,
+              Y);
+      return true;
+    }
+    case 3: {
+      MaxPool3DForwardNCHWCUDAKernel<float>
+          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              K,
+              X_dims[0],
+              X_dims[1],
+              X_dims[2],
+              Y_dims[0],
+              Y_dims[1],
+              Y_dims[2],
+              kernel[0],
+              kernel[1],
+              kernel[2],
+              stride[0],
+              stride[1],
+              stride[2],
+              pads[0],
+              pads[1],
+              pads[2],
+              X,
+              Y);
+      return true;
+    }
+    default: {
+      CAFFE_THROW("Unsupported pooling dim: ", ndim);
+      return false;
+    }
+  }
+}
+
+template <>
+template <>
+bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
+    const int N,
+    const int C,
+    const std::vector<int>& X_dims,
+    const std::vector<int>& Y_dims,
+    const std::vector<int>& kernel,
+    const std::vector<int>& /* dilation */,
+    const std::vector<int>& stride,
+    const std::vector<int>& pads,
+    const float* X,
+    float* Y,
+    CUDAContext* context) const {
+  // Each CUDA block handles one point, one thread per channel.
+  const int ndim = X_dims.size();
+  const int Y_HxW = std::accumulate(
+      Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
+  switch (ndim) {
+    case 1: {
+      MaxPool1DForwardNHWCCUDAKernel<float>
+          <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              C, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y);
+      return true;
+    }
+    case 2: {
+      MaxPool2DForwardNHWCCUDAKernel<float>
+          <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              C,
+              X_dims[0],
+              X_dims[1],
+              Y_dims[0],
+              Y_dims[1],
+              kernel[0],
+              kernel[1],
+              stride[0],
+              stride[1],
+              pads[0],
+              pads[1],
+              X,
+              Y);
+      return true;
+    }
+    case 3: {
+      MaxPool3DForwardNHWCCUDAKernel<float>
+          <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              C,
+              X_dims[0],
+              X_dims[1],
+              X_dims[2],
+              Y_dims[0],
+              Y_dims[1],
+              Y_dims[2],
+              kernel[0],
+              kernel[1],
+              kernel[2],
+              stride[0],
+              stride[1],
+              stride[2],
+              pads[0],
+              pads[1],
+              pads[2],
+              X,
+              Y);
+      return true;
+    }
+    default: {
+      CAFFE_THROW("Unsupported pooling dim: ", ndim);
+      return false;
+    }
   }
-  return true;
 }
 
 template <>
@@ -1777,23 +1813,32 @@ REGISTER_CUDA_OPERATOR(
     AveragePool3DGradient,
     PoolGradientOp<float, CUDAContext, AveragePool>);
 
-REGISTER_CUDA_OPERATOR(MaxPool, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+    MaxPool,
+    PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPoolGradient,
     PoolGradientOp<float, CUDAContext, MaxPool>);
 
-REGISTER_CUDA_OPERATOR(MaxPool1D, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+    MaxPool1D,
+    PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPool1DGradient,
     PoolGradientOp<float, CUDAContext, MaxPool>);
 
-REGISTER_CUDA_OPERATOR(MaxPool2D, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+    MaxPool2D,
+    PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPool2DGradient,
     PoolGradientOp<float, CUDAContext, MaxPool>);
 
-REGISTER_CUDA_OPERATOR(MaxPool3D, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+    MaxPool3D,
+    PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPool3DGradient,
     PoolGradientOp<float, CUDAContext, MaxPool>);
+
 } // namespace caffe2
index a9758f4..4880a50 100644 (file)
@@ -40,6 +40,13 @@ class PoolOp final : public ConvPoolOpBase<Context> {
     const int N = X.dim32(0);
     const int C = X.dim32(1);
     ConvPoolOpBase<Context>::SetOutputSize(X, Y, C);
+    const T* X_data = X.template data<T>();
+    T* Y_data = Y->template mutable_data<T>();
+    if (global_pooling_) {
+      const int HxW = X.numel() / (N * C);
+      return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>(
+          N, C, HxW, X_data, Y_data, &context_);
+    }
     const std::vector<int> X_HW_dims = GetDims(X);
     const std::vector<int> Y_HW_dims = GetDims(*Y);
     return functor_.template Forward<T, StorageOrder::NCHW>(
@@ -63,6 +70,13 @@ class PoolOp final : public ConvPoolOpBase<Context> {
     const int N = X.dim32(0);
     const int C = X.dim32(ndim - 1);
     ConvPoolOpBase<Context>::SetOutputSize(X, Y, C);
+    const T* X_data = X.template data<T>();
+    T* Y_data = Y->template mutable_data<T>();
+    if (global_pooling_) {
+      const int HxW = X.numel() / (N * C);
+      return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>(
+          N, C, HxW, X_data, Y_data, &context_);
+    }
     const std::vector<int> X_HW_dims = GetDims(X);
     const std::vector<int> Y_HW_dims = GetDims(*Y);
     return functor_.template Forward<T, StorageOrder::NHWC>(
@@ -156,6 +170,15 @@ struct AveragePoolFunctor {
             op.template GetSingleArgument<bool>("count_include_pad", false)) {}
 
   template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingForward(
+      int N,
+      int C,
+      int HxW,
+      const T* X,
+      T* Y,
+      Context* context) const;
+
+  template <typename T, StorageOrder kOrder>
   bool Forward(
       int N,
       int C,
@@ -193,6 +216,15 @@ struct MaxPoolFunctor {
   explicit MaxPoolFunctor(const OperatorBase& /* op */) {}
 
   template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingForward(
+      int N,
+      int C,
+      int HxW,
+      const T* X,
+      T* Y,
+      Context* context) const;
+
+  template <typename T, StorageOrder kOrder>
   bool Forward(
       int N,
       int C,