#include "caffe2/core/context_gpu.h"
namespace caffe2 {
+
namespace {
struct AveragePool {
explicit MaxPool(const OperatorBase& /* op */) {}
};
-} // namespace
-
-namespace {
template <typename T>
-__global__ void Average1DPoolForwardNCHW(
- const int nthreads,
- const T* bottom_data,
- const int num,
- const int channels,
- const int height,
- const int pooled_height,
- const int kernel_h,
- const int stride_h,
- const int pad_t,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int n = index;
- int ph = n % pooled_height;
- n /= pooled_height;
- int c = n % channels;
- n /= channels;
- int hstart = ph * stride_h - pad_t;
- int hend = min(hstart + kernel_h, height);
- hstart = max(hstart, 0);
- top_data[index] = 0;
- int bottom_offset = (n * channels + c) * height;
- for (int h = hstart; h < hend; ++h) {
- top_data[index] += bottom_data[bottom_offset + h];
+__global__ void AveragePool1DForwardNCHWCUDAKernel(
+ const int K,
+ const int X_size,
+ const int Y_size,
+ const int kernel,
+ const int stride,
+ const int pad,
+ const bool count_include_pad,
+ const T* X,
+ T* Y) {
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const T* X_ptr = X + nc * X_size;
+ T* Y_ptr = Y + nc * Y_size;
+ const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+ if (y < Y_size) {
+ const int x = y * stride;
+ const int l = max(x - pad, 0);
+ const int r = min(x - pad + kernel, X_size);
+ const T scale = T(1) / static_cast<T>(count_include_pad ? kernel : r - l);
+ T sum = 0;
+ for (int i = l; i < r; ++i) {
+ sum += X_ptr[i];
}
- top_data[index] /= (hend - hstart);
+ Y_ptr[y] = sum * scale;
}
}
template <typename T>
-__global__ void Average2DPoolForwardNCHW(
- const int nthreads,
- const T* bottom_data,
- const int num,
- const int channels,
- const int height,
- const int width,
- const int pooled_height,
- const int pooled_width,
- const int kernel_h,
- const int kernel_w,
- const int stride_h,
- const int stride_w,
- const int pad_t,
- const int pad_l,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int n = index;
- int pw = n % pooled_width;
- n /= pooled_width;
- int ph = n % pooled_height;
- n /= pooled_height;
- int c = n % channels;
- n /= channels;
- int hstart = ph * stride_h - pad_t;
- int wstart = pw * stride_w - pad_l;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- top_data[index] = 0;
- int bottom_offset = (n * channels + c) * height * width;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- top_data[index] += bottom_data[bottom_offset + h * width + w];
- }
+__global__ void AveragePool1DForwardNHWCCUDAKernel(
+ const int C,
+ const int X_size,
+ const int Y_size,
+ const int kernel,
+ const int stride,
+ const int pad,
+ const bool count_include_pad,
+ const T* X,
+ T* Y) {
+ const int n = blockIdx.x / Y_size;
+ const int y = blockIdx.x % Y_size;
+ const int x = y * stride;
+ const int l = max(x - pad, 0);
+ const int r = min(x - pad + kernel, X_size);
+ const T scale = T(1) / static_cast<T>(count_include_pad ? kernel : r - l);
+ const T* X_ptr = X + n * X_size * C;
+ T* Y_ptr = Y + n * Y_size * C;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ T sum = 0;
+ for (int i = l; i < r; ++i) {
+ sum += X_ptr[i * C + c];
}
- top_data[index] /= (hend - hstart) * (wend - wstart);
+ Y_ptr[y * C + c] = sum * scale;
}
}
template <typename T>
-__global__ void Average3DPoolForwardNCHW(
- const int nthreads,
- const T* bottom_data,
- const int num,
- const int channels,
- const int height,
- const int width,
- const int depth,
- const int pooled_height,
- const int pooled_width,
- const int pooled_depth,
+__global__ void AveragePool2DForwardNCHWCUDAKernel(
+ const int K,
+ const int X_H,
+ const int X_W,
+ const int Y_H,
+ const int Y_W,
const int kernel_h,
const int kernel_w,
- const int kernel_d,
const int stride_h,
const int stride_w,
- const int stride_d,
const int pad_t,
const int pad_l,
- const int pad_f,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int n = index;
- int pd = n % pooled_depth;
- n /= pooled_depth;
- int pw = n % pooled_width;
- n /= pooled_width;
- int ph = n % pooled_height;
- n /= pooled_height;
- int c = n % channels;
- n /= channels;
- int hstart = ph * stride_h - pad_t;
- int wstart = pw * stride_w - pad_l;
- int dstart = pd * stride_d - pad_f;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- int dend = min(dstart + kernel_d, depth);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- dstart = max(dstart, 0);
- top_data[index] = 0;
- int bottom_offset = (n * channels + c) * height * width * depth;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- for (int d = dstart; d < dend; ++d) {
- const int input_index =
- bottom_offset + h * width * depth + w * depth + d;
- top_data[index] += bottom_data[input_index];
- }
+ const bool count_include_pad,
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_H * X_W;
+ const int Y_HxW = Y_H * Y_W;
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const T* X_ptr = X + nc * X_HxW;
+ T* Y_ptr = Y + nc * Y_HxW;
+ const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+ if (y < Y_HxW) {
+ const int yh = y / Y_W;
+ const int yw = y % Y_W;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ const T scale = T(1) /
+ static_cast<T>(count_include_pad ? kernel_h * kernel_w
+ : (b - t) * (r - l));
+ T sum = 0;
+ for (int i = t; i < b; ++i) {
+ for (int j = l; j < r; ++j) {
+ sum += X_ptr[i * X_W + j];
}
}
- top_data[index] /= (hend - hstart) * (wend - wstart) * (dend - dstart);
+ Y_ptr[y] = sum * scale;
}
}
template <typename T>
-__global__ void Average1DPoolForwardNHWC(
- const int nthreads,
- const T* bottom_data,
- const int num,
- const int height,
- const int channels,
- const int pooled_height,
+__global__ void AveragePool2DForwardNHWCCUDAKernel(
+ const int C,
+ const int X_H,
+ const int X_W,
+ const int Y_H,
+ const int Y_W,
const int kernel_h,
+ const int kernel_w,
const int stride_h,
+ const int stride_w,
const int pad_t,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int c = index % channels;
- int ph = (index / channels) % pooled_height;
- int n = index / channels / pooled_height;
- int hstart = ph * stride_h - pad_t;
- int hend = min(hstart + kernel_h, height);
- hstart = max(hstart, 0);
- T output = 0;
- int bottom_offset = n * height * channels + c;
- for (int h = hstart; h < hend; ++h) {
- output += bottom_data[bottom_offset + h * channels];
+ const int pad_l,
+ const bool count_include_pad,
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_H * X_W;
+ const int Y_HxW = Y_H * Y_W;
+ const int n = blockIdx.x / Y_HxW;
+ const int y = blockIdx.x % Y_HxW;
+ const int yh = y / Y_W;
+ const int yw = y % Y_W;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ const T scale = T(1) /
+ static_cast<T>(count_include_pad ? kernel_h * kernel_w
+ : (b - t) * (r - l));
+ const T* X_ptr = X + n * X_HxW * C;
+ T* Y_ptr = Y + n * Y_HxW * C;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ T sum = 0;
+ for (int i = t; i < b; ++i) {
+ for (int j = l; j < r; ++j) {
+ sum += X_ptr[(i * X_W + j) * C + c];
+ }
}
- int pool_size = (hend - hstart);
- top_data[index] = output / pool_size;
+ Y_ptr[y * C + c] = sum * scale;
}
}
template <typename T>
-__global__ void Average2DPoolForwardNHWC(
- const int nthreads,
- const T* bottom_data,
- const int num,
- const int height,
- const int width,
- const int channels,
- const int pooled_height,
- const int pooled_width,
+__global__ void AveragePool3DForwardNCHWCUDAKernel(
+ const int K,
+ const int X_D,
+ const int X_H,
+ const int X_W,
+ const int Y_D,
+ const int Y_H,
+ const int Y_W,
+ const int kernel_d,
const int kernel_h,
const int kernel_w,
+ const int stride_d,
const int stride_h,
const int stride_w,
+ const int pad_p,
const int pad_t,
const int pad_l,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int c = index % channels;
- int pw = (index / channels) % pooled_width;
- int ph = (index / channels / pooled_width) % pooled_height;
- int n = index / channels / pooled_width / pooled_height;
- int hstart = ph * stride_h - pad_t;
- int wstart = pw * stride_w - pad_l;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- T output = 0;
- int bottom_offset = n * height * width * channels + c;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- output += bottom_data[bottom_offset + (h * width + w) * channels];
+ const bool count_include_pad,
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_D * X_H * X_W;
+ const int Y_HxW = Y_D * Y_H * Y_W;
+ const int nc = blockIdx.x / K;
+ const int block = blockIdx.x % K;
+ const T* X_ptr = X + nc * X_HxW;
+ T* Y_ptr = Y + nc * Y_HxW;
+ const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS;
+ if (y < Y_HxW) {
+ const int yy = y / Y_W;
+ const int yw = y % Y_W;
+ const int yh = yy % Y_H;
+ const int yd = yy / Y_H;
+ const int xd = yd * stride_d;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int p = max(xd - pad_p, 0);
+ const int a = min(xd - pad_p + kernel_h, X_D);
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ const T scale = T(1) /
+ static_cast<T>(count_include_pad ? kernel_d * kernel_h * kernel_w
+ : (a - p) * (b - t) * (r - l));
+ T sum = 0;
+ for (int i = p; i < a; ++i) {
+ for (int j = t; j < b; ++j) {
+ for (int k = l; k < r; ++k) {
+ sum += X_ptr[(i * X_H + j) * X_W + k];
+ }
}
}
- int pool_size = (hend - hstart) * (wend - wstart);
- top_data[index] = output / pool_size;
+ Y_ptr[y] = sum * scale;
}
}
template <typename T>
-__global__ void Average3DPoolForwardNHWC(
- const int nthreads,
- const T* bottom_data,
- const int num,
- const int height,
- const int width,
- const int depth,
- const int channels,
- const int pooled_height,
- const int pooled_width,
- const int pooled_depth,
+__global__ void AveragePool3DForwardNHWCCUDAKernel(
+ const int C,
+ const int X_D,
+ const int X_H,
+ const int X_W,
+ const int Y_D,
+ const int Y_H,
+ const int Y_W,
+ const int kernel_d,
const int kernel_h,
const int kernel_w,
- const int kernel_d,
+ const int stride_d,
const int stride_h,
const int stride_w,
- const int stride_d,
+ const int pad_p,
const int pad_t,
const int pad_l,
- const int pad_f,
- T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int c = index % channels;
- int pd = (index / channels) % pooled_depth;
- int pw = (index / channels / pooled_depth) % pooled_width;
- int ph = (index / channels / pooled_depth / pooled_width) % pooled_height;
- int n = index / channels / pooled_depth / pooled_width / pooled_height;
- int hstart = ph * stride_h - pad_t;
- int wstart = pw * stride_w - pad_l;
- int dstart = pd * stride_d - pad_f;
- int hend = min(hstart + kernel_h, height);
- int wend = min(wstart + kernel_w, width);
- int dend = min(dstart + kernel_d, depth);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- dstart = max(dstart, 0);
- T output = 0;
- int bottom_offset = n * height * width * depth * channels + c;
- for (int h = hstart; h < hend; ++h) {
- for (int w = wstart; w < wend; ++w) {
- for (int d = dstart; d < dend; ++d) {
- const int bottom_index =
- bottom_offset + (h * depth * width + w * depth + d) * channels;
- output += bottom_data[bottom_index];
+ const bool count_include_pad,
+ const T* X,
+ T* Y) {
+ const int X_HxW = X_D * X_H * X_W;
+ const int Y_HxW = Y_D * Y_H * Y_W;
+ const int n = blockIdx.x / Y_HxW;
+ const int y = blockIdx.x % Y_HxW;
+ const int yy = y / Y_W;
+ const int yw = y % Y_W;
+ const int yh = yy % Y_H;
+ const int yd = yy / Y_H;
+ const int xd = yd * stride_d;
+ const int xh = yh * stride_h;
+ const int xw = yw * stride_w;
+ const int p = max(xd - pad_p, 0);
+ const int a = min(xd - pad_p + kernel_h, X_D);
+ const int t = max(xh - pad_t, 0);
+ const int b = min(xh - pad_t + kernel_h, X_H);
+ const int l = max(xw - pad_l, 0);
+ const int r = min(xw - pad_l + kernel_w, X_W);
+ const T scale = T(1) /
+ static_cast<T>(count_include_pad ? kernel_d * kernel_h * kernel_w
+ : (a - p) * (b - t) * (r - l));
+ const T* X_ptr = X + n * X_HxW * C;
+ T* Y_ptr = Y + n * Y_HxW * C;
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ T sum = 0;
+ for (int i = p; i < a; ++i) {
+ for (int j = t; j < b; ++j) {
+ for (int k = l; k < r; ++k) {
+ sum += X_ptr[((i * X_H + j) * X_W + k) * C + c];
}
}
}
- int pool_size = (hend - hstart) * (wend - wstart) * (dend - dstart);
- top_data[index] = output / pool_size;
+ Y_ptr[y * C + c] = sum * scale;
}
}
}
}
-} // namespace
+} // namespace
template <>
-bool PoolOp<float, CUDAContext, AveragePool>::RunOnDeviceWithOrderNCHW() {
- auto& X = Input(0);
- auto* Y = Output(0);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
- int output_size = Y->size();
- switch (kernel_.size()) {
- case 1:
- Average1DPoolForwardNCHW<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(0),
- X.dim32(1),
- X.dim32(2),
- Y->dim32(2),
- kernel_h(),
- stride_h(),
- pad_t(),
- Y->template mutable_data<float>());
- break;
- case 2:
- Average2DPoolForwardNCHW<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(0),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- Y->dim32(2),
- Y->dim32(3),
- kernel_h(),
- kernel_w(),
- stride_h(),
- stride_w(),
- pad_t(),
- pad_l(),
- Y->template mutable_data<float>());
- break;
- case 3:
- Average3DPoolForwardNCHW<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(0),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- X.dim32(4),
- Y->dim32(2),
- Y->dim32(3),
- Y->dim32(4),
- kernel_h(),
- kernel_w(),
- kernel_[2],
- stride_h(),
- stride_w(),
- stride_[2],
- pad_t(),
- pad_l(),
- pads_[2],
- Y->template mutable_data<float>());
- break;
- default:
- CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
- }
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW, 1>(
+ const int N,
+ const int C,
+ const std::array<int, 1>& X_dims,
+ const std::array<int, 1>& Y_dims,
+ const std::array<int, 1>& kernel,
+ const std::array<int, 1>& /* dilation */,
+ const std::array<int, 1>& stride,
+ const std::array<int, 2>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ const int K =
+ (Y_dims[0] + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+ AveragePool1DForwardNCHWCUDAKernel<float>
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ K,
+ X_dims[0],
+ Y_dims[0],
+ kernel[0],
+ stride[0],
+ pads[0],
+ count_include_pad,
+ X,
+ Y);
return true;
}
template <>
-bool PoolOp<float, CUDAContext, AveragePool>::RunOnDeviceWithOrderNHWC() {
- auto& X = Input(0);
- auto* Y = Output(0);
- ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(X.ndim() - 1));
- int output_size = Y->size();
- switch (kernel_.size()) {
- case 1:
- Average1DPoolForwardNHWC<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(0),
- X.dim32(1),
- X.dim32(2),
- Y->dim32(1),
- kernel_h(),
- stride_h(),
- pad_t(),
- Y->template mutable_data<float>());
- break;
- case 2:
- Average2DPoolForwardNHWC<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(0),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- Y->dim32(1),
- Y->dim32(2),
- kernel_h(),
- kernel_w(),
- stride_h(),
- stride_w(),
- pad_t(),
- pad_l(),
- Y->template mutable_data<float>());
- break;
- case 3:
- Average3DPoolForwardNHWC<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- X.dim32(0),
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- X.dim32(4),
- Y->dim32(1),
- Y->dim32(2),
- Y->dim32(3),
- kernel_h(),
- kernel_w(),
- kernel_[2],
- stride_h(),
- stride_w(),
- stride_[2],
- pad_t(),
- pad_l(),
- pads_[2],
- Y->template mutable_data<float>());
- break;
- default:
- CAFFE_THROW("Unsupported pooling size : ", kernel_.size());
- }
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC, 1>(
+ const int N,
+ const int C,
+ const std::array<int, 1>& X_dims,
+ const std::array<int, 1>& Y_dims,
+ const std::array<int, 1>& kernel,
+ const std::array<int, 1>& /* dilation */,
+ const std::array<int, 1>& stride,
+ const std::array<int, 2>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ AveragePool1DForwardNHWCCUDAKernel<float>
+ <<<N * Y_dims[0], CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ C,
+ X_dims[0],
+ Y_dims[0],
+ kernel[0],
+ stride[0],
+ pads[0],
+ count_include_pad,
+ X,
+ Y);
+ return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW, 2>(
+ const int N,
+ const int C,
+ const std::array<int, 2>& X_dims,
+ const std::array<int, 2>& Y_dims,
+ const std::array<int, 2>& kernel,
+ const std::array<int, 2>& /* dilation */,
+ const std::array<int, 2>& stride,
+ const std::array<int, 4>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ const int Y_HxW = Y_dims[0] * Y_dims[1];
+ const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+ AveragePool2DForwardNCHWCUDAKernel<float>
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ K,
+ X_dims[0],
+ X_dims[1],
+ Y_dims[0],
+ Y_dims[1],
+ kernel[0],
+ kernel[1],
+ stride[0],
+ stride[1],
+ pads[0],
+ pads[1],
+ count_include_pad,
+ X,
+ Y);
+ return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC, 2>(
+ const int N,
+ const int C,
+ const std::array<int, 2>& X_dims,
+ const std::array<int, 2>& Y_dims,
+ const std::array<int, 2>& kernel,
+ const std::array<int, 2>& /* dilation */,
+ const std::array<int, 2>& stride,
+ const std::array<int, 4>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ const int Y_HxW = Y_dims[0] * Y_dims[1];
+ AveragePool2DForwardNHWCCUDAKernel<float>
+ <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ C,
+ X_dims[0],
+ X_dims[1],
+ Y_dims[0],
+ Y_dims[1],
+ kernel[0],
+ kernel[1],
+ stride[0],
+ stride[1],
+ pads[0],
+ pads[1],
+ count_include_pad,
+ X,
+ Y);
+ return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NCHW, 3>(
+ const int N,
+ const int C,
+ const std::array<int, 3>& X_dims,
+ const std::array<int, 3>& Y_dims,
+ const std::array<int, 3>& kernel,
+ const std::array<int, 3>& /* dilation */,
+ const std::array<int, 3>& stride,
+ const std::array<int, 6>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ const int Y_HxW = Y_dims[0] * Y_dims[1] * Y_dims[2];
+ const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+ AveragePool3DForwardNCHWCUDAKernel<float>
+ <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ K,
+ X_dims[0],
+ X_dims[1],
+ X_dims[2],
+ Y_dims[0],
+ Y_dims[1],
+ Y_dims[2],
+ kernel[0],
+ kernel[1],
+ kernel[2],
+ stride[0],
+ stride[1],
+ stride[2],
+ pads[0],
+ pads[1],
+ pads[2],
+ count_include_pad,
+ X,
+ Y);
+ return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC, 3>(
+ const int N,
+ const int C,
+ const std::array<int, 3>& X_dims,
+ const std::array<int, 3>& Y_dims,
+ const std::array<int, 3>& kernel,
+ const std::array<int, 3>& /* dilation */,
+ const std::array<int, 3>& stride,
+ const std::array<int, 6>& pads,
+ const float* X,
+ float* Y,
+ CUDAContext* context) const {
+ const int Y_HxW = Y_dims[0] * Y_dims[1] * Y_dims[2];
+ AveragePool3DForwardNHWCCUDAKernel<float>
+ <<<N * Y_HxW, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ C,
+ X_dims[0],
+ X_dims[1],
+ X_dims[2],
+ Y_dims[0],
+ Y_dims[1],
+ Y_dims[2],
+ kernel[0],
+ kernel[1],
+ kernel[2],
+ stride[0],
+ stride[1],
+ stride[2],
+ pads[0],
+ pads[1],
+ pads[2],
+ count_include_pad,
+ X,
+ Y);
return true;
}
return true;
}
-
namespace {
template <typename T>
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
- const int top_offset =
- n * pooled_height * pooled_width * channels + c;
+ const int top_offset = n * pooled_height * pooled_width * channels + c;
bottom_diff[index] = 0;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
}
}
}
-} // namespace
+} // namespace
template <>
bool PoolOp<float, CUDAContext, MaxPool>::RunOnDeviceWithOrderNCHW() {
return true;
}
-REGISTER_CUDA_OPERATOR(AveragePool, PoolOp<float, CUDAContext, AveragePool>);
-REGISTER_CUDA_OPERATOR(AveragePoolGradient,
- PoolGradientOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+ AveragePool,
+ PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
+REGISTER_CUDA_OPERATOR(
+ AveragePoolGradient,
+ PoolGradientOp<float, CUDAContext, AveragePool>);
-REGISTER_CUDA_OPERATOR(AveragePool1D, PoolOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+ AveragePool1D,
+ PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
AveragePool1DGradient,
PoolGradientOp<float, CUDAContext, AveragePool>);
-REGISTER_CUDA_OPERATOR(AveragePool2D, PoolOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+ AveragePool2D,
+ PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
AveragePool2DGradient,
PoolGradientOp<float, CUDAContext, AveragePool>);
-REGISTER_CUDA_OPERATOR(AveragePool3D, PoolOp<float, CUDAContext, AveragePool>);
+REGISTER_CUDA_OPERATOR(
+ AveragePool3D,
+ PoolOp<float, CUDAContext, AveragePoolFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
AveragePool3DGradient,
PoolGradientOp<float, CUDAContext, AveragePool>);
REGISTER_CUDA_OPERATOR(MaxPool, PoolOp<float, CUDAContext, MaxPool>);
-REGISTER_CUDA_OPERATOR(MaxPoolGradient,
- PoolGradientOp<float, CUDAContext, MaxPool>);
+REGISTER_CUDA_OPERATOR(
+ MaxPoolGradient,
+ PoolGradientOp<float, CUDAContext, MaxPool>);
REGISTER_CUDA_OPERATOR(MaxPool1D, PoolOp<float, CUDAContext, MaxPool>);
REGISTER_CUDA_OPERATOR(
REGISTER_CUDA_OPERATOR(
MaxPool3DGradient,
PoolGradientOp<float, CUDAContext, MaxPool>);
-} // namespace caffe2
+} // namespace caffe2