Add count_include_pad to average_pool_gradient_op (#15997)
authorXiaomeng Yang <yangxm@fb.com>
Wed, 16 Jan 2019 00:44:33 +0000 (16:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 16 Jan 2019 00:56:40 +0000 (16:56 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15997

Add count_include_pad to average_pool_gradient_op

Reviewed By: houseroad

Differential Revision: D13648339

fbshipit-source-id: 205cb2acb32dc24a85256b628298b1a11f0ffa2c

caffe2/operators/pool_gradient_op.cc
caffe2/operators/pool_op.cu
caffe2/operators/pool_op.h
caffe2/python/operator_test/pooling_test.py

index 8fb3dff..dff8669 100644 (file)
@@ -618,6 +618,48 @@ void RunMaxPoolGradient3D(
 } // namespace
 
 template <>
+template <>
+bool AveragePoolFunctor<CPUContext>::
+    GlobalPoolingBackward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* /* X */,
+        const float* /* Y */,
+        float* dX,
+        CPUContext* /* context */) const {
+  const int NxC = N * C;
+  EigenArrayMap<float> dX_arr(dX, HxW, NxC);
+  const float scale = 1.0f / static_cast<float>(HxW);
+  for (int i = 0; i < NxC; ++i) {
+    dX_arr.col(i).setConstant(dY[i] * scale);
+  }
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CPUContext>::
+    GlobalPoolingBackward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* /* X */,
+        const float* /* Y */,
+        float* dX,
+        CPUContext* /* context */) const {
+  ConstEigenArrayMap<float> dY_arr(dY, C, N);
+  const float scale = 1.0f / static_cast<float>(HxW);
+  for (int i = 0; i < N; ++i) {
+    EigenArrayMap<float>(dX + i * HxW * C, C, HxW).colwise() =
+        dY_arr.col(i) * scale;
+  }
+  return true;
+}
+
+template <>
 template <typename T, StorageOrder kOrder>
 bool AveragePoolFunctor<CPUContext>::Backward(
     const int N,
@@ -700,6 +742,52 @@ bool AveragePoolFunctor<CPUContext>::Backward(
 }
 
 template <>
+template <>
+bool MaxPoolFunctor<CPUContext>::
+    GlobalPoolingBackward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* X,
+        const float* Y,
+        float* dX,
+        CPUContext* /* context */) const {
+  const int NxC = N * C;
+  ConstEigenArrayMap<float> X_arr(X, HxW, NxC);
+  EigenArrayMap<float> dX_arr(dX, HxW, NxC);
+  for (int i = 0; i < NxC; ++i) {
+    dX_arr.col(i) = (X_arr.col(i) == Y[i]).template cast<float>() * dY[i];
+  }
+  return true;
+}
+
+template <>
+template <>
+bool MaxPoolFunctor<CPUContext>::
+    GlobalPoolingBackward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* X,
+        const float* Y,
+        float* dX,
+        CPUContext* /* context */) const {
+  ConstEigenArrayMap<float> Y_arr(Y, C, N);
+  ConstEigenArrayMap<float> dY_arr(dY, C, N);
+  for (int i = 0; i < N; ++i) {
+    ConstEigenArrayMap<float> X_arr(X + i * HxW * C, C, HxW);
+    EigenArrayMap<float> dX_arr(dX + i * HxW * C, C, HxW);
+    for (int j = 0; j < HxW; ++j) {
+      dX_arr.col(j) =
+          (X_arr.col(j) == Y_arr.col(i)).template cast<float>() * dY_arr.col(i);
+    }
+  }
+  return true;
+}
+
+template <>
 template <typename T, StorageOrder kOrder>
 bool MaxPoolFunctor<CPUContext>::Backward(
     const int N,
index bca9e52..83f2759 100644 (file)
@@ -1,4 +1,3 @@
-// TODO(ataei): reduce the apparent redundancy of all the code below.
 #include "caffe2/operators/pool_op.h"
 
 #include <array>
@@ -13,17 +12,8 @@ namespace caffe2 {
 
 namespace {
 
-struct AveragePool {
-  explicit AveragePool(const OperatorBase& /* op */) {}
-};
-
-struct MaxPool {
-  explicit MaxPool(const OperatorBase& /* op */) {}
-};
-
 template <typename T>
 __global__ void AveragePool1DForwardNCHWCUDAKernel(
-    const int K,
     const int X_size,
     const int Y_size,
     const int kernel,
@@ -32,19 +22,21 @@ __global__ void AveragePool1DForwardNCHWCUDAKernel(
     const bool count_include_pad,
     const T* X,
     T* Y) {
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
+  const int nc = blockIdx.x;
   const T* X_ptr = X + nc * X_size;
   T* Y_ptr = Y + nc * Y_size;
-  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
-  if (y < Y_size) {
-    const int x = y * stride;
-    const int l = max(x - pad, 0);
-    const int r = min(x - pad + kernel, X_size);
+  for (int y = threadIdx.x; y < Y_size; y += blockDim.x) {
+    const int x = y * stride - pad;
+    const int l = max(x, 0);
+    const int r = min(x + kernel, X_size);
     const T scale = T(1) / static_cast<T>(count_include_pad ? kernel : r - l);
     T sum = 0;
     for (int i = l; i < r; ++i) {
+#if __CUDA_ARCH__ >= 350
+      sum += __ldg(X_ptr + i);
+#else
       sum += X_ptr[i];
+#endif
     }
     Y_ptr[y] = sum * scale;
   }
@@ -63,16 +55,20 @@ __global__ void AveragePool1DForwardNHWCCUDAKernel(
     T* Y) {
   const int n = blockIdx.x / Y_size;
   const int y = blockIdx.x % Y_size;
-  const int x = y * stride;
-  const int l = max(x - pad, 0);
-  const int r = min(x - pad + kernel, X_size);
+  const int x = y * stride - pad;
+  const int l = max(x, 0);
+  const int r = min(x + kernel, X_size);
   const T scale = T(1) / static_cast<T>(count_include_pad ? kernel : r - l);
   const T* X_ptr = X + n * X_size * C;
   T* Y_ptr = Y + n * Y_size * C;
   for (int c = threadIdx.x; c < C; c += blockDim.x) {
     T sum = 0;
     for (int i = l; i < r; ++i) {
+#if __CUDA_ARCH__ >= 350
+      sum += __ldg(X_ptr + i * C + c);
+#else
       sum += X_ptr[i * C + c];
+#endif
     }
     Y_ptr[y * C + c] = sum * scale;
   }
@@ -80,7 +76,6 @@ __global__ void AveragePool1DForwardNHWCCUDAKernel(
 
 template <typename T>
 __global__ void AveragePool2DForwardNCHWCUDAKernel(
-    const int K,
     const int X_H,
     const int X_W,
     const int Y_H,
@@ -96,30 +91,31 @@ __global__ void AveragePool2DForwardNCHWCUDAKernel(
     T* Y) {
   const int X_HxW = X_H * X_W;
   const int Y_HxW = Y_H * Y_W;
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
+  const int nc = blockIdx.x / Y_H;
+  const int yh = blockIdx.x % Y_H;
   const T* X_ptr = X + nc * X_HxW;
   T* Y_ptr = Y + nc * Y_HxW;
-  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
-  if (y < Y_HxW) {
-    const int yh = y / Y_W;
-    const int yw = y % Y_W;
-    const int xh = yh * stride_h;
-    const int xw = yw * stride_w;
-    const int t = max(xh - pad_t, 0);
-    const int b = min(xh - pad_t + kernel_h, X_H);
-    const int l = max(xw - pad_l, 0);
-    const int r = min(xw - pad_l + kernel_w, X_W);
+  const int xh = yh * stride_h - pad_t;
+  const int t = max(xh, 0);
+  const int b = min(xh + kernel_h, X_H);
+  for (int yw = threadIdx.x; yw < Y_W; yw += blockDim.x) {
+    const int xw = yw * stride_w - pad_l;
+    const int l = max(xw, 0);
+    const int r = min(xw + kernel_w, X_W);
     const T scale = T(1) /
         static_cast<T>(count_include_pad ? kernel_h * kernel_w
                                          : (b - t) * (r - l));
     T sum = 0;
     for (int i = t; i < b; ++i) {
       for (int j = l; j < r; ++j) {
+#if __CUDA_ARCH__ >= 350
+        sum += __ldg(X_ptr + i * X_W + j);
+#else
         sum += X_ptr[i * X_W + j];
+#endif
       }
     }
-    Y_ptr[y] = sum * scale;
+    Y_ptr[yh * Y_W + yw] = sum * scale;
   }
 }
 
@@ -145,12 +141,12 @@ __global__ void AveragePool2DForwardNHWCCUDAKernel(
   const int y = blockIdx.x % Y_HxW;
   const int yh = y / Y_W;
   const int yw = y % Y_W;
-  const int xh = yh * stride_h;
-  const int xw = yw * stride_w;
-  const int t = max(xh - pad_t, 0);
-  const int b = min(xh - pad_t + kernel_h, X_H);
-  const int l = max(xw - pad_l, 0);
-  const int r = min(xw - pad_l + kernel_w, X_W);
+  const int xh = yh * stride_h - pad_t;
+  const int xw = yw * stride_w - pad_l;
+  const int t = max(xh, 0);
+  const int b = min(xh + kernel_h, X_H);
+  const int l = max(xw, 0);
+  const int r = min(xw + kernel_w, X_W);
   const T scale = T(1) /
       static_cast<T>(count_include_pad ? kernel_h * kernel_w
                                        : (b - t) * (r - l));
@@ -160,7 +156,11 @@ __global__ void AveragePool2DForwardNHWCCUDAKernel(
     T sum = 0;
     for (int i = t; i < b; ++i) {
       for (int j = l; j < r; ++j) {
+#if __CUDA_ARCH__ >= 350
+        sum += __ldg(X_ptr + (i * X_W + j) * C + c);
+#else
         sum += X_ptr[(i * X_W + j) * C + c];
+#endif
       }
     }
     Y_ptr[y * C + c] = sum * scale;
@@ -169,7 +169,6 @@ __global__ void AveragePool2DForwardNHWCCUDAKernel(
 
 template <typename T>
 __global__ void AveragePool3DForwardNCHWCUDAKernel(
-    const int K,
     const int X_D,
     const int X_H,
     const int X_W,
@@ -190,25 +189,22 @@ __global__ void AveragePool3DForwardNCHWCUDAKernel(
     T* Y) {
   const int X_HxW = X_D * X_H * X_W;
   const int Y_HxW = Y_D * Y_H * Y_W;
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
+  const int yy = blockIdx.x / Y_H;
+  const int nc = yy / Y_D;
+  const int yd = yy % Y_D;
+  const int yh = blockIdx.x % Y_H;
   const T* X_ptr = X + nc * X_HxW;
   T* Y_ptr = Y + nc * Y_HxW;
-  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
-  if (y < Y_HxW) {
-    const int yy = y / Y_W;
-    const int yw = y % Y_W;
-    const int yh = yy % Y_H;
-    const int yd = yy / Y_H;
-    const int xd = yd * stride_d;
-    const int xh = yh * stride_h;
-    const int xw = yw * stride_w;
-    const int p = max(xd - pad_p, 0);
-    const int a = min(xd - pad_p + kernel_d, X_D);
-    const int t = max(xh - pad_t, 0);
-    const int b = min(xh - pad_t + kernel_h, X_H);
-    const int l = max(xw - pad_l, 0);
-    const int r = min(xw - pad_l + kernel_w, X_W);
+  const int xd = yd * stride_d - pad_p;
+  const int xh = yh * stride_h - pad_t;
+  const int p = max(xd, 0);
+  const int a = min(xd + kernel_d, X_D);
+  const int t = max(xh, 0);
+  const int b = min(xh + kernel_h, X_H);
+  for (int yw = threadIdx.x; yw < Y_W; yw += blockDim.x) {
+    const int xw = yw * stride_w - pad_l;
+    const int l = max(xw, 0);
+    const int r = min(xw + kernel_w, X_W);
     const T scale = T(1) /
         static_cast<T>(count_include_pad ? kernel_d * kernel_h * kernel_w
                                          : (a - p) * (b - t) * (r - l));
@@ -216,11 +212,15 @@ __global__ void AveragePool3DForwardNCHWCUDAKernel(
     for (int i = p; i < a; ++i) {
       for (int j = t; j < b; ++j) {
         for (int k = l; k < r; ++k) {
+#if __CUDA_ARCH__ >= 350
+          sum += __ldg(X_ptr + (i * X_H + j) * X_W + k);
+#else
           sum += X_ptr[(i * X_H + j) * X_W + k];
+#endif
         }
       }
     }
-    Y_ptr[y] = sum * scale;
+    Y_ptr[(yd * Y_H + yh) * Y_W + yw] = sum * scale;
   }
 }
 
@@ -250,18 +250,18 @@ __global__ void AveragePool3DForwardNHWCCUDAKernel(
   const int n = blockIdx.x / Y_HxW;
   const int y = blockIdx.x % Y_HxW;
   const int yy = y / Y_W;
-  const int yw = y % Y_W;
-  const int yh = yy % Y_H;
   const int yd = yy / Y_H;
-  const int xd = yd * stride_d;
-  const int xh = yh * stride_h;
-  const int xw = yw * stride_w;
-  const int p = max(xd - pad_p, 0);
-  const int a = min(xd - pad_p + kernel_d, X_D);
-  const int t = max(xh - pad_t, 0);
-  const int b = min(xh - pad_t + kernel_h, X_H);
-  const int l = max(xw - pad_l, 0);
-  const int r = min(xw - pad_l + kernel_w, X_W);
+  const int yh = yy % Y_H;
+  const int yw = y % Y_W;
+  const int xd = yd * stride_d - pad_p;
+  const int xh = yh * stride_h - pad_t;
+  const int xw = yw * stride_w - pad_l;
+  const int p = max(xd, 0);
+  const int a = min(xd + kernel_d, X_D);
+  const int t = max(xh, 0);
+  const int b = min(xh + kernel_h, X_H);
+  const int l = max(xw, 0);
+  const int r = min(xw + kernel_w, X_W);
   const T scale = T(1) /
       static_cast<T>(count_include_pad ? kernel_d * kernel_h * kernel_w
                                        : (a - p) * (b - t) * (r - l));
@@ -272,7 +272,11 @@ __global__ void AveragePool3DForwardNHWCCUDAKernel(
     for (int i = p; i < a; ++i) {
       for (int j = t; j < b; ++j) {
         for (int k = l; k < r; ++k) {
+#if __CUDA_ARCH__ >= 350
+          sum += __ldg(X_ptr + ((i * X_H + j) * X_W + k) * C + c);
+#else
           sum += X_ptr[((i * X_H + j) * X_W + k) * C + c];
+#endif
         }
       }
     }
@@ -281,295 +285,404 @@ __global__ void AveragePool3DForwardNHWCCUDAKernel(
 }
 
 template <typename T>
-__global__ void Ave1DPoolBackwardNCHW(
-    const int nthreads,
-    const T* const top_diff,
-    const int num,
-    const int channels,
-    const int height,
-    const int pooled_height,
-    const int kernel_h,
-    const int stride_h,
-    const int pad_t,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int h = index % height + pad_t;
-    const int c = (index / height) % channels;
-    const int n = index / height / channels;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    T gradient = 0;
-    const T* const top_diff_slice =
-        top_diff + (n * channels + c) * pooled_height;
-    for (int ph = phstart; ph < phend; ++ph) {
-      // figure out the pooling size
-      int hstart = ph * stride_h - pad_t;
-      int hend = min(hstart + kernel_h, height);
-      hstart = max(hstart, 0);
-      int pool_size = (hend - hstart);
-      gradient += top_diff_slice[ph] / pool_size;
-    }
-    bottom_diff[index] = gradient;
+__global__ void GlobalAveragePoolingBackwardNCHWCUDAKernel(
+    const int K,
+    const int HxW,
+    const T scale,
+    const T* dY,
+    T* dX) {
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const int x = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (x < HxW) {
+#if __CUDA_ARCH__ >= 350
+    dX[nc * HxW + x] = __ldg(dY + nc) * scale;
+#else
+    dX[nc * HxW + x] = dY[nc] * scale;
+#endif
   }
 }
 
 template <typename T>
-__global__ void Ave2DPoolBackwardNCHW(
-    const int nthreads,
-    const T* const top_diff,
-    const int num,
-    const int channels,
-    const int height,
-    const int width,
-    const int pooled_height,
-    const int pooled_width,
-    const int kernel_h,
-    const int kernel_w,
-    const int stride_h,
-    const int stride_w,
-    const int pad_t,
-    const int pad_l,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int w = index % width + pad_l;
-    const int h = (index / width) % height + pad_t;
-    const int c = (index / width / height) % channels;
-    const int n = index / width / height / channels;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    T gradient = 0;
-    const T* const top_diff_slice =
-        top_diff + (n * channels + c) * pooled_height * pooled_width;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        // figure out the pooling size
-        int hstart = ph * stride_h - pad_t;
-        int wstart = pw * stride_w - pad_l;
-        int hend = min(hstart + kernel_h, height);
-        int wend = min(wstart + kernel_w, width);
-        hstart = max(hstart, 0);
-        wstart = max(wstart, 0);
-        int pool_size = (hend - hstart) * (wend - wstart);
-        gradient += top_diff_slice[ph * pooled_width + pw] / pool_size;
+__global__ void GlobalAveragePoolingBackwardNHWCCUDAKernel(
+    const int C,
+    const int HxW,
+    const T scale,
+    const T* dY,
+    T* dX) {
+  const int n = blockIdx.x / HxW;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+#if __CUDA_ARCH__ >= 350
+    dX[blockIdx.x * C + c] = __ldg(dY + n * C + c) * scale;
+#else
+    dX[blockIdx.x * C + c] = dY[n * C + c] * scale;
+#endif
+  }
+}
+
+template <typename T, bool kCountIncludePad>
+__global__ void AveragePool1DBackwardNCHWCUDAKernel(
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const T* dY,
+    T* dX) {
+  const int nc = blockIdx.x;
+  const T* dY_ptr = dY + nc * Y_size;
+  T* dX_ptr = dX + nc * X_size;
+  for (int x = threadIdx.x; x < X_size; x += blockDim.x) {
+    const int w = x + pad;
+    const int l = w < kernel ? 0 : (w - kernel) / stride + 1;
+    const int r = min(w / stride + 1, Y_size);
+    T sum = 0;
+    for (int i = l; i < r; ++i) {
+      if (kCountIncludePad) {
+#if __CUDA_ARCH__ >= 350
+        sum += __ldg(dY_ptr + i);
+#else
+        sum += dY_ptr[i];
+#endif
+      } else {
+        const int xx = i * stride - pad;
+        const int xl = max(xx, 0);
+        const int xr = min(xx + kernel, X_size);
+#if __CUDA_ARCH__ >= 350
+        sum += __ldg(dY_ptr + i) / static_cast<T>(xr - xl);
+#else
+        sum += dY_ptr[i] / static_cast<T>(xr - xl);
+#endif
       }
     }
-    bottom_diff[index] = gradient;
+    if (kCountIncludePad) {
+      dX_ptr[x] = sum / static_cast<T>(kernel);
+    } else {
+      dX_ptr[x] = sum;
+    }
   }
 }
 
-template <typename T>
-__global__ void Ave3DPoolBackwardNCHW(
-    const int nthreads,
-    const T* const top_diff,
-    const int num,
-    const int channels,
-    const int height,
-    const int width,
-    const int depth,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+template <typename T, bool kCountIncludePad>
+__global__ void AveragePool1DBackwardNHWCCUDAKernel(
+    const int C,
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const T* dY,
+    T* dX) {
+  const int n = blockIdx.x / X_size;
+  const int x = blockIdx.x % X_size;
+  const int w = x + pad;
+  const int l = w < kernel ? 0 : (w - kernel) / stride + 1;
+  const int r = min(w / stride + 1, Y_size);
+  const T scale = T(1) / static_cast<T>(kernel);
+  const T* dY_ptr = dY + n * Y_size * C;
+  T* dX_ptr = dX + n * X_size * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = l; i < r; ++i) {
+      if (kCountIncludePad) {
+#if __CUDA_ARCH__ >= 350
+        sum += __ldg(dY_ptr + i * C + c);
+#else
+        sum += dY_ptr[i * C + c];
+#endif
+      } else {
+        const int xx = i * stride - pad;
+        const int xl = max(xx, 0);
+        const int xr = min(xx + kernel, X_size);
+#if __CUDA_ARCH__ >= 350
+        sum += __ldg(dY_ptr + i * C + c) / static_cast<T>(xr - xl);
+#else
+        sum += dY_ptr[i * C + c] / static_cast<T>(xr - xl);
+#endif
+      }
+    }
+    if (kCountIncludePad) {
+      dX_ptr[x * C + c] = sum * scale;
+    } else {
+      dX_ptr[x * C + c] = sum;
+    }
+  }
+}
+
+template <typename T, bool kCountIncludePad>
+__global__ void AveragePool2DBackwardNCHWCUDAKernel(
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int d = index % depth + pad_f;
-    const int w = (index / depth) % width + pad_l;
-    const int h = (index / depth / width) % height + pad_t;
-    const int c = (index / depth / width / height) % channels;
-    const int n = index / depth / width / height / channels;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    const int pdstart = (d < kernel_d) ? 0 : (d - kernel_d) / stride_d + 1;
-    const int pdend = min(d / stride_d + 1, pooled_depth);
-    T gradient = 0;
-    const T* const top_diff_slice = top_diff +
-        (n * channels + c) * pooled_height * pooled_width * pooled_depth;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        for (int pd = pdstart; pd < pdend; ++pd) {
-          // figure out the pooling size
-          int hstart = ph * stride_h - pad_t;
-          int wstart = pw * stride_w - pad_l;
-          int dstart = pd * stride_d - pad_f;
-          int hend = min(hstart + kernel_h, height);
-          int wend = min(wstart + kernel_w, width);
-          int dend = min(dstart + kernel_d, depth);
-          hstart = max(hstart, 0);
-          wstart = max(wstart, 0);
-          dstart = max(dstart, 0);
-          int pool_size = (hend - hstart) * (wend - wstart) * (dend - dstart);
-          const int pooled_index =
-              ph * pooled_depth * pooled_width + pooled_depth * pw + pd;
-          gradient += top_diff_slice[pooled_index] / pool_size;
+    const T* dY,
+    T* dX) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int nc = blockIdx.x / X_H;
+  const int hh = blockIdx.x % X_H;
+  const T* dY_ptr = dY + nc * Y_HxW;
+  T* dX_ptr = dX + nc * X_HxW;
+  const int h = hh + pad_t;
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  for (int ww = threadIdx.x; ww < X_W; ww += blockDim.x) {
+    const int w = ww + pad_l;
+    const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+    const int r = min(w / stride_w + 1, Y_W);
+    T sum = 0;
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        if (kCountIncludePad) {
+#if __CUDA_ARCH__ >= 350
+          sum += __ldg(dY_ptr + i * Y_W + j);
+#else
+          sum += dY_ptr[i * Y_W + j];
+#endif
+        } else {
+          const int xh = i * stride_h - pad_t;
+          const int xw = j * stride_w - pad_l;
+          const int xt = max(xh, 0);
+          const int xb = min(xh + kernel_h, X_H);
+          const int xl = max(xw, 0);
+          const int xr = min(xw + kernel_w, X_W);
+#if __CUDA_ARCH__ >= 350
+          sum += __ldg(dY_ptr + i * Y_W + j) /
+              static_cast<T>((xb - xt) * (xr - xl));
+#else
+          sum += dY_ptr[i * Y_W + j] / static_cast<T>((xb - xt) * (xr - xl));
+#endif
         }
       }
     }
-    bottom_diff[index] = gradient;
+    if (kCountIncludePad) {
+      dX_ptr[hh * X_W + ww] = sum / static_cast<T>(kernel_h * kernel_w);
+    } else {
+      dX_ptr[hh * X_W + ww] = sum;
+    }
   }
 }
 
-template <typename T>
-__global__ void Ave1DPoolBackwardNHWC(
-    const int nthreads,
-    const T* const top_diff,
-    const int num,
-    const int height,
-    const int channels,
-    const int pooled_height,
+template <typename T, bool kCountIncludePad>
+__global__ void AveragePool2DBackwardNHWCCUDAKernel(
+    const int C,
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
+    const int kernel_w,
     const int stride_h,
+    const int stride_w,
     const int pad_t,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int c = index % channels;
-    const int h = (index / channels) % height + pad_t;
-    const int n = index / channels / height;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    T gradient = 0;
-    const T* const top_diff_slice = top_diff + n * pooled_height * channels + c;
-    for (int ph = phstart; ph < phend; ++ph) {
-      // figure out the pooling size
-      int hstart = ph * stride_h - pad_t;
-      int hend = min(hstart + kernel_h, height);
-      hstart = max(hstart, 0);
-      int pool_size = (hend - hstart);
-      gradient += top_diff_slice[ph * channels] / pool_size;
+    const int pad_l,
+    const T* dY,
+    T* dX) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int n = blockIdx.x / X_HxW;
+  const int x = blockIdx.x % X_HxW;
+  const int h = x / X_W + pad_t;
+  const int w = x % X_W + pad_l;
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+  const int r = min(w / stride_w + 1, Y_W);
+  const T scale = T(1) / static_cast<T>(kernel_h * kernel_w);
+  const T* dY_ptr = dY + n * Y_HxW * C;
+  T* dX_ptr = dX + n * X_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        if (kCountIncludePad) {
+#if __CUDA_ARCH__ >= 350
+          sum += __ldg(dY_ptr + (i * Y_W + j) * C + c);
+#else
+          sum += dY_ptr[(i * Y_W + j) * C + c];
+#endif
+        } else {
+          const int xh = i * stride_h - pad_t;
+          const int xw = j * stride_w - pad_l;
+          const int xt = max(xh, 0);
+          const int xb = min(xh + kernel_h, X_H);
+          const int xl = max(xw, 0);
+          const int xr = min(xw + kernel_w, X_W);
+#if __CUDA_ARCH__ >= 350
+          sum += __ldg(dY_ptr + (i * Y_W + j) * C + c) /
+              static_cast<T>((xb - xt) * (xr - xl));
+#else
+          sum += dY_ptr[(i * Y_W + j) * C + c] /
+              static_cast<T>((xb - xt) * (xr - xl));
+#endif
+        }
+      }
+    }
+    if (kCountIncludePad) {
+      dX_ptr[x * C + c] = sum * scale;
+    } else {
+      dX_ptr[x * C + c] = sum;
     }
-    bottom_diff[index] = gradient;
   }
 }
 
-template <typename T>
-__global__ void Ave2DPoolBackwardNHWC(
-    const int nthreads,
-    const T* const top_diff,
-    const int num,
-    const int height,
-    const int width,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
+template <typename T, bool kCountIncludePad>
+__global__ void AveragePool3DBackwardNCHWCUDAKernel(
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int c = index % channels;
-    const int w = index / channels % width + pad_l;
-    const int h = (index / channels / width) % height + pad_t;
-    const int n = index / channels / width / height;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    T gradient = 0;
-    const T* const top_diff_slice =
-        top_diff + n * pooled_height * pooled_width * channels + c;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        // figure out the pooling size
-        int hstart = ph * stride_h - pad_t;
-        int wstart = pw * stride_w - pad_l;
-        int hend = min(hstart + kernel_h, height);
-        int wend = min(wstart + kernel_w, width);
-        hstart = max(hstart, 0);
-        wstart = max(wstart, 0);
-        int pool_size = (hend - hstart) * (wend - wstart);
-        gradient +=
-            top_diff_slice[(ph * pooled_width + pw) * channels] / pool_size;
+    const T* dY,
+    T* dX) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int xx = blockIdx.x / X_H;
+  const int nc = xx / X_D;
+  const int dd = xx % X_D;
+  const int hh = blockIdx.x % X_H;
+  const T* dY_ptr = dY + nc * Y_HxW;
+  T* dX_ptr = dX + nc * X_HxW;
+  const int d = dd + pad_p;
+  const int h = hh + pad_t;
+  const int p = d < kernel_d ? 0 : (d - kernel_d) / stride_d + 1;
+  const int a = min(d / stride_d + 1, Y_D);
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  for (int ww = threadIdx.x; ww < X_W; ww += blockDim.x) {
+    const int w = ww + pad_l;
+    const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+    const int r = min(w / stride_w + 1, Y_W);
+    T sum = 0;
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          if (kCountIncludePad) {
+#if __CUDA_ARCH__ >= 350
+            sum += __ldg(dY_ptr + (i * Y_H + j) * Y_W + k);
+#else
+            sum += dY_ptr[(i * Y_H + j) * Y_W + k];
+#endif
+          } else {
+            const int xd = i * stride_d - pad_p;
+            const int xh = j * stride_h - pad_t;
+            const int xw = k * stride_w - pad_l;
+            const int xp = max(xd, 0);
+            const int xa = min(xd + kernel_d, X_D);
+            const int xt = max(xh, 0);
+            const int xb = min(xh + kernel_h, X_H);
+            const int xl = max(xw, 0);
+            const int xr = min(xw + kernel_w, X_W);
+#if __CUDA_ARCH__ >= 350
+            sum += __ldg(dY_ptr + (i * Y_H + j) * Y_W + k) /
+                static_cast<T>((xa - xp) * (xb - xt) * (xr - xl));
+#else
+            sum += dY_ptr[(i * Y_H + j) * Y_W + k] /
+                static_cast<T>((xa - xp) * (xb - xt) * (xr - xl));
+#endif
+          }
+        }
       }
     }
-    bottom_diff[index] = gradient;
+    if (kCountIncludePad) {
+      dX_ptr[(dd * X_H + hh) * X_W + ww] =
+          sum / static_cast<T>(kernel_d * kernel_h * kernel_w);
+    } else {
+      dX_ptr[(dd * X_H + hh) * X_W + ww] = sum;
+    }
   }
 }
 
-template <typename T>
-__global__ void Ave3DPoolBackwardNHWC(
-    const int nthreads,
-    const T* const top_diff,
-    const int num,
-    const int height,
-    const int width,
-    const int depth,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+template <typename T, bool kCountIncludePad>
+__global__ void AveragePool3DBackwardNHWCCUDAKernel(
+    const int C,
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int c = index % channels;
-    const int d = index / channels % depth + pad_f;
-    const int w = (index / channels / depth) % width + pad_l;
-    const int h = (index / channels / depth / width) % height + pad_t;
-    const int n = index / channels / depth / width / height;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    const int pdstart = (d < kernel_d) ? 0 : (d - kernel_d) / stride_d + 1;
-    const int pdend = min(d / stride_d + 1, pooled_depth);
-    T gradient = 0;
-    const T* const top_diff_slice = top_diff +
-        n * pooled_height * pooled_width * pooled_depth * channels + c;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        for (int pd = pdstart; pd < pdend; ++pd) {
-          // figure out the pooling size
-          int hstart = ph * stride_h - pad_t;
-          int wstart = pw * stride_w - pad_l;
-          int dstart = pd * stride_d - pad_f;
-          int hend = min(hstart + kernel_h, height);
-          int wend = min(wstart + kernel_w, width);
-          int dend = min(dstart + kernel_d, depth);
-          hstart = max(hstart, 0);
-          wstart = max(wstart, 0);
-          dstart = max(dstart, 0);
-          int pool_size = (hend - hstart) * (wend - wstart) * (dend - dstart);
-          const int pooled_index =
-              (ph * pooled_depth * pooled_width + pw * pooled_depth + pd) *
-              channels;
-          gradient += top_diff_slice[pooled_index] / pool_size;
+    const T* dY,
+    T* dX) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int n = blockIdx.x / X_HxW;
+  const int x = blockIdx.x % X_HxW;
+  const int xx = x / X_W;
+  const int d = xx / X_H + pad_p;
+  const int h = xx % X_H + pad_t;
+  const int w = x % X_W + pad_l;
+  const int p = d < kernel_d ? 0 : (d - kernel_d) / stride_d + 1;
+  const int a = min(d / stride_d + 1, Y_D);
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+  const int r = min(w / stride_w + 1, Y_W);
+  const T scale = T(1) / static_cast<T>(kernel_d * kernel_h * kernel_w);
+  const T* dY_ptr = dY + n * Y_HxW * C;
+  T* dX_ptr = dX + n * X_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          if (kCountIncludePad) {
+#if __CUDA_ARCH__ >= 350
+            sum += __ldg(dY_ptr + ((i * Y_H + j) * Y_W + k) * C + c);
+#else
+            sum += dY_ptr[((i * Y_H + j) * Y_W + k) * C + c];
+#endif
+          } else {
+            const int xd = i * stride_d - pad_p;
+            const int xh = j * stride_h - pad_t;
+            const int xw = k * stride_w - pad_l;
+            const int xp = max(xd, 0);
+            const int xa = min(xd + kernel_d, X_D);
+            const int xt = max(xh, 0);
+            const int xb = min(xh + kernel_h, X_H);
+            const int xl = max(xw, 0);
+            const int xr = min(xw + kernel_w, X_W);
+#if __CUDA_ARCH__ >= 350
+            sum += __ldg(dY_ptr + ((i * Y_H + j) * Y_W + k) * C + c) /
+                static_cast<T>((xa - xp) * (xb - xt) * (xr - xl));
+#else
+            sum += dY_ptr[((i * Y_H + j) * Y_W + k) * C + c] /
+                static_cast<T>((xa - xp) * (xb - xt) * (xr - xl));
+#endif
+          }
         }
       }
     }
-    bottom_diff[index] = gradient;
+    if (kCountIncludePad) {
+      dX_ptr[x * C + c] = sum * scale;
+    } else {
+      dX_ptr[x * C + c] = sum;
+    }
   }
 }
 
@@ -612,16 +725,12 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
     const float* X,
     float* Y,
     CUDAContext* context) const {
-  // Split each image into K segments, each CUDA block handles one segment.
   const int ndim = X_dims.size();
-  const int Y_HxW = std::accumulate(
-      Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
-  const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
   switch (ndim) {
     case 1: {
+      const int num_blocks = N * C;
       AveragePool1DForwardNCHWCUDAKernel<float>
-          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-              K,
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
               X_dims[0],
               Y_dims[0],
               kernel[0],
@@ -633,9 +742,9 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
       return true;
     }
     case 2: {
+      const int num_blocks = N * C * Y_dims[0];
       AveragePool2DForwardNCHWCUDAKernel<float>
-          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-              K,
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
               X_dims[0],
               X_dims[1],
               Y_dims[0],
@@ -652,9 +761,9 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
       return true;
     }
     case 3: {
+      const int num_blocks = N * C * Y_dims[0] * Y_dims[1];
       AveragePool3DForwardNCHWCUDAKernel<float>
-          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-              K,
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
               X_dims[0],
               X_dims[1],
               X_dims[2],
@@ -766,177 +875,253 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
 }
 
 template <>
-bool PoolGradientOp<float, CUDAContext, AveragePool>::
-    RunOnDeviceWithOrderNCHW() {
-  auto& X = Input(0);
-  auto& dY = Input(2);
-  CAFFE_ENFORCE_EQ(dY.dim32(1), X.dim32(1));
+template <>
+bool AveragePoolFunctor<CUDAContext>::
+    GlobalPoolingBackward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* /* X */,
+        const float* /* Y */,
+        float* dX,
+        CUDAContext* context) const {
+  const float scale = 1.0f / static_cast<float>(HxW);
+  const int K = (HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+  GlobalAveragePoolingBackwardNCHWCUDAKernel<float>
+      <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          K, HxW, scale, dY, dX);
+  return true;
+}
 
-  auto* dX = Output(0, X.sizes(), at::dtype<float>());
-  vector<int> dims(X.sizes().begin() + 2, X.sizes().end());
-  ConvPoolOpBase<CUDAContext>::ComputePads(dims);
-  switch (kernel_.size()) {
-    case 1:
-      Ave1DPoolBackwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              dY.dim32(2),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              dX->template mutable_data<float>());
-      break;
-    case 2:
-      Ave2DPoolBackwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              dY.dim32(2),
-              dY.dim32(3),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              dX->template mutable_data<float>());
-      break;
-    case 3:
-      Ave3DPoolBackwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              dY.dim32(2),
-              dY.dim32(3),
-              dY.dim32(4),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              dX->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
-  }
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::
+    GlobalPoolingBackward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* /* X */,
+        const float* /* Y */,
+        float* dX,
+        CUDAContext* context) const {
+  const float scale = 1.0f / static_cast<float>(HxW);
+  GlobalAveragePoolingBackwardNHWCCUDAKernel<float>
+      <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          C, HxW, scale, dY, dX);
   return true;
 }
 
+#define DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(                       \
+    cond, Func, T, num_blocks, threads_per_block, cuda_stream, ...)         \
+  do {                                                                      \
+    if (cond) {                                                             \
+      Func<T, true>                                                         \
+          <<<num_blocks, threads_per_block, 0, cuda_stream>>>(__VA_ARGS__); \
+    } else {                                                                \
+      Func<T, false>                                                        \
+          <<<num_blocks, threads_per_block, 0, cuda_stream>>>(__VA_ARGS__); \
+    }                                                                       \
+  } while (false)
+
+template <>
 template <>
-bool PoolGradientOp<float, CUDAContext, AveragePool>::
-    RunOnDeviceWithOrderNHWC() {
-  auto& X = Input(0);
-  auto& dY = Input(2);
-  CAFFE_ENFORCE_EQ(X.ndim(), dY.ndim());
-  CAFFE_ENFORCE_EQ(X.dim32(X.ndim() - 1), dY.dim32(dY.ndim() - 1));
+bool AveragePoolFunctor<CUDAContext>::Backward<float, StorageOrder::NCHW>(
+    const int N,
+    const int C,
+    const std::vector<int>& X_dims,
+    const std::vector<int>& Y_dims,
+    const std::vector<int>& kernel,
+    const std::vector<int>& /* dilation */,
+    const std::vector<int>& stride,
+    const std::vector<int>& pads,
+    const float* dY,
+    const float* /* X */,
+    const float* /* Y */,
+    float* dX,
+    CUDAContext* context) const {
+  const int ndim = X_dims.size();
+  switch (ndim) {
+    case 1: {
+      const int num_blocks = N * C;
+      DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(
+          count_include_pad,
+          AveragePool1DBackwardNCHWCUDAKernel,
+          float,
+          num_blocks,
+          CAFFE_CUDA_NUM_THREADS,
+          context->cuda_stream(),
+          X_dims[0],
+          Y_dims[0],
+          kernel[0],
+          stride[0],
+          pads[0],
+          dY,
+          dX);
+      return true;
+    }
+    case 2: {
+      const int num_blocks = N * C * X_dims[0];
+      DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(
+          count_include_pad,
+          AveragePool2DBackwardNCHWCUDAKernel,
+          float,
+          num_blocks,
+          CAFFE_CUDA_NUM_THREADS,
+          context->cuda_stream(),
+          X_dims[0],
+          X_dims[1],
+          Y_dims[0],
+          Y_dims[1],
+          kernel[0],
+          kernel[1],
+          stride[0],
+          stride[1],
+          pads[0],
+          pads[1],
+          dY,
+          dX);
+      return true;
+    }
+    case 3: {
+      const int num_blocks = N * C * X_dims[0] * X_dims[1];
+      DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(
+          count_include_pad,
+          AveragePool3DBackwardNCHWCUDAKernel,
+          float,
+          num_blocks,
+          CAFFE_CUDA_NUM_THREADS,
+          context->cuda_stream(),
+          X_dims[0],
+          X_dims[1],
+          X_dims[2],
+          Y_dims[0],
+          Y_dims[1],
+          Y_dims[2],
+          kernel[0],
+          kernel[1],
+          kernel[2],
+          stride[0],
+          stride[1],
+          stride[2],
+          pads[0],
+          pads[1],
+          pads[2],
+          dY,
+          dX);
+      return true;
+    }
+    default: {
+      CAFFE_THROW("Unsupported pooling dim: ", ndim);
+      return false;
+    }
+  }
+}
 
-  auto* dX = Output(0, X.sizes(), at::dtype<float>());
-  vector<int> dims(X.sizes().begin() + 1, X.sizes().end() - 1);
-  ConvPoolOpBase<CUDAContext>::ComputePads(dims);
-  switch (kernel_.size()) {
-    case 1:
-      Ave1DPoolBackwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              dY.dim32(1),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              dX->template mutable_data<float>());
-      break;
-    case 2:
-      Ave2DPoolBackwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              dY.dim32(1),
-              dY.dim32(2),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              dX->template mutable_data<float>());
-      break;
-    case 3:
-      Ave3DPoolBackwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              dY.dim32(1),
-              dY.dim32(2),
-              dY.dim32(3),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              dX->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Backward<float, StorageOrder::NHWC>(
+    const int N,
+    const int C,
+    const std::vector<int>& X_dims,
+    const std::vector<int>& Y_dims,
+    const std::vector<int>& kernel,
+    const std::vector<int>& /* dilation */,
+    const std::vector<int>& stride,
+    const std::vector<int>& pads,
+    const float* dY,
+    const float* /* X */,
+    const float* /* Y */,
+    float* dX,
+    CUDAContext* context) const {
+  const int ndim = X_dims.size();
+  const int X_HxW = std::accumulate(
+      X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
+  const int num_blocks = N * X_HxW;
+  switch (ndim) {
+    case 1: {
+      DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(
+          count_include_pad,
+          AveragePool1DBackwardNHWCCUDAKernel,
+          float,
+          num_blocks,
+          CAFFE_CUDA_NUM_THREADS,
+          context->cuda_stream(),
+          C,
+          X_dims[0],
+          Y_dims[0],
+          kernel[0],
+          stride[0],
+          pads[0],
+          dY,
+          dX);
+      return true;
+    }
+    case 2: {
+      DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(
+          count_include_pad,
+          AveragePool2DBackwardNHWCCUDAKernel,
+          float,
+          num_blocks,
+          CAFFE_CUDA_NUM_THREADS,
+          context->cuda_stream(),
+          C,
+          X_dims[0],
+          X_dims[1],
+          Y_dims[0],
+          Y_dims[1],
+          kernel[0],
+          kernel[1],
+          stride[0],
+          stride[1],
+          pads[0],
+          pads[1],
+          dY,
+          dX);
+      return true;
+    }
+    case 3: {
+      DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1(
+          count_include_pad,
+          AveragePool3DBackwardNHWCCUDAKernel,
+          float,
+          num_blocks,
+          CAFFE_CUDA_NUM_THREADS,
+          context->cuda_stream(),
+          C,
+          X_dims[0],
+          X_dims[1],
+          X_dims[2],
+          Y_dims[0],
+          Y_dims[1],
+          Y_dims[2],
+          kernel[0],
+          kernel[1],
+          kernel[2],
+          stride[0],
+          stride[1],
+          stride[2],
+          pads[0],
+          pads[1],
+          pads[2],
+          dY,
+          dX);
+      return true;
+    }
+    default: {
+      CAFFE_THROW("Unsupported pooling dim: ", ndim);
+      return false;
+    }
   }
-  return true;
 }
 
+#undef DISPATCH_KERNEL_FUNCTION_BY_BOOL_WITH_TYPE_1
+
 namespace {
 
 template <typename T>
 __global__ void MaxPool1DForwardNCHWCUDAKernel(
-    const int K,
     const int X_size,
     const int Y_size,
     const int kernel,
@@ -944,18 +1129,20 @@ __global__ void MaxPool1DForwardNCHWCUDAKernel(
     const int pad,
     const T* X,
     T* Y) {
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
+  const int nc = blockIdx.x;
   const T* X_ptr = X + nc * X_size;
   T* Y_ptr = Y + nc * Y_size;
-  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
-  if (y < Y_size) {
+  for (int y = threadIdx.x; y < Y_size; y += blockDim.x) {
     const int x = y * stride;
     const int l = max(x - pad, 0);
     const int r = min(x - pad + kernel, X_size);
     T val = std::numeric_limits<T>::lowest();
     for (int i = l; i < r; ++i) {
+#if __CUDA_ARCH__ >= 350
+      val = max(val, __ldg(X_ptr + i));
+#else
       val = max(val, X_ptr[i]);
+#endif
     }
     Y_ptr[y] = val;
   }
@@ -981,7 +1168,11 @@ __global__ void MaxPool1DForwardNHWCCUDAKernel(
   for (int c = threadIdx.x; c < C; c += blockDim.x) {
     T val = std::numeric_limits<T>::lowest();
     for (int i = l; i < r; ++i) {
+#if __CUDA_ARCH__ >= 350
+      val = max(val, __ldg(X_ptr + i * C + c));
+#else
       val = max(val, X_ptr[i * C + c]);
+#endif
     }
     Y_ptr[y * C + c] = val;
   }
@@ -989,7 +1180,6 @@ __global__ void MaxPool1DForwardNHWCCUDAKernel(
 
 template <typename T>
 __global__ void MaxPool2DForwardNCHWCUDAKernel(
-    const int K,
     const int X_H,
     const int X_W,
     const int Y_H,
@@ -1004,27 +1194,28 @@ __global__ void MaxPool2DForwardNCHWCUDAKernel(
     T* Y) {
   const int X_HxW = X_H * X_W;
   const int Y_HxW = Y_H * Y_W;
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
+  const int nc = blockIdx.x / Y_H;
+  const int yh = blockIdx.x % Y_H;
   const T* X_ptr = X + nc * X_HxW;
   T* Y_ptr = Y + nc * Y_HxW;
-  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
-  if (y < Y_HxW) {
-    const int yh = y / Y_W;
-    const int yw = y % Y_W;
-    const int xh = yh * stride_h;
+  const int xh = yh * stride_h;
+  const int t = max(xh - pad_t, 0);
+  const int b = min(xh - pad_t + kernel_h, X_H);
+  for (int yw = threadIdx.x; yw < Y_W; yw += blockDim.x) {
     const int xw = yw * stride_w;
-    const int t = max(xh - pad_t, 0);
-    const int b = min(xh - pad_t + kernel_h, X_H);
     const int l = max(xw - pad_l, 0);
     const int r = min(xw - pad_l + kernel_w, X_W);
     T val = std::numeric_limits<T>::lowest();
     for (int i = t; i < b; ++i) {
       for (int j = l; j < r; ++j) {
+#if __CUDA_ARCH__ >= 350
+        val = max(val, __ldg(X_ptr + i * X_W + j));
+#else
         val = max(val, X_ptr[i * X_W + j]);
+#endif
       }
     }
-    Y_ptr[y] = val;
+    Y_ptr[yh * Y_W + yw] = val;
   }
 }
 
@@ -1061,7 +1252,11 @@ __global__ void MaxPool2DForwardNHWCCUDAKernel(
     T val = std::numeric_limits<T>::lowest();
     for (int i = t; i < b; ++i) {
       for (int j = l; j < r; ++j) {
+#if __CUDA_ARCH__ >= 350
+        val = max(val, __ldg(X_ptr + (i * X_W + j) * C + c));
+#else
         val = max(val, X_ptr[(i * X_W + j) * C + c]);
+#endif
       }
     }
     Y_ptr[y * C + c] = val;
@@ -1070,7 +1265,6 @@ __global__ void MaxPool2DForwardNHWCCUDAKernel(
 
 template <typename T>
 __global__ void MaxPool3DForwardNCHWCUDAKernel(
-    const int K,
     const int X_D,
     const int X_H,
     const int X_W,
@@ -1090,34 +1284,35 @@ __global__ void MaxPool3DForwardNCHWCUDAKernel(
     T* Y) {
   const int X_HxW = X_D * X_H * X_W;
   const int Y_HxW = Y_D * Y_H * Y_W;
-  const int nc = blockIdx.x / K;
-  const int block = blockIdx.x % K;
+  const int yy = blockIdx.x / Y_H;
+  const int nc = yy / Y_D;
+  const int yd = yy % Y_D;
+  const int yh = blockIdx.x % Y_H;
   const T* X_ptr = X + nc * X_HxW;
   T* Y_ptr = Y + nc * Y_HxW;
-  const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
-  if (y < Y_HxW) {
-    const int yy = y / Y_W;
-    const int yw = y % Y_W;
-    const int yh = yy % Y_H;
-    const int yd = yy / Y_H;
-    const int xd = yd * stride_d;
-    const int xh = yh * stride_h;
+  const int xd = yd * stride_d;
+  const int xh = yh * stride_h;
+  const int p = max(xd - pad_p, 0);
+  const int a = min(xd - pad_p + kernel_d, X_D);
+  const int t = max(xh - pad_t, 0);
+  const int b = min(xh - pad_t + kernel_h, X_H);
+  for (int yw = threadIdx.x; yw < Y_W; yw += blockDim.x) {
     const int xw = yw * stride_w;
-    const int p = max(xd - pad_p, 0);
-    const int a = min(xd - pad_p + kernel_d, X_D);
-    const int t = max(xh - pad_t, 0);
-    const int b = min(xh - pad_t + kernel_h, X_H);
     const int l = max(xw - pad_l, 0);
     const int r = min(xw - pad_l + kernel_w, X_W);
     T val = std::numeric_limits<T>::lowest();
     for (int i = p; i < a; ++i) {
       for (int j = t; j < b; ++j) {
         for (int k = l; k < r; ++k) {
+#if __CUDA_ARCH__ >= 350
+          val = max(val, __ldg(X_ptr + (i * X_H + j) * X_W + k));
+#else
           val = max(val, X_ptr[(i * X_H + j) * X_W + k]);
+#endif
         }
       }
     }
-    Y_ptr[y] = val;
+    Y_ptr[(yd * Y_H + yh) * Y_W + yw] = val;
   }
 }
 
@@ -1165,7 +1360,11 @@ __global__ void MaxPool3DForwardNHWCCUDAKernel(
     for (int i = p; i < a; ++i) {
       for (int j = t; j < b; ++j) {
         for (int k = l; k < r; ++k) {
+#if __CUDA_ARCH__ >= 350
+          val = max(val, __ldg(X_ptr + ((i * X_H + j) * X_W + k) * C + c));
+#else
           val = max(val, X_ptr[((i * X_H + j) * X_W + k) * C + c]);
+#endif
         }
       }
     }
@@ -1174,264 +1373,346 @@ __global__ void MaxPool3DForwardNHWCCUDAKernel(
 }
 
 template <typename T>
-__global__ void MaxPool1DBackwardNCHW(
-    const int nthreads,
-    const T* const bottom_data,
-    const T* const top_data,
-    const T* const top_diff,
-    const int num,
-    const int channels,
-    const int height,
-    const int pooled_height,
-    const int kernel_h,
-    const int stride_h,
-    const int pad_t,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int h = index % height + pad_t;
-    const int c = (index / height) % channels;
-    const int n = index / height / channels;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int top_offset = (n * channels + c) * pooled_height;
-    bottom_diff[index] = 0;
-    for (int ph = phstart; ph < phend; ++ph) {
-      int top_local_offset = top_offset + ph;
-      if (bottom_data[index] == top_data[top_local_offset]) {
-        bottom_diff[index] += top_diff[top_local_offset];
+__global__ void GlobalMaxPoolingBackwardNCHWCUDAKernel(
+    const int K,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int nc = blockIdx.x / K;
+  const int block = blockIdx.x % K;
+  const int x = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+  if (x < HxW) {
+#if __CUDA_ARCH__ >= 350
+    dX[nc * HxW + x] =
+        (__ldg(X + nc * HxW + x) == __ldg(Y + nc)) ? __ldg(dY + nc) : T(0);
+#else
+    dX[nc * HxW + x] = (X[nc * HxW + x] == Y[nc]) ? dY[nc] : T(0);
+#endif
+  }
+}
+
+template <typename T>
+__global__ void GlobalMaxPoolingBackwardNHWCCUDAKernel(
+    const int C,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int n = blockIdx.x / HxW;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+#if __CUDA_ARCH__ >= 350
+    dX[blockIdx.x * C + c] =
+        (__ldg(X + blockIdx.x * C + c) == __ldg(Y + n * C + c))
+        ? __ldg(dY + n * C + c)
+        : T(0);
+#else
+    dX[blockIdx.x * C + c] =
+        (X[blockIdx.x * C + c] == Y[n * C + c]) ? dY[n * C + c] : T(0);
+#endif
+  }
+}
+
+template <typename T>
+__global__ void MaxPool1DBackwardNCHWCUDAKernel(
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int nc = blockIdx.x;
+  const T* dY_ptr = dY + nc * Y_size;
+  const T* X_ptr = X + nc * X_size;
+  const T* Y_ptr = Y + nc * Y_size;
+  T* dX_ptr = dX + nc * X_size;
+  for (int x = threadIdx.x; x < X_size; x += blockDim.x) {
+    const int w = x + pad;
+    const int l = w < kernel ? 0 : (w - kernel) / stride + 1;
+    const int r = min(w / stride + 1, Y_size);
+    T sum = 0;
+    for (int i = l; i < r; ++i) {
+#if __CUDA_ARCH__ >= 350
+      if (__ldg(X_ptr + x) == __ldg(Y_ptr + i)) {
+        sum += __ldg(dY_ptr + i);
       }
+#else
+      if (X_ptr[x] == Y_ptr[i]) {
+        sum += dY_ptr[i];
+      }
+#endif
     }
+    dX_ptr[x] = sum;
   }
 }
 
 template <typename T>
-__global__ void MaxPool2DBackwardNCHW(
-    const int nthreads,
-    const T* const bottom_data,
-    const T* const top_data,
-    const T* const top_diff,
-    const int num,
-    const int channels,
-    const int height,
-    const int width,
-    const int pooled_height,
-    const int pooled_width,
-    const int kernel_h,
-    const int kernel_w,
-    const int stride_h,
-    const int stride_w,
-    const int pad_t,
-    const int pad_l,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int w = index % width + pad_l;
-    const int h = (index / width) % height + pad_t;
-    const int c = (index / width / height) % channels;
-    const int n = index / width / height / channels;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    const int top_offset = (n * channels + c) * pooled_height * pooled_width;
-    bottom_diff[index] = 0;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        int top_local_offset = top_offset + ph * pooled_width + pw;
-        if (bottom_data[index] == top_data[top_local_offset]) {
-          bottom_diff[index] += top_diff[top_local_offset];
-        }
+__global__ void MaxPool1DBackwardNHWCCUDAKernel(
+    const int C,
+    const int X_size,
+    const int Y_size,
+    const int kernel,
+    const int stride,
+    const int pad,
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int n = blockIdx.x / X_size;
+  const int x = blockIdx.x % X_size;
+  const int w = x + pad;
+  const int l = w < kernel ? 0 : (w - kernel) / stride + 1;
+  const int r = min(w / stride + 1, Y_size);
+  const T* dY_ptr = dY + n * Y_size * C;
+  const T* X_ptr = X + n * X_size * C;
+  const T* Y_ptr = Y + n * Y_size * C;
+  T* dX_ptr = dX + n * X_size * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = l; i < r; ++i) {
+#if __CUDA_ARCH__ >= 350
+      if (__ldg(X_ptr + x * C + c) == __ldg(Y_ptr + i * C + c)) {
+        sum += __ldg(dY_ptr + i * C + c);
+      }
+#else
+      if (X_ptr[x * C + c] == Y_ptr[i * C + c]) {
+        sum += dY_ptr[i * C + c];
       }
+#endif
     }
+    dX_ptr[x * C + c] = sum;
   }
 }
 
 template <typename T>
-__global__ void MaxPool3DBackwardNCHW(
-    const int nthreads,
-    const T* const bottom_data,
-    const T* const top_data,
-    const T* const top_diff,
-    const int num,
-    const int channels,
-    const int height,
-    const int width,
-    const int depth,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+__global__ void MaxPool2DBackwardNCHWCUDAKernel(
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int d = index % depth + pad_f;
-    const int w = (index / depth) % width + pad_l;
-    const int h = (index / depth / width) % height + pad_t;
-    const int c = (index / depth / width / height) % channels;
-    const int n = index / depth / width / height / channels;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    const int pdstart = (d < kernel_d) ? 0 : (d - kernel_d) / stride_d + 1;
-    const int pdend = min(d / stride_d + 1, pooled_depth);
-    const int top_offset =
-        (n * channels + c) * pooled_height * pooled_width * pooled_depth;
-    bottom_diff[index] = 0;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        for (int pd = pdstart; pd < pdend; ++pd) {
-          int top_local_offset =
-              top_offset + (ph * pooled_width + pw) * pooled_depth + pd;
-          if (bottom_data[index] == top_data[top_local_offset]) {
-            bottom_diff[index] += top_diff[top_local_offset];
-          }
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int nc = blockIdx.x / X_H;
+  const int xh = blockIdx.x % X_H;
+  const T* dY_ptr = dY + nc * Y_HxW;
+  const T* X_ptr = X + nc * X_HxW;
+  const T* Y_ptr = Y + nc * Y_HxW;
+  T* dX_ptr = dX + nc * X_HxW;
+  const int h = xh + pad_t;
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  for (int xw = threadIdx.x; xw < X_W; xw += blockDim.x) {
+    const int w = xw + pad_l;
+    const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+    const int r = min(w / stride_w + 1, Y_W);
+    const int x = xh * X_W + xw;
+    T sum = 0;
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        const int y = i * Y_W + j;
+#if __CUDA_ARCH__ >= 350
+        if (__ldg(X_ptr + x) == __ldg(Y_ptr + y)) {
+          sum += __ldg(dY_ptr + y);
         }
+#else
+        if (X_ptr[x] == Y_ptr[y]) {
+          sum += dY_ptr[y];
+        }
+#endif
       }
     }
+    dX_ptr[x] = sum;
   }
 }
 
 template <typename T>
-__global__ void MaxPool1DBackwardNHWC(
-    const int nthreads,
-    const T* const bottom_data,
-    const T* const top_data,
-    const T* const top_diff,
-    const int num,
-    const int height,
-    const int channels,
-    const int pooled_height,
+__global__ void MaxPool2DBackwardNHWCCUDAKernel(
+    const int C,
+    const int X_H,
+    const int X_W,
+    const int Y_H,
+    const int Y_W,
     const int kernel_h,
+    const int kernel_w,
     const int stride_h,
+    const int stride_w,
     const int pad_t,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int c = index % channels;
-    const int h = (index / channels) % height + pad_t;
-    const int n = index / channels / height;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int top_offset = n * pooled_height * channels + c;
-    bottom_diff[index] = 0;
-    for (int ph = phstart; ph < phend; ++ph) {
-      int top_local_offset = top_offset + ph * channels;
-      if (bottom_data[index] == top_data[top_local_offset]) {
-        bottom_diff[index] += top_diff[top_local_offset];
+    const int pad_l,
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int X_HxW = X_H * X_W;
+  const int Y_HxW = Y_H * Y_W;
+  const int n = blockIdx.x / X_HxW;
+  const int x = blockIdx.x % X_HxW;
+  const int h = x / X_W + pad_t;
+  const int w = x % X_W + pad_l;
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+  const int r = min(w / stride_w + 1, Y_W);
+  const T* dY_ptr = dY + n * Y_HxW * C;
+  const T* X_ptr = X + n * X_HxW * C;
+  const T* Y_ptr = Y + n * Y_HxW * C;
+  T* dX_ptr = dX + n * X_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = t; i < b; ++i) {
+      for (int j = l; j < r; ++j) {
+        const int y = i * Y_W + j;
+#if __CUDA_ARCH__ >= 350
+        if (__ldg(X_ptr + x * C + c) == __ldg(Y_ptr + y * C + c)) {
+          sum += __ldg(dY_ptr + y * C + c);
+        }
+#else
+        if (X_ptr[x * C + c] == Y_ptr[y * C + c]) {
+          sum += dY_ptr[y * C + c];
+        }
+#endif
       }
     }
+    dX_ptr[x * C + c] = sum;
   }
 }
 
 template <typename T>
-__global__ void MaxPool2DBackwardNHWC(
-    const int nthreads,
-    const T* const bottom_data,
-    const T* const top_data,
-    const T* const top_diff,
-    const int num,
-    const int height,
-    const int width,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
+__global__ void MaxPool3DBackwardNCHWCUDAKernel(
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int c = index % channels;
-    const int w = index / channels % width + pad_l;
-    const int h = (index / channels / width) % height + pad_t;
-    const int n = index / channels / width / height;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    const int top_offset = n * pooled_height * pooled_width * channels + c;
-    bottom_diff[index] = 0;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        int top_local_offset = top_offset + (ph * pooled_width + pw) * channels;
-        if (bottom_data[index] == top_data[top_local_offset]) {
-          bottom_diff[index] += top_diff[top_local_offset];
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int xx = blockIdx.x / X_H;
+  const int nc = xx / X_D;
+  const int xd = xx % X_D;
+  const int xh = blockIdx.x % X_H;
+  const T* dY_ptr = dY + nc * Y_HxW;
+  const T* X_ptr = X + nc * X_HxW;
+  const T* Y_ptr = Y + nc * Y_HxW;
+  T* dX_ptr = dX + nc * X_HxW;
+  const int d = xd + pad_p;
+  const int h = xh + pad_t;
+  const int p = d < kernel_d ? 0 : (d - kernel_d) / stride_d + 1;
+  const int a = min(d / stride_d + 1, Y_D);
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  for (int xw = threadIdx.x; xw < X_W; xw += blockDim.x) {
+    const int w = xw + pad_l;
+    const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+    const int r = min(w / stride_w + 1, Y_W);
+    const int x = (xd * X_H + xh) * X_W + xw;
+    T sum = 0;
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          const int y = (i * Y_H + j) * Y_W + k;
+#if __CUDA_ARCH__ >= 350
+          if (__ldg(X_ptr + x) == __ldg(Y_ptr + y)) {
+            sum += __ldg(dY_ptr + y);
+          }
+#else
+          if (X_ptr[x] == Y_ptr[y]) {
+            sum += dY_ptr[y];
+          }
+#endif
         }
       }
     }
+    dX_ptr[x] = sum;
   }
 }
 
 template <typename T>
-__global__ void MaxPool3DBackwardNHWC(
-    const int nthreads,
-    const T* const bottom_data,
-    const T* const top_data,
-    const T* const top_diff,
-    const int num,
-    const int height,
-    const int width,
-    const int depth,
-    const int channels,
-    const int pooled_height,
-    const int pooled_width,
-    const int pooled_depth,
+__global__ void MaxPool3DBackwardNHWCCUDAKernel(
+    const int C,
+    const int X_D,
+    const int X_H,
+    const int X_W,
+    const int Y_D,
+    const int Y_H,
+    const int Y_W,
+    const int kernel_d,
     const int kernel_h,
     const int kernel_w,
-    const int kernel_d,
+    const int stride_d,
     const int stride_h,
     const int stride_w,
-    const int stride_d,
+    const int pad_p,
     const int pad_t,
     const int pad_l,
-    const int pad_f,
-    T* const bottom_diff) {
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    // find out the local index
-    // find out the local offset
-    const int c = index % channels;
-    const int d = index / channels % depth + pad_f;
-    const int w = (index / depth / channels) % width + pad_l;
-    const int h = (index / channels / depth / width) % height + pad_t;
-    const int n = index / channels / depth / width / height;
-    const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
-    const int phend = min(h / stride_h + 1, pooled_height);
-    const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
-    const int pwend = min(w / stride_w + 1, pooled_width);
-    const int pdstart = (d < kernel_d) ? 0 : (d - kernel_d) / stride_d + 1;
-    const int pdend = min(d / stride_d + 1, pooled_depth);
-    const int top_offset =
-        n * pooled_height * pooled_width * pooled_depth * channels + c;
-    bottom_diff[index] = 0;
-    for (int ph = phstart; ph < phend; ++ph) {
-      for (int pw = pwstart; pw < pwend; ++pw) {
-        for (int pd = pdstart; pd < pdend; ++pd) {
-          int top_local_offset = top_offset +
-              ((ph * pooled_width + pw) * pooled_depth + d) * channels;
-          if (bottom_data[index] == top_data[top_local_offset]) {
-            bottom_diff[index] += top_diff[top_local_offset];
+    const T* dY,
+    const T* X,
+    const T* Y,
+    T* dX) {
+  const int X_HxW = X_D * X_H * X_W;
+  const int Y_HxW = Y_D * Y_H * Y_W;
+  const int n = blockIdx.x / X_HxW;
+  const int x = blockIdx.x % X_HxW;
+  const int xx = x / X_W;
+  const int d = xx / X_H + pad_p;
+  const int h = xx % X_H + pad_t;
+  const int w = x % X_W + pad_l;
+  const int p = d < kernel_d ? 0 : (d - kernel_d) / stride_d + 1;
+  const int a = min(d / stride_d + 1, Y_D);
+  const int t = h < kernel_h ? 0 : (h - kernel_h) / stride_h + 1;
+  const int b = min(h / stride_h + 1, Y_H);
+  const int l = w < kernel_w ? 0 : (w - kernel_w) / stride_w + 1;
+  const int r = min(w / stride_w + 1, Y_W);
+  const T* dY_ptr = dY + n * Y_HxW * C;
+  const T* X_ptr = X + n * X_HxW * C;
+  const T* Y_ptr = Y + n * Y_HxW * C;
+  T* dX_ptr = dX + n * X_HxW * C;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    T sum = 0;
+    for (int i = p; i < a; ++i) {
+      for (int j = t; j < b; ++j) {
+        for (int k = l; k < r; ++k) {
+          const int y = (i * Y_H + j) * Y_W + k;
+#if __CUDA_ARCH__ >= 350
+          if (__ldg(X_ptr + x * C + c) == __ldg(Y_ptr + y * C + c)) {
+            sum += __ldg(dY_ptr + y * C + c);
+          }
+#else
+          if (X_ptr[x * C + c] == Y_ptr[y * C + c]) {
+            sum += dY_ptr[y * C + c];
           }
+#endif
         }
       }
     }
+    dX_ptr[x * C + c] = sum;
   }
 }
 
@@ -1474,22 +1755,19 @@ bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
     const float* X,
     float* Y,
     CUDAContext* context) const {
-  // Split each image into K segments, each CUDA block handles one segment.
   const int ndim = X_dims.size();
-  const int Y_HxW = std::accumulate(
-      Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
-  const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
   switch (ndim) {
     case 1: {
+      const int num_blocks = N * C;
       MaxPool1DForwardNCHWCUDAKernel<float>
-          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-              K, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y);
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y);
       return true;
     }
     case 2: {
+      const int num_blocks = N * C * Y_dims[0];
       MaxPool2DForwardNCHWCUDAKernel<float>
-          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-              K,
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
               X_dims[0],
               X_dims[1],
               Y_dims[0],
@@ -1505,9 +1783,9 @@ bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
       return true;
     }
     case 3: {
+      const int num_blocks = N * C * Y_dims[0] * Y_dims[1];
       MaxPool3DForwardNCHWCUDAKernel<float>
-          <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-              K,
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
               X_dims[0],
               X_dims[1],
               X_dims[2],
@@ -1608,181 +1886,212 @@ bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
 }
 
 template <>
-bool PoolGradientOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNCHW() {
-  auto& X = Input(0);
-  auto& Y = Input(1);
-  auto& dY = Input(2);
-  CAFFE_ENFORCE_EQ(dY.ndim(), X.ndim());
+template <>
+bool MaxPoolFunctor<CUDAContext>::
+    GlobalPoolingBackward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* X,
+        const float* Y,
+        float* dX,
+        CUDAContext* context) const {
+  const int K = (HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+  GlobalMaxPoolingBackwardNCHWCUDAKernel<float>
+      <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          K, HxW, dY, X, Y, dX);
+  return true;
+}
 
-  auto* dX = Output(0, X.sizes(), at::dtype<float>());
-  vector<int> dims(X.sizes().begin() + 2, X.sizes().end());
-  ConvPoolOpBase<CUDAContext>::ComputePads(dims);
-  switch (kernel_.size()) {
-    case 1:
-      MaxPool1DBackwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              X.data<float>(),
-              Y.data<float>(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              dY.dim32(2),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              dX->template mutable_data<float>());
-      break;
-    case 2:
-      MaxPool2DBackwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              X.data<float>(),
-              Y.data<float>(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              dY.dim32(2),
-              dY.dim32(3),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              dX->template mutable_data<float>());
-      break;
-    case 3:
-      MaxPool3DBackwardNCHW<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              X.data<float>(),
-              Y.data<float>(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              dY.dim32(2),
-              dY.dim32(3),
-              dY.dim32(4),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              dX->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
-  }
+template <>
+template <>
+bool MaxPoolFunctor<CUDAContext>::
+    GlobalPoolingBackward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* dY,
+        const float* X,
+        const float* Y,
+        float* dX,
+        CUDAContext* context) const {
+  GlobalMaxPoolingBackwardNHWCCUDAKernel<float>
+      <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+          C, HxW, dY, X, Y, dX);
   return true;
 }
 
 template <>
-bool PoolGradientOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNHWC() {
-  auto& X = Input(0);
-  auto& Y = Input(1);
-  auto& dY = Input(2);
-  CAFFE_ENFORCE_EQ(dY.ndim(), X.ndim());
+template <>
+bool MaxPoolFunctor<CUDAContext>::Backward<float, StorageOrder::NCHW>(
+    const int N,
+    const int C,
+    const std::vector<int>& X_dims,
+    const std::vector<int>& Y_dims,
+    const std::vector<int>& kernel,
+    const std::vector<int>& /* dilation */,
+    const std::vector<int>& stride,
+    const std::vector<int>& pads,
+    const float* dY,
+    const float* X,
+    const float* Y,
+    float* dX,
+    CUDAContext* context) const {
+  const int ndim = X_dims.size();
+  switch (ndim) {
+    case 1: {
+      const int num_blocks = N * C;
+      MaxPool1DBackwardNCHWCUDAKernel<float>
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              X_dims[0],
+              Y_dims[0],
+              kernel[0],
+              stride[0],
+              pads[0],
+              dY,
+              X,
+              Y,
+              dX);
+      return true;
+    }
+    case 2: {
+      const int num_blocks = N * C * X_dims[0];
+      MaxPool2DBackwardNCHWCUDAKernel<float>
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              X_dims[0],
+              X_dims[1],
+              Y_dims[0],
+              Y_dims[1],
+              kernel[0],
+              kernel[1],
+              stride[0],
+              stride[1],
+              pads[0],
+              pads[1],
+              dY,
+              X,
+              Y,
+              dX);
+      return true;
+    }
+    case 3: {
+      const int num_blocks = N * C * X_dims[0] * X_dims[1];
+      MaxPool3DBackwardNCHWCUDAKernel<float>
+          <<<num_blocks, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              X_dims[0],
+              X_dims[1],
+              X_dims[2],
+              Y_dims[0],
+              Y_dims[1],
+              Y_dims[2],
+              kernel[0],
+              kernel[1],
+              kernel[2],
+              stride[0],
+              stride[1],
+              stride[2],
+              pads[0],
+              pads[1],
+              pads[2],
+              dY,
+              X,
+              Y,
+              dX);
+      return true;
+    }
+    default: {
+      CAFFE_THROW("Unsupported pooling dim: ", ndim);
+      return false;
+    }
+  }
+}
 
-  auto* dX = Output(0, X.sizes(), at::dtype<float>());
-  vector<int> dims(X.sizes().begin() + 1, X.sizes().end() - 1);
-  ConvPoolOpBase<CUDAContext>::ComputePads(dims);
-  switch (kernel_.size()) {
-    case 1:
-      MaxPool1DBackwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              X.data<float>(),
-              Y.data<float>(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              dY.dim32(1),
-              kernel_h(),
-              stride_h(),
-              pad_t(),
-              dX->template mutable_data<float>());
-      break;
-    case 2:
-      MaxPool2DBackwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              X.data<float>(),
-              Y.data<float>(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              dY.dim32(1),
-              dY.dim32(2),
-              kernel_h(),
-              kernel_w(),
-              stride_h(),
-              stride_w(),
-              pad_t(),
-              pad_l(),
-              dX->template mutable_data<float>());
-      break;
-    case 3:
-      MaxPool3DBackwardNHWC<float>
-          <<<CAFFE_GET_BLOCKS(X.size()),
-             CAFFE_CUDA_NUM_THREADS,
-             0,
-             context_.cuda_stream()>>>(
-              X.size(),
-              X.data<float>(),
-              Y.data<float>(),
-              dY.data<float>(),
-              X.dim32(0),
-              X.dim32(1),
-              X.dim32(2),
-              X.dim32(3),
-              X.dim32(4),
-              dY.dim32(1),
-              dY.dim32(2),
-              dY.dim32(3),
-              kernel_h(),
-              kernel_w(),
-              kernel_[2],
-              stride_h(),
-              stride_w(),
-              stride_[2],
-              pad_t(),
-              pad_l(),
-              pads_[2],
-              dX->template mutable_data<float>());
-      break;
-    default:
-      CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
+template <>
+template <>
+bool MaxPoolFunctor<CUDAContext>::Backward<float, StorageOrder::NHWC>(
+    const int N,
+    const int C,
+    const std::vector<int>& X_dims,
+    const std::vector<int>& Y_dims,
+    const std::vector<int>& kernel,
+    const std::vector<int>& /* dilation */,
+    const std::vector<int>& stride,
+    const std::vector<int>& pads,
+    const float* dY,
+    const float* X,
+    const float* Y,
+    float* dX,
+    CUDAContext* context) const {
+  const int ndim = X_dims.size();
+  const int X_HxW = std::accumulate(
+      X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
+  switch (ndim) {
+    case 1: {
+      MaxPool1DBackwardNHWCCUDAKernel<float>
+          <<<N * X_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              C,
+              X_dims[0],
+              Y_dims[0],
+              kernel[0],
+              stride[0],
+              pads[0],
+              dY,
+              X,
+              Y,
+              dX);
+      return true;
+    }
+    case 2: {
+      MaxPool2DBackwardNHWCCUDAKernel<float>
+          <<<N * X_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              C,
+              X_dims[0],
+              X_dims[1],
+              Y_dims[0],
+              Y_dims[1],
+              kernel[0],
+              kernel[1],
+              stride[0],
+              stride[1],
+              pads[0],
+              pads[1],
+              dY,
+              X,
+              Y,
+              dX);
+      return true;
+    }
+    case 3: {
+      MaxPool3DBackwardNHWCCUDAKernel<float>
+          <<<N * X_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+              C,
+              X_dims[0],
+              X_dims[1],
+              X_dims[2],
+              Y_dims[0],
+              Y_dims[1],
+              Y_dims[2],
+              kernel[0],
+              kernel[1],
+              kernel[2],
+              stride[0],
+              stride[1],
+              stride[2],
+              pads[0],
+              pads[1],
+              pads[2],
+              dY,
+              X,
+              Y,
+              dX);
+      return true;
+    }
+    default: {
+      CAFFE_THROW("Unsupported pooling dim: ", ndim);
+      return false;
+    }
   }
-  return true;
 }
 
 REGISTER_CUDA_OPERATOR(
@@ -1790,55 +2099,55 @@ REGISTER_CUDA_OPERATOR(
     PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePoolGradient,
-    PoolGradientOp<float, CUDAContext, AveragePool>);
+    PoolGradientOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     AveragePool1D,
     PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePool1DGradient,
-    PoolGradientOp<float, CUDAContext, AveragePool>);
+    PoolGradientOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     AveragePool2D,
     PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePool2DGradient,
-    PoolGradientOp<float, CUDAContext, AveragePool>);
+    PoolGradientOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     AveragePool3D,
     PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     AveragePool3DGradient,
-    PoolGradientOp<float, CUDAContext, AveragePool>);
+    PoolGradientOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     MaxPool,
     PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPoolGradient,
-    PoolGradientOp<float, CUDAContext, MaxPool>);
+    PoolGradientOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     MaxPool1D,
     PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPool1DGradient,
-    PoolGradientOp<float, CUDAContext, MaxPool>);
+    PoolGradientOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     MaxPool2D,
     PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPool2DGradient,
-    PoolGradientOp<float, CUDAContext, MaxPool>);
+    PoolGradientOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 
 REGISTER_CUDA_OPERATOR(
     MaxPool3D,
     PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 REGISTER_CUDA_OPERATOR(
     MaxPool3DGradient,
-    PoolGradientOp<float, CUDAContext, MaxPool>);
+    PoolGradientOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
 
 } // namespace caffe2
index 4880a50..8c9db86 100644 (file)
@@ -115,7 +115,16 @@ class PoolGradientOp final : public ConvPoolOpBase<Context> {
     const int C = X.dim32(1);
     const std::vector<int> X_HW_dims = GetDims(X);
     const std::vector<int> Y_HW_dims = GetDims(Y);
-    ConvPoolOpBase<CPUContext>::ComputePads(X_HW_dims);
+    ConvPoolOpBase<Context>::ComputePads(X_HW_dims);
+    const T* dY_data = dY.template data<T>();
+    const T* X_data = X.template data<T>();
+    const T* Y_data = Y.template data<T>();
+    T* dX_data = dX->template mutable_data<T>();
+    if (global_pooling_) {
+      const int HxW = X.numel() / (N * C);
+      return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>(
+          N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
+    }
     return functor_.template Backward<T, StorageOrder::NCHW>(
         N,
         C,
@@ -125,10 +134,10 @@ class PoolGradientOp final : public ConvPoolOpBase<Context> {
         dilation_,
         stride_,
         pads_,
-        dY.template data<T>(),
-        X.template data<T>(),
-        Y.template data<T>(),
-        dX->template mutable_data<T>(),
+        dY_data,
+        X_data,
+        Y_data,
+        dX_data,
         &context_);
   }
 
@@ -142,7 +151,16 @@ class PoolGradientOp final : public ConvPoolOpBase<Context> {
     const int C = X.dim32(ndim - 1);
     const std::vector<int> X_HW_dims = GetDims(X);
     const std::vector<int> Y_HW_dims = GetDims(Y);
-    ConvPoolOpBase<CPUContext>::ComputePads(X_HW_dims);
+    ConvPoolOpBase<Context>::ComputePads(X_HW_dims);
+    const T* dY_data = dY.template data<T>();
+    const T* X_data = X.template data<T>();
+    const T* Y_data = Y.template data<T>();
+    T* dX_data = dX->template mutable_data<T>();
+    if (global_pooling_) {
+      const int HxW = X.numel() / (N * C);
+      return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>(
+          N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
+    }
     return functor_.template Backward<T, StorageOrder::NHWC>(
         N,
         C,
@@ -152,10 +170,10 @@ class PoolGradientOp final : public ConvPoolOpBase<Context> {
         dilation_,
         stride_,
         pads_,
-        dY.template data<T>(),
-        X.template data<T>(),
-        Y.template data<T>(),
-        dX->template mutable_data<T>(),
+        dY_data,
+        X_data,
+        Y_data,
+        dX_data,
         &context_);
   }
 
@@ -193,6 +211,17 @@ struct AveragePoolFunctor {
       Context* context) const;
 
   template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingBackward(
+      int N,
+      int C,
+      int HxW,
+      const T* dY,
+      const T* X,
+      const T* Y,
+      T* dX,
+      Context* context) const;
+
+  template <typename T, StorageOrder kOrder>
   bool Backward(
       int N,
       int C,
@@ -239,6 +268,17 @@ struct MaxPoolFunctor {
       Context* context) const;
 
   template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingBackward(
+      int N,
+      int C,
+      int HxW,
+      const T* dY,
+      const T* X,
+      const T* Y,
+      T* dX,
+      Context* context) const;
+
+  template <typename T, StorageOrder kOrder>
   bool Backward(
       int N,
       int C,
index 188006b..2d0ee30 100644 (file)
@@ -335,6 +335,69 @@ class TestPooling(hu.HypothesisTestCase):
         if 'MaxPool' not in op_type:
             self.assertGradientChecks(gc, op, [X], 0, [0])
 
+    @given(op_type=st.sampled_from(["MaxPool", "MaxPoolND"]),
+           dim=st.integers(1, 3),
+           N=st.integers(1, 3),
+           C=st.integers(1, 3),
+           D=st.integers(3, 5),
+           H=st.integers(3, 5),
+           W=st.integers(3, 5),
+           kernel=st.integers(1, 3),
+           stride=st.integers(1, 3),
+           pad=st.integers(0, 2),
+           order=st.sampled_from(["NCHW", "NHWC"]),
+           engine=st.sampled_from(["", "CUDNN"]),
+           **hu.gcs)
+    def test_max_pool_grad(
+            self, op_type, dim, N, C, D, H, W, kernel, stride, pad, order,
+            engine, gc, dc):
+        assume(pad < kernel)
+        assume(dim > 1 or engine == "")
+        if hiputl.run_in_hip(gc, dc):
+            if dim != 2:
+                assume(engine != "CUDNN")
+            elif engine == "CUDNN":
+                assume(order == "NCHW")
+
+        if op_type.endswith("ND"):
+            op_type = op_type.replace("N", str(dim))
+
+        op = core.CreateOperator(
+            op_type,
+            ["X"],
+            ["Y"],
+            kernels=[kernel] * dim,
+            strides=[stride] * dim,
+            pads=[pad] * dim * 2,
+            order=order,
+            engine=engine,
+        )
+
+        if dim == 1:
+            size = W
+            dims = [N, C, W]
+            axes = [0, 2, 1]
+        elif dim == 2:
+            size = H * W
+            dims = [N, C, H, W]
+            axes = [0, 2, 3, 1]
+        else:
+            size = D * H * W
+            dims = [N, C, D, H, W]
+            axes = [0, 2, 3, 4, 1]
+
+        X = np.zeros((N * C, size)).astype(np.float32)
+        for i in range(N * C):
+            X[i, :] = np.arange(size, dtype=np.float32) / size
+            np.random.shuffle(X[i, :])
+        X = X.reshape(dims)
+        if order == "NHWC":
+            X = np.transpose(X, axes)
+
+        self.assertDeviceChecks(dc, op, [X], [0])
+        self.assertGradientChecks(
+            gc, op, [X], 0, [0], threshold=5e-2, stepsize=1e-3)
+
 
 if __name__ == "__main__":
     import unittest