// TODO(ataei): reduce the apparent redundancy of all the code below.
#include "caffe2/operators/pool_op.h"
-#include <cfloat>
+#include <array>
#include <functional>
+#include <limits>
#include <numeric>
#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
namespace caffe2 {
} // namespace
template <>
+template <typename T, StorageOrder kOrder>
+bool AveragePoolFunctor<CUDAContext>::GlobalPoolingForward(
+ const int N,
+ const int C,
+ const int HxW,
+ const T* X,
+ T* Y,
+ CUDAContext* context) const {
+ if (kOrder == StorageOrder::NCHW) {
+ const std::array<int, 2> dims = {N * C, HxW};
+ const int axis = 1;
+ math::ReduceMean<float, CUDAContext>(
+ 2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+ } else {
+ const std::array<int, 3> dims = {N, HxW, C};
+ const int axis = 1;
+ math::ReduceMean<float, CUDAContext>(
+ 3, dims.data(), 1, &axis, 1.0f, X, Y, context);
+ }
+ return true;
+}
+
+template <>
template <>
bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
const int N,
const float* X,
float* Y,
CUDAContext* context) const {
+ // Split each image into K segments, each CUDA block handles one segment.
const int ndim = X_dims.size();
const int Y_HxW = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
const float* X,
float* Y,
CUDAContext* context) const {
+ // Each CUDA block handles one point, one thread per channel.
const int ndim = X_dims.size();
const int Y_HxW = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
namespace {
template <typename T>
-__global__ void MaxPool1DForwardNCHW(
- const int nthreads,
- const T* bottom_data,
- const int channels,
- const int height,
- const int pooled_height,
- const int kernel_h,
- const int stride_h,
- const int pad_t,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int ph = index % pooled_height;
- int c = (index / pooled_height) % channels;
- int n = index / pooled_height / channels;
- int hstart = ph * stride_h - pad_t;
- int hend = min(hstart + kernel_h, height);
- hstart = max(hstart, 0);
- T maxval = -FLT_MAX;
- const T* bdata_offset = bottom_data + n * channels * height;
- for (int h = hstart; h < hend; ++h) {
- int idx = c * height + h;
- if (bdata_offset[idx] > maxval) {
- maxval = bdata_offset[idx];
- }
+__global__ void MaxPool1DForwardNCHWCUDAKernel(
+ const int K,
+ const int X_size,
+ const int Y_size,
+ const int kernel,
+ const int stride,
+ const int pad,
+ const T* X,
+ T* Y) {
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const T* X_ptr = X + nc * X_size;
+ T* Y_ptr = Y + nc * Y_size;
+ const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+ if (y < Y_size) {
+ const int x = y * stride;
+ const int l = max(x - pad, 0);
+ const int r = min(x - pad + kernel, X_size);
+ T val = std::numeric_limits<T>::lowest();
+ for (int i = l; i < r; ++i) {
+ val = max(val, X_ptr[i]);
}
- top_data[index] = maxval;
+ Y_ptr[y] = val;
}
}
template <typename T>
-__global__ void MaxPool2DForwardNCHW(
- const int nthreads,
- const T* bottom_data,
- const int channels,
- const int height,
- const int width,
- const int pooled_height,
- const int pooled_width,
- const int kernel_h,
- const int kernel_w,
- const int stride_h,
- const int stride_w,
- const int pad_t,
- const int pad_l,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height;
- int c = (index / pooled_width / pooled_height) % channels;
- int n = index / pooled_width / pooled_height / channels;
- int hstart = ph * stride_h - pad_t;
- int wstart = pw * stride_w - pad_l;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- T maxval = -FLT_MAX;
- const T* bdata_offset = bottom_data + n * channels * height * width;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- int idx = c * height * width + h * width + w;
- if (bdata_offset[idx] > maxval) {
- maxval = bdata_offset[idx];
- }
- }
+__global__ void MaxPool1DForwardNHWCCUDAKernel(
+ const int C,
+ const int X_size,
+ const int Y_size,
+ const int kernel,
+ const int stride,
+ const int pad,
+ const T* X,
+ T* Y) {
+ const int n = blockIdx.x / Y_size;
+ const int y = blockIdx.x % Y_size;
+ const int x = y * stride;
+ const int l = max(x - pad, 0);
+ const int r = min(x - pad + kernel, X_size);
+ const T* X_ptr = X + n * X_size * C;
+ T* Y_ptr = Y + n * Y_size * C;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ T val = std::numeric_limits<T>::lowest();
+ for (int i = l; i < r; ++i) {
+ val = max(val, X_ptr[i * C + c]);
}
- top_data[index] = maxval;
+ Y_ptr[y * C + c] = val;
}
}
template <typename T>
-__global__ void MaxPool3DForwardNCHW(
- const int nthreads,
- const T* bottom_data,
- const int channels,
- const int height,
- const int width,
- const int depth,
- const int pooled_height,
- const int pooled_width,
- const int pooled_depth,
+__global__ void MaxPool2DForwardNCHWCUDAKernel(
+ const int K,
+ const int X_H,
+ const int X_W,
+ const int Y_H,
+ const int Y_W,
const int kernel_h,
const int kernel_w,
- const int kernel_d,
const int stride_h,
const int stride_w,
- const int stride_d,
const int pad_t,
const int pad_l,
- const int pad_f,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int pd = index % pooled_depth;
- int pw = (index / pooled_depth) % pooled_width;
- int ph = (index / pooled_depth / pooled_width) % pooled_height;
- int c = (index / pooled_depth / pooled_width / pooled_height) % channels;
- int n = index / pooled_depth / pooled_width / pooled_height / channels;
- int hstart = ph * stride_h - pad_t;
- int wstart = pw * stride_w - pad_l;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- int dstart = pd * stride_d - pad_f;
- int dend = min(dstart + kernel_d, depth);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- dstart = max(dstart, 0);
- T maxval = -FLT_MAX;
- const T* bdata_offset = bottom_data + n * channels * height * width * depth;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- for (int d = dstart; d < dend; ++d) {
- int idx = ((c * height + h) * width + w) * depth + d;
- if (bdata_offset[idx] > maxval) {
- maxval = bdata_offset[idx];
- }
- }
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_H * X_W;
+ const int Y_HxW = Y_H * Y_W;
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const T* X_ptr = X + nc * X_HxW;
+ T* Y_ptr = Y + nc * Y_HxW;
+ const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+ if (y < Y_HxW) {
+ const int yh = y / Y_W;
+ const int yw = y % Y_W;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ T val = std::numeric_limits<T>::lowest();
+ for (int i = t; i < b; ++i) {
+ for (int j = l; j < r; ++j) {
+ val = max(val, X_ptr[i * X_W + j]);
}
}
- top_data[index] = maxval;
+ Y_ptr[y] = val;
}
}
template <typename T>
-__global__ void MaxPool1DForwardNHWC(
- const int nthreads,
- const T* bottom_data,
- const int height,
- const int channels,
- const int pooled_height,
+__global__ void MaxPool2DForwardNHWCCUDAKernel(
+ const int C,
+ const int X_H,
+ const int X_W,
+ const int Y_H,
+ const int Y_W,
const int kernel_h,
+ const int kernel_w,
const int stride_h,
+ const int stride_w,
const int pad_t,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int n = index;
- int c = n % channels;
- n /= channels;
- int hstart = (n % pooled_height) * stride_h - pad_t;
- n /= pooled_height;
- int hend = min(hstart + kernel_h, height);
- hstart = max(hstart, 0);
- T maxval = -FLT_MAX;
- const T* bdata_offset = bottom_data + n * height * channels;
- for (int h = hstart; h < hend; ++h) {
- int idx = h * channels + c;
- if (bdata_offset[idx] > maxval) {
- maxval = bdata_offset[idx];
+ const int pad_l,
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_H * X_W;
+ const int Y_HxW = Y_H * Y_W;
+ const int n = blockIdx.x / Y_HxW;
+ const int y = blockIdx.x % Y_HxW;
+ const int yh = y / Y_W;
+ const int yw = y % Y_W;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ const T* X_ptr = X + n * X_HxW * C;
+ T* Y_ptr = Y + n * Y_HxW * C;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ T val = std::numeric_limits<T>::lowest();
+ for (int i = t; i < b; ++i) {
+ for (int j = l; j < r; ++j) {
+ val = max(val, X_ptr[(i * X_W + j) * C + c]);
}
}
- top_data[index] = maxval;
+ Y_ptr[y * C + c] = val;
}
}
template <typename T>
-__global__ void MaxPool2DForwardNHWC(
- const int nthreads,
- const T* bottom_data,
- const int height,
- const int width,
- const int channels,
- const int pooled_height,
- const int pooled_width,
+__global__ void MaxPool3DForwardNCHWCUDAKernel(
+ const int K,
+ const int X_D,
+ const int X_H,
+ const int X_W,
+ const int Y_D,
+ const int Y_H,
+ const int Y_W,
+ const int kernel_d,
const int kernel_h,
const int kernel_w,
+ const int stride_d,
const int stride_h,
const int stride_w,
+ const int pad_p,
const int pad_t,
const int pad_l,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int n = index;
- int c = n % channels;
- n /= channels;
- int wstart = (n % pooled_width) * stride_w - pad_l;
- n /= pooled_width;
- int hstart = (n % pooled_height) * stride_h - pad_t;
- n /= pooled_height;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- T maxval = -FLT_MAX;
- const T* bdata_offset = bottom_data + n * height * width * channels;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- int idx = (h * width + w) * channels + c;
- if (bdata_offset[idx] > maxval) {
- maxval = bdata_offset[idx];
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_D * X_H * X_W;
+ const int Y_HxW = Y_D * Y_H * Y_W;
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const T* X_ptr = X + nc * X_HxW;
+ T* Y_ptr = Y + nc * Y_HxW;
+ const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+ if (y < Y_HxW) {
+ const int yy = y / Y_W;
+ const int yw = y % Y_W;
+ const int yh = yy % Y_H;
+ const int yd = yy / Y_H;
+ const int xd = yd * stride_d;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int p = max(xd - pad_p, 0);
+ const int a = min(xd - pad_p + kernel_d, X_D);
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ T val = std::numeric_limits<T>::lowest();
+ for (int i = p; i < a; ++i) {
+ for (int j = t; j < b; ++j) {
+ for (int k = l; k < r; ++k) {
+ val = max(val, X_ptr[(i * X_H + j) * X_W + k]);
}
}
}
- top_data[index] = maxval;
+ Y_ptr[y] = val;
}
}
template <typename T>
-__global__ void MaxPool3DForwardNHWC(
- const int nthreads,
- const T* bottom_data,
- const int height,
- const int width,
- const int depth,
- const int channels,
- const int pooled_height,
- const int pooled_width,
- const int pooled_depth,
+__global__ void MaxPool3DForwardNHWCCUDAKernel(
+ const int C,
+ const int X_D,
+ const int X_H,
+ const int X_W,
+ const int Y_D,
+ const int Y_H,
+ const int Y_W,
+ const int kernel_d,
const int kernel_h,
const int kernel_w,
- const int kernel_d,
+ const int stride_d,
const int stride_h,
const int stride_w,
- const int stride_d,
+ const int pad_p,
const int pad_t,
const int pad_l,
- const int pad_f,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int n = index;
- int c = n % channels;
- n /= channels;
- int dstart = (n % pooled_depth) * stride_d - pad_f;
- n /= pooled_depth;
- int wstart = (n % pooled_width) * stride_w - pad_l;
- n /= pooled_width;
- int hstart = (n % pooled_height) * stride_h - pad_t;
- n /= pooled_height;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- int dend = min(dstart + kernel_d, depth);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- dstart = max(dstart, 0);
- T maxval = -FLT_MAX;
- const T* bdata_offset = bottom_data + n * height * width * depth * channels;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- for (int d = dstart; d < dend; ++d) {
- int idx = ((h * width + w) * depth + d) * channels + c;
- if (bdata_offset[idx] > maxval) {
- maxval = bdata_offset[idx];
- }
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_D * X_H * X_W;
+ const int Y_HxW = Y_D * Y_H * Y_W;
+ const int n = blockIdx.x / Y_HxW;
+ const int y = blockIdx.x % Y_HxW;
+ const int yy = y / Y_W;
+ const int yw = y % Y_W;
+ const int yh = yy % Y_H;
+ const int yd = yy / Y_H;
+ const int xd = yd * stride_d;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int p = max(xd - pad_p, 0);
+ const int a = min(xd - pad_p + kernel_d, X_D);
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ const T* X_ptr = X + n * X_HxW * C;
+ T* Y_ptr = Y + n * Y_HxW * C;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ T val = std::numeric_limits<T>::lowest();
+ for (int i = p; i < a; ++i) {
+ for (int j = t; j < b; ++j) {
+ for (int k = l; k < r; ++k) {
+ val = max(val, X_ptr[((i * X_H + j) * X_W + k) * C + c]);
}
}
}
- top_data[index] = maxval;
+ Y_ptr[y * C + c] = val;
}
}
}
}
}
+
} // namespace
template <>
-bool PoolOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNCHW() {
- auto& X = Input(0);
- auto* Y = Output(0);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
- int output_size = Y->size();
- switch (kernel_.size()) {
- case 1:
- MaxPool1DForwardNCHW<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(1),
- X.dim32(2),
- Y->dim32(2),
- kernel_h(),
- stride_h(),
- pad_t(),
- Y->template mutable_data<float>());
- break;
- case 2:
- MaxPool2DForwardNCHW<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- Y->dim32(2),
- Y->dim32(3),
- kernel_h(),
- kernel_w(),
- stride_h(),
- stride_w(),
- pad_t(),
- pad_l(),
- Y->template mutable_data<float>());
- break;
- case 3:
- MaxPool3DForwardNCHW<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- X.dim32(4),
- Y->dim32(2),
- Y->dim32(3),
- Y->dim32(4),
- kernel_h(),
- kernel_w(),
- kernel_[2],
- stride_h(),
- stride_w(),
- stride_[2],
- pad_t(),
- pad_l(),
- pads_[2],
- Y->template mutable_data<float>());
- break;
- default:
- CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
+template <typename T, StorageOrder kOrder>
+bool MaxPoolFunctor<CUDAContext>::GlobalPoolingForward(
+ const int N,
+ const int C,
+ const int HxW,
+ const T* X,
+ T* Y,
+ CUDAContext* context) const {
+ if (kOrder == StorageOrder::NCHW) {
+ const std::array<int, 2> dims = {N * C, HxW};
+ const int axis = 1;
+ math::ReduceMax<float, CUDAContext>(
+ 2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+ } else {
+ const std::array<int, 3> dims = {N, HxW, C};
+ const int axis = 1;
+ math::ReduceMax<float, CUDAContext>(
+ 3, dims.data(), 1, &axis, 1.0f, X, Y, context);
}
return true;
}
template <>
-bool PoolOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNHWC() {
- auto& X = Input(0);
- auto* Y = Output(0);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(X.ndim() - 1));
- int output_size = Y->size();
- switch (kernel_.size()) {
- case 1:
- MaxPool1DForwardNHWC<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(1),
- X.dim32(2),
- Y->dim32(1),
- kernel_h(),
- stride_h(),
- pad_t(),
- Y->template mutable_data<float>());
- break;
- case 2:
- MaxPool2DForwardNHWC<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- Y->dim32(1),
- Y->dim32(2),
- kernel_h(),
- kernel_w(),
- stride_h(),
- stride_w(),
- pad_t(),
- pad_l(),
- Y->template mutable_data<float>());
- break;
- case 3:
- MaxPool3DForwardNHWC<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- X.dim32(4),
- Y->dim32(1),
- Y->dim32(2),
- Y->dim32(3),
- kernel_h(),
- kernel_w(),
- kernel_[2],
- stride_h(),
- stride_w(),
- stride_[2],
- pad_t(),
- pad_l(),
- pads_[2],
- Y->template mutable_data<float>());
- break;
- default:
- CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
+template <>
+bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW>(
+ const int N,
+ const int C,
+ const std::vector<int>& X_dims,
+ const std::vector<int>& Y_dims,
+ const std::vector<int>& kernel,
+ const std::vector<int>& /* dilation */,
+ const std::vector<int>& stride,
+ const std::vector<int>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ // Split each image into K segments, each CUDA block handles one segment.
+ const int ndim = X_dims.size();
+ const int Y_HxW = std::accumulate(
+ Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
+ const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+ switch (ndim) {
+ case 1: {
+ MaxPool1DForwardNCHWCUDAKernel<float>
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ K, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y);
+ return true;
+ }
+ case 2: {
+ MaxPool2DForwardNCHWCUDAKernel<float>
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ K,
+ X_dims[0],
+ X_dims[1],
+ Y_dims[0],
+ Y_dims[1],
+ kernel[0],
+ kernel[1],
+ stride[0],
+ stride[1],
+ pads[0],
+ pads[1],
+ X,
+ Y);
+ return true;
+ }
+ case 3: {
+ MaxPool3DForwardNCHWCUDAKernel<float>
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ K,
+ X_dims[0],
+ X_dims[1],
+ X_dims[2],
+ Y_dims[0],
+ Y_dims[1],
+ Y_dims[2],
+ kernel[0],
+ kernel[1],
+ kernel[2],
+ stride[0],
+ stride[1],
+ stride[2],
+ pads[0],
+ pads[1],
+ pads[2],
+ X,
+ Y);
+ return true;
+ }
+ default: {
+ CAFFE_THROW("Unsupported pooling dim: ", ndim);
+ return false;
+ }
+ }
+}
+
+template <>
+template <>
+bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
+ const int N,
+ const int C,
+ const std::vector<int>& X_dims,
+ const std::vector<int>& Y_dims,
+ const std::vector<int>& kernel,
+ const std::vector<int>& /* dilation */,
+ const std::vector<int>& stride,
+ const std::vector<int>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ // Each CUDA block handles one point, one thread per channel.
+ const int ndim = X_dims.size();
+ const int Y_HxW = std::accumulate(
+ Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
+ switch (ndim) {
+ case 1: {
+ MaxPool1DForwardNHWCCUDAKernel<float>
+ <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ C, X_dims[0], Y_dims[0], kernel[0], stride[0], pads[0], X, Y);
+ return true;
+ }
+ case 2: {
+ MaxPool2DForwardNHWCCUDAKernel<float>
+ <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ C,
+ X_dims[0],
+ X_dims[1],
+ Y_dims[0],
+ Y_dims[1],
+ kernel[0],
+ kernel[1],
+ stride[0],
+ stride[1],
+ pads[0],
+ pads[1],
+ X,
+ Y);
+ return true;
+ }
+ case 3: {
+ MaxPool3DForwardNHWCCUDAKernel<float>
+ <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ C,
+ X_dims[0],
+ X_dims[1],
+ X_dims[2],
+ Y_dims[0],
+ Y_dims[1],
+ Y_dims[2],
+ kernel[0],
+ kernel[1],
+ kernel[2],
+ stride[0],
+ stride[1],
+ stride[2],
+ pads[0],
+ pads[1],
+ pads[2],
+ X,
+ Y);
+ return true;
+ }
+ default: {
+ CAFFE_THROW("Unsupported pooling dim: ", ndim);
+ return false;
+ }
}
- return true;
}
template <>
AveragePool3DGradient,
PoolGradientOp<float, CUDAContext, AveragePool>);
-REGISTER_CUDA_OPERATOR(MaxPool, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+ MaxPool,
+ PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
MaxPoolGradient,
PoolGradientOp<float, CUDAContext, MaxPool>);
-REGISTER_CUDA_OPERATOR(MaxPool1D, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+ MaxPool1D,
+ PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
MaxPool1DGradient,
PoolGradientOp<float, CUDAContext, MaxPool>);
-REGISTER_CUDA_OPERATOR(MaxPool2D, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+ MaxPool2D,
+ PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
MaxPool2DGradient,
PoolGradientOp<float, CUDAContext, MaxPool>);
-REGISTER_CUDA_OPERATOR(MaxPool3D, PoolOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+ MaxPool3D,
+ PoolOp<float, CUDAContext, MaxPoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
MaxPool3DGradient,
PoolGradientOp<float, CUDAContext, MaxPool>);
+
} // namespace caffe2