From 586d0303113123f59c1992c57714bec60b90fd76 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Fri, 11 Jan 2019 22:35:12 -0800 Subject: [PATCH] Add global pooling specialization and also update MaxPooling on GPU (#15824) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15824 Add global pooling specialization and also update MaxPooling on GPU Reviewed By: houseroad Differential Revision: D13596340 fbshipit-source-id: c8a42aa69ee92c383c9f19d3ed57b77cb3e5bd28 --- caffe2/operators/pool_op.cc | 86 ++++++ caffe2/operators/pool_op.cu | 725 +++++++++++++++++++++++--------------------- caffe2/operators/pool_op.h | 32 ++ 3 files changed, 503 insertions(+), 340 deletions(-) diff --git a/caffe2/operators/pool_op.cc b/caffe2/operators/pool_op.cc index 0c58431..8f5d522 100644 --- a/caffe2/operators/pool_op.cc +++ b/caffe2/operators/pool_op.cc @@ -6,6 +6,7 @@ #include "caffe2/operators/pool_op_util.h" #include "caffe2/utils/eigen_utils.h" +#include "caffe2/utils/math.h" namespace caffe2 { @@ -561,6 +562,48 @@ void RunMaxPool3D( } // namespace +template <> +template <> +bool AveragePoolFunctor:: + GlobalPoolingForward( + const int N, + const int C, + const int HxW, + const float* X, + float* Y, + CPUContext* context) const { + const std::array dims = {N * C, HxW}; + const int axis = 1; + math::ReduceMean( + 2, dims.data(), 1, &axis, 1.0f, X, Y, context); + return true; +} + +template <> +template <> +bool AveragePoolFunctor:: + GlobalPoolingForward( + const int N, + const int C, + const int HxW, + const float* X, + float* Y, + CPUContext* context) const { + math::Set(N * C, 0.0f, Y, context); + const float* X_ptr = X; + float* Y_ptr = Y; + for (int i = 0; i < N; ++i) { + for (int j = 0; j < HxW; ++j) { + math::Add(C, Y_ptr, X_ptr + j * C, Y_ptr, context); + } + X_ptr += HxW * C; + Y_ptr += C; + } + math::Scale( + N * C, 1.0f / static_cast(HxW), Y, Y, context); + return true; +} + #define CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD(T, kOrder) \ template <> \ template <> \ @@ -667,6 +710,49 @@ CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD(float, StorageOrder::NCHW) CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD(float, StorageOrder::NHWC) #undef CAFFE2_SPECIALIZED_AVERAGE_POOL_FUNCTOR_FORWARD +template <> +template <> +bool MaxPoolFunctor:: + GlobalPoolingForward( + const int N, + const int C, + const int HxW, + const float* X, + float* Y, + CPUContext* context) const { + const std::array dims = {N * C, HxW}; + const int axis = 1; + math::ReduceMax( + 2, dims.data(), 1, &axis, 1.0f, X, Y, context); + return true; +} + +template <> +template <> +bool MaxPoolFunctor:: + GlobalPoolingForward( + const int N, + const int C, + const int HxW, + const float* X, + float* Y, + CPUContext* context) const { + math::Set( + N * C, std::numeric_limits::lowest(), Y, context); + const float* X_ptr = X; + float* Y_ptr = Y; + for (int i = 0; i < N; ++i) { + ConstEigenArrayMap X_arr(X_ptr, C, HxW); + EigenVectorArrayMap Y_arr(Y_ptr, C); + for (int j = 0; j < HxW; ++j) { + Y_arr = Y_arr.max(X_arr.col(j)); + } + X_ptr += HxW * C; + Y_ptr += C; + } + return true; +} + #define CAFFE2_SPECIALIZED_MAX_POOL_FUNCTOR_FORWARD(T, kOrder) \ template <> \ template <> \ diff --git a/caffe2/operators/pool_op.cu b/caffe2/operators/pool_op.cu index 2a18be9..bca9e52 100644 --- a/caffe2/operators/pool_op.cu +++ b/caffe2/operators/pool_op.cu @@ -1,11 +1,13 @@ // TODO(ataei): reduce the apparent redundancy of all the code below. #include "caffe2/operators/pool_op.h" -#include +#include #include +#include #include #include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math.h" namespace caffe2 { @@ -574,6 +576,29 @@ __global__ void Ave3DPoolBackwardNHWC( } // namespace template <> +template +bool AveragePoolFunctor::GlobalPoolingForward( + const int N, + const int C, + const int HxW, + const T* X, + T* Y, + CUDAContext* context) const { + if (kOrder == StorageOrder::NCHW) { + const std::array dims = {N * C, HxW}; + const int axis = 1; + math::ReduceMean( + 2, dims.data(), 1, &axis, 1.0f, X, Y, context); + } else { + const std::array dims = {N, HxW, C}; + const int axis = 1; + math::ReduceMean( + 3, dims.data(), 1, &axis, 1.0f, X, Y, context); + } + return true; +} + +template <> template <> bool AveragePoolFunctor::Forward( const int N, @@ -587,6 +612,7 @@ bool AveragePoolFunctor::Forward( const float* X, float* Y, CUDAContext* context) const { + // Split each image into K segments, each CUDA block handles one segment. const int ndim = X_dims.size(); const int Y_HxW = std::accumulate( Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies()); @@ -670,6 +696,7 @@ bool AveragePoolFunctor::Forward( const float* X, float* Y, CUDAContext* context) const { + // Each CUDA block handles one point, one thread per channel. const int ndim = X_dims.size(); const int Y_HxW = std::accumulate( Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies()); @@ -908,251 +935,241 @@ bool PoolGradientOp:: namespace { template -__global__ void MaxPool1DForwardNCHW( - const int nthreads, - const T* bottom_data, - const int channels, - const int height, - const int pooled_height, - const int kernel_h, - const int stride_h, - const int pad_t, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int ph = index % pooled_height; - int c = (index / pooled_height) % channels; - int n = index / pooled_height / channels; - int hstart = ph * stride_h - pad_t; - int hend = min(hstart + kernel_h, height); - hstart = max(hstart, 0); - T maxval = -FLT_MAX; - const T* bdata_offset = bottom_data + n * channels * height; - for (int h = hstart; h < hend; ++h) { - int idx = c * height + h; - if (bdata_offset[idx] > maxval) { - maxval = bdata_offset[idx]; - } +__global__ void MaxPool1DForwardNCHWCUDAKernel( + const int K, + const int X_size, + const int Y_size, + const int kernel, + const int stride, + const int pad, + const T* X, + T* Y) { + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const T* X_ptr = X + nc * X_size; + T* Y_ptr = Y + nc * Y_size; + const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS; + if (y < Y_size) { + const int x = y * stride; + const int l = max(x - pad, 0); + const int r = min(x - pad + kernel, X_size); + T val = std::numeric_limits::lowest(); + for (int i = l; i < r; ++i) { + val = max(val, X_ptr[i]); } - top_data[index] = maxval; + Y_ptr[y] = val; } } template -__global__ void MaxPool2DForwardNCHW( - const int nthreads, - const T* bottom_data, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_t, - const int pad_l, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - T maxval = -FLT_MAX; - const T* bdata_offset = bottom_data + n * channels * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int idx = c * height * width + h * width + w; - if (bdata_offset[idx] > maxval) { - maxval = bdata_offset[idx]; - } - } +__global__ void MaxPool1DForwardNHWCCUDAKernel( + const int C, + const int X_size, + const int Y_size, + const int kernel, + const int stride, + const int pad, + const T* X, + T* Y) { + const int n = blockIdx.x / Y_size; + const int y = blockIdx.x % Y_size; + const int x = y * stride; + const int l = max(x - pad, 0); + const int r = min(x - pad + kernel, X_size); + const T* X_ptr = X + n * X_size * C; + T* Y_ptr = Y + n * Y_size * C; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + T val = std::numeric_limits::lowest(); + for (int i = l; i < r; ++i) { + val = max(val, X_ptr[i * C + c]); } - top_data[index] = maxval; + Y_ptr[y * C + c] = val; } } template -__global__ void MaxPool3DForwardNCHW( - const int nthreads, - const T* bottom_data, - const int channels, - const int height, - const int width, - const int depth, - const int pooled_height, - const int pooled_width, - const int pooled_depth, +__global__ void MaxPool2DForwardNCHWCUDAKernel( + const int K, + const int X_H, + const int X_W, + const int Y_H, + const int Y_W, const int kernel_h, const int kernel_w, - const int kernel_d, const int stride_h, const int stride_w, - const int stride_d, const int pad_t, const int pad_l, - const int pad_f, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int pd = index % pooled_depth; - int pw = (index / pooled_depth) % pooled_width; - int ph = (index / pooled_depth / pooled_width) % pooled_height; - int c = (index / pooled_depth / pooled_width / pooled_height) % channels; - int n = index / pooled_depth / pooled_width / pooled_height / channels; - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - int dstart = pd * stride_d - pad_f; - int dend = min(dstart + kernel_d, depth); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - dstart = max(dstart, 0); - T maxval = -FLT_MAX; - const T* bdata_offset = bottom_data + n * channels * height * width * depth; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - for (int d = dstart; d < dend; ++d) { - int idx = ((c * height + h) * width + w) * depth + d; - if (bdata_offset[idx] > maxval) { - maxval = bdata_offset[idx]; - } - } + const T* X, + T* Y) { + const int X_HxW = X_H * X_W; + const int Y_HxW = Y_H * Y_W; + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const T* X_ptr = X + nc * X_HxW; + T* Y_ptr = Y + nc * Y_HxW; + const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS; + if (y < Y_HxW) { + const int yh = y / Y_W; + const int yw = y % Y_W; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + T val = std::numeric_limits::lowest(); + for (int i = t; i < b; ++i) { + for (int j = l; j < r; ++j) { + val = max(val, X_ptr[i * X_W + j]); } } - top_data[index] = maxval; + Y_ptr[y] = val; } } template -__global__ void MaxPool1DForwardNHWC( - const int nthreads, - const T* bottom_data, - const int height, - const int channels, - const int pooled_height, +__global__ void MaxPool2DForwardNHWCCUDAKernel( + const int C, + const int X_H, + const int X_W, + const int Y_H, + const int Y_W, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index; - int c = n % channels; - n /= channels; - int hstart = (n % pooled_height) * stride_h - pad_t; - n /= pooled_height; - int hend = min(hstart + kernel_h, height); - hstart = max(hstart, 0); - T maxval = -FLT_MAX; - const T* bdata_offset = bottom_data + n * height * channels; - for (int h = hstart; h < hend; ++h) { - int idx = h * channels + c; - if (bdata_offset[idx] > maxval) { - maxval = bdata_offset[idx]; + const int pad_l, + const T* X, + T* Y) { + const int X_HxW = X_H * X_W; + const int Y_HxW = Y_H * Y_W; + const int n = blockIdx.x / Y_HxW; + const int y = blockIdx.x % Y_HxW; + const int yh = y / Y_W; + const int yw = y % Y_W; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + const T* X_ptr = X + n * X_HxW * C; + T* Y_ptr = Y + n * Y_HxW * C; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + T val = std::numeric_limits::lowest(); + for (int i = t; i < b; ++i) { + for (int j = l; j < r; ++j) { + val = max(val, X_ptr[(i * X_W + j) * C + c]); } } - top_data[index] = maxval; + Y_ptr[y * C + c] = val; } } template -__global__ void MaxPool2DForwardNHWC( - const int nthreads, - const T* bottom_data, - const int height, - const int width, - const int channels, - const int pooled_height, - const int pooled_width, +__global__ void MaxPool3DForwardNCHWCUDAKernel( + const int K, + const int X_D, + const int X_H, + const int X_W, + const int Y_D, + const int Y_H, + const int Y_W, + const int kernel_d, const int kernel_h, const int kernel_w, + const int stride_d, const int stride_h, const int stride_w, + const int pad_p, const int pad_t, const int pad_l, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index; - int c = n % channels; - n /= channels; - int wstart = (n % pooled_width) * stride_w - pad_l; - n /= pooled_width; - int hstart = (n % pooled_height) * stride_h - pad_t; - n /= pooled_height; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - T maxval = -FLT_MAX; - const T* bdata_offset = bottom_data + n * height * width * channels; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int idx = (h * width + w) * channels + c; - if (bdata_offset[idx] > maxval) { - maxval = bdata_offset[idx]; + const T* X, + T* Y) { + const int X_HxW = X_D * X_H * X_W; + const int Y_HxW = Y_D * Y_H * Y_W; + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const T* X_ptr = X + nc * X_HxW; + T* Y_ptr = Y + nc * Y_HxW; + const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS; + if (y < Y_HxW) { + const int yy = y / Y_W; + const int yw = y % Y_W; + const int yh = yy % Y_H; + const int yd = yy / Y_H; + const int xd = yd * stride_d; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int p = max(xd - pad_p, 0); + const int a = min(xd - pad_p + kernel_d, X_D); + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + T val = std::numeric_limits::lowest(); + for (int i = p; i < a; ++i) { + for (int j = t; j < b; ++j) { + for (int k = l; k < r; ++k) { + val = max(val, X_ptr[(i * X_H + j) * X_W + k]); } } } - top_data[index] = maxval; + Y_ptr[y] = val; } } template -__global__ void MaxPool3DForwardNHWC( - const int nthreads, - const T* bottom_data, - const int height, - const int width, - const int depth, - const int channels, - const int pooled_height, - const int pooled_width, - const int pooled_depth, +__global__ void MaxPool3DForwardNHWCCUDAKernel( + const int C, + const int X_D, + const int X_H, + const int X_W, + const int Y_D, + const int Y_H, + const int Y_W, + const int kernel_d, const int kernel_h, const int kernel_w, - const int kernel_d, + const int stride_d, const int stride_h, const int stride_w, - const int stride_d, + const int pad_p, const int pad_t, const int pad_l, - const int pad_f, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index; - int c = n % channels; - n /= channels; - int dstart = (n % pooled_depth) * stride_d - pad_f; - n /= pooled_depth; - int wstart = (n % pooled_width) * stride_w - pad_l; - n /= pooled_width; - int hstart = (n % pooled_height) * stride_h - pad_t; - n /= pooled_height; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - int dend = min(dstart + kernel_d, depth); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - dstart = max(dstart, 0); - T maxval = -FLT_MAX; - const T* bdata_offset = bottom_data + n * height * width * depth * channels; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - for (int d = dstart; d < dend; ++d) { - int idx = ((h * width + w) * depth + d) * channels + c; - if (bdata_offset[idx] > maxval) { - maxval = bdata_offset[idx]; - } + const T* X, + T* Y) { + const int X_HxW = X_D * X_H * X_W; + const int Y_HxW = Y_D * Y_H * Y_W; + const int n = blockIdx.x / Y_HxW; + const int y = blockIdx.x % Y_HxW; + const int yy = y / Y_W; + const int yw = y % Y_W; + const int yh = yy % Y_H; + const int yd = yy / Y_H; + const int xd = yd * stride_d; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int p = max(xd - pad_p, 0); + const int a = min(xd - pad_p + kernel_d, X_D); + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + const T* X_ptr = X + n * X_HxW * C; + T* Y_ptr = Y + n * Y_HxW * C; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + T val = std::numeric_limits::lowest(); + for (int i = p; i < a; ++i) { + for (int j = t; j < b; ++j) { + for (int k = l; k < r; ++k) { + val = max(val, X_ptr[((i * X_H + j) * X_W + k) * C + c]); } } } - top_data[index] = maxval; + Y_ptr[y * C + c] = val; } } @@ -1417,158 +1434,177 @@ __global__ void MaxPool3DBackwardNHWC( } } } + } // namespace template <> -bool PoolOp::RunOnDeviceWithOrderNCHW() { - auto& X = Input(0); - auto* Y = Output(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); - int output_size = Y->size(); - switch (kernel_.size()) { - case 1: - MaxPool1DForwardNCHW - <<>>( - output_size, - X.data(), - X.dim32(1), - X.dim32(2), - Y->dim32(2), - kernel_h(), - stride_h(), - pad_t(), - Y->template mutable_data()); - break; - case 2: - MaxPool2DForwardNCHW - <<>>( - output_size, - X.data(), - X.dim32(1), - X.dim32(2), - X.dim32(3), - Y->dim32(2), - Y->dim32(3), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l(), - Y->template mutable_data()); - break; - case 3: - MaxPool3DForwardNCHW - <<>>( - output_size, - X.data(), - X.dim32(1), - X.dim32(2), - X.dim32(3), - X.dim32(4), - Y->dim32(2), - Y->dim32(3), - Y->dim32(4), - kernel_h(), - kernel_w(), - kernel_[2], - stride_h(), - stride_w(), - stride_[2], - pad_t(), - pad_l(), - pads_[2], - Y->template mutable_data()); - break; - default: - CAFFE_THROW("Unsupported pooling size : ", kernel_.size()); +template +bool MaxPoolFunctor::GlobalPoolingForward( + const int N, + const int C, + const int HxW, + const T* X, + T* Y, + CUDAContext* context) const { + if (kOrder == StorageOrder::NCHW) { + const std::array dims = {N * C, HxW}; + const int axis = 1; + math::ReduceMax( + 2, dims.data(), 1, &axis, 1.0f, X, Y, context); + } else { + const std::array dims = {N, HxW, C}; + const int axis = 1; + math::ReduceMax( + 3, dims.data(), 1, &axis, 1.0f, X, Y, context); } return true; } template <> -bool PoolOp::RunOnDeviceWithOrderNHWC() { - auto& X = Input(0); - auto* Y = Output(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(X.ndim() - 1)); - int output_size = Y->size(); - switch (kernel_.size()) { - case 1: - MaxPool1DForwardNHWC - <<>>( - output_size, - X.data(), - X.dim32(1), - X.dim32(2), - Y->dim32(1), - kernel_h(), - stride_h(), - pad_t(), - Y->template mutable_data()); - break; - case 2: - MaxPool2DForwardNHWC - <<>>( - output_size, - X.data(), - X.dim32(1), - X.dim32(2), - X.dim32(3), - Y->dim32(1), - Y->dim32(2), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l(), - Y->template mutable_data()); - break; - case 3: - MaxPool3DForwardNHWC - <<>>( - output_size, - X.data(), - X.dim32(1), - X.dim32(2), - X.dim32(3), - X.dim32(4), - Y->dim32(1), - Y->dim32(2), - Y->dim32(3), - kernel_h(), - kernel_w(), - kernel_[2], - stride_h(), - stride_w(), - stride_[2], - pad_t(), - pad_l(), - pads_[2], - Y->template mutable_data()); - break; - default: - CAFFE_THROW("Unsupported pooling size : ", kernel_.size()); +template <> +bool MaxPoolFunctor::Forward( + const int N, + const int C, + const std::vector& X_dims, + const std::vector& Y_dims, + const std::vector& kernel, + const std::vector& /* dilation */, + const std::vector& stride, + const std::vector& pads, + const float* X, + float* Y, + CUDAContext* context) const { + // Split each image into K segments, each CUDA block handles one segment. + const int ndim = X_dims.size(); + const int Y_HxW = std::accumulate( + Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies()); + const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; + switch (ndim) { + case 1: { + MaxPool1DForwardNCHWCUDAKernel + <<cuda_stream()>>>( + K, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y); + return true; + } + case 2: { + MaxPool2DForwardNCHWCUDAKernel + <<cuda_stream()>>>( + K, + X_dims[0], + X_dims[1], + Y_dims[0], + Y_dims[1], + kernel[0], + kernel[1], + stride[0], + stride[1], + pads[0], + pads[1], + X, + Y); + return true; + } + case 3: { + MaxPool3DForwardNCHWCUDAKernel + <<cuda_stream()>>>( + K, + X_dims[0], + X_dims[1], + X_dims[2], + Y_dims[0], + Y_dims[1], + Y_dims[2], + kernel[0], + kernel[1], + kernel[2], + stride[0], + stride[1], + stride[2], + pads[0], + pads[1], + pads[2], + X, + Y); + return true; + } + default: { + CAFFE_THROW("Unsupported pooling dim: ", ndim); + return false; + } + } +} + +template <> +template <> +bool MaxPoolFunctor::Forward( + const int N, + const int C, + const std::vector& X_dims, + const std::vector& Y_dims, + const std::vector& kernel, + const std::vector& /* dilation */, + const std::vector& stride, + const std::vector& pads, + const float* X, + float* Y, + CUDAContext* context) const { + // Each CUDA block handles one point, one thread per channel. + const int ndim = X_dims.size(); + const int Y_HxW = std::accumulate( + Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies()); + switch (ndim) { + case 1: { + MaxPool1DForwardNHWCCUDAKernel + <<cuda_stream()>>>( + C, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y); + return true; + } + case 2: { + MaxPool2DForwardNHWCCUDAKernel + <<cuda_stream()>>>( + C, + X_dims[0], + X_dims[1], + Y_dims[0], + Y_dims[1], + kernel[0], + kernel[1], + stride[0], + stride[1], + pads[0], + pads[1], + X, + Y); + return true; + } + case 3: { + MaxPool3DForwardNHWCCUDAKernel + <<cuda_stream()>>>( + C, + X_dims[0], + X_dims[1], + X_dims[2], + Y_dims[0], + Y_dims[1], + Y_dims[2], + kernel[0], + kernel[1], + kernel[2], + stride[0], + stride[1], + stride[2], + pads[0], + pads[1], + pads[2], + X, + Y); + return true; + } + default: { + CAFFE_THROW("Unsupported pooling dim: ", ndim); + return false; + } } - return true; } template <> @@ -1777,23 +1813,32 @@ REGISTER_CUDA_OPERATOR( AveragePool3DGradient, PoolGradientOp); -REGISTER_CUDA_OPERATOR(MaxPool, PoolOp); +REGISTER_CUDA_OPERATOR( + MaxPool, + PoolOp>); REGISTER_CUDA_OPERATOR( MaxPoolGradient, PoolGradientOp); -REGISTER_CUDA_OPERATOR(MaxPool1D, PoolOp); +REGISTER_CUDA_OPERATOR( + MaxPool1D, + PoolOp>); REGISTER_CUDA_OPERATOR( MaxPool1DGradient, PoolGradientOp); -REGISTER_CUDA_OPERATOR(MaxPool2D, PoolOp); +REGISTER_CUDA_OPERATOR( + MaxPool2D, + PoolOp>); REGISTER_CUDA_OPERATOR( MaxPool2DGradient, PoolGradientOp); -REGISTER_CUDA_OPERATOR(MaxPool3D, PoolOp); +REGISTER_CUDA_OPERATOR( + MaxPool3D, + PoolOp>); REGISTER_CUDA_OPERATOR( MaxPool3DGradient, PoolGradientOp); + } // namespace caffe2 diff --git a/caffe2/operators/pool_op.h b/caffe2/operators/pool_op.h index a9758f4..4880a50 100644 --- a/caffe2/operators/pool_op.h +++ b/caffe2/operators/pool_op.h @@ -40,6 +40,13 @@ class PoolOp final : public ConvPoolOpBase { const int N = X.dim32(0); const int C = X.dim32(1); ConvPoolOpBase::SetOutputSize(X, Y, C); + const T* X_data = X.template data(); + T* Y_data = Y->template mutable_data(); + if (global_pooling_) { + const int HxW = X.numel() / (N * C); + return functor_.template GlobalPoolingForward( + N, C, HxW, X_data, Y_data, &context_); + } const std::vector X_HW_dims = GetDims(X); const std::vector Y_HW_dims = GetDims(*Y); return functor_.template Forward( @@ -63,6 +70,13 @@ class PoolOp final : public ConvPoolOpBase { const int N = X.dim32(0); const int C = X.dim32(ndim - 1); ConvPoolOpBase::SetOutputSize(X, Y, C); + const T* X_data = X.template data(); + T* Y_data = Y->template mutable_data(); + if (global_pooling_) { + const int HxW = X.numel() / (N * C); + return functor_.template GlobalPoolingForward( + N, C, HxW, X_data, Y_data, &context_); + } const std::vector X_HW_dims = GetDims(X); const std::vector Y_HW_dims = GetDims(*Y); return functor_.template Forward( @@ -156,6 +170,15 @@ struct AveragePoolFunctor { op.template GetSingleArgument("count_include_pad", false)) {} template + bool GlobalPoolingForward( + int N, + int C, + int HxW, + const T* X, + T* Y, + Context* context) const; + + template bool Forward( int N, int C, @@ -193,6 +216,15 @@ struct MaxPoolFunctor { explicit MaxPoolFunctor(const OperatorBase& /* op */) {} template + bool GlobalPoolingForward( + int N, + int C, + int HxW, + const T* X, + T* Y, + Context* context) const; + + template bool Forward( int N, int C, -- 2.7.4