From: Xiaomeng Yang Date: Tue, 8 Jan 2019 05:33:44 +0000 (-0800) Subject: Add count_include_pad arg for AveragePoolOp on GPU (#15787) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1984 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4650d70e930fd90484e648bd7a7c6b7ac2bafb72;p=platform%2Fupstream%2Fpytorch.git Add count_include_pad arg for AveragePoolOp on GPU (#15787) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15787 Add count_include_pad arg for AveragePoolOp on GPU Reviewed By: houseroad Differential Revision: D13589185 fbshipit-source-id: 235a84cfcd2033ee796c13e338fc3d03e832b5b1 --- diff --git a/caffe2/operators/pool_gradient_op.cc b/caffe2/operators/pool_gradient_op.cc index c34335e..e33b408 100644 --- a/caffe2/operators/pool_gradient_op.cc +++ b/caffe2/operators/pool_gradient_op.cc @@ -6,9 +6,6 @@ namespace caffe2 { -using std::max; -using std::min; - namespace { template diff --git a/caffe2/operators/pool_op.cu b/caffe2/operators/pool_op.cu index 55c59cb..25e2b2c 100644 --- a/caffe2/operators/pool_op.cu +++ b/caffe2/operators/pool_op.cu @@ -6,6 +6,7 @@ #include "caffe2/core/context_gpu.h" namespace caffe2 { + namespace { struct AveragePool { @@ -16,256 +17,262 @@ struct MaxPool { explicit MaxPool(const OperatorBase& /* op */) {} }; -} // namespace - -namespace { template -__global__ void Average1DPoolForwardNCHW( - const int nthreads, - const T* bottom_data, - const int num, - const int channels, - const int height, - const int pooled_height, - const int kernel_h, - const int stride_h, - const int pad_t, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index; - int ph = n % pooled_height; - n /= pooled_height; - int c = n % channels; - n /= channels; - int hstart = ph * stride_h - pad_t; - int hend = min(hstart + kernel_h, height); - hstart = max(hstart, 0); - top_data[index] = 0; - int bottom_offset = (n * channels + c) * height; - for (int h = hstart; h < hend; ++h) { - top_data[index] += bottom_data[bottom_offset + h]; +__global__ void AveragePool1DForwardNCHWCUDAKernel( + const int K, + const int X_size, + const int Y_size, + const int kernel, + const int stride, + const int pad, + const bool count_include_pad, + const T* X, + T* Y) { + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const T* X_ptr = X + nc * X_size; + T* Y_ptr = Y + nc * Y_size; + const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS; + if (y < Y_size) { + const int x = y * stride; + const int l = max(x - pad, 0); + const int r = min(x - pad + kernel, X_size); + const T scale = T(1) / static_cast(count_include_pad ? kernel : r - l); + T sum = 0; + for (int i = l; i < r; ++i) { + sum += X_ptr[i]; } - top_data[index] /= (hend - hstart); + Y_ptr[y] = sum * scale; } } template -__global__ void Average2DPoolForwardNCHW( - const int nthreads, - const T* bottom_data, - const int num, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_t, - const int pad_l, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index; - int pw = n % pooled_width; - n /= pooled_width; - int ph = n % pooled_height; - n /= pooled_height; - int c = n % channels; - n /= channels; - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - top_data[index] = 0; - int bottom_offset = (n * channels + c) * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - top_data[index] += bottom_data[bottom_offset + h * width + w]; - } +__global__ void AveragePool1DForwardNHWCCUDAKernel( + const int C, + const int X_size, + const int Y_size, + const int kernel, + const int stride, + const int pad, + const bool count_include_pad, + const T* X, + T* Y) { + const int n = blockIdx.x / Y_size; + const int y = blockIdx.x % Y_size; + const int x = y * stride; + const int l = max(x - pad, 0); + const int r = min(x - pad + kernel, X_size); + const T scale = T(1) / static_cast(count_include_pad ? kernel : r - l); + const T* X_ptr = X + n * X_size * C; + T* Y_ptr = Y + n * Y_size * C; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + T sum = 0; + for (int i = l; i < r; ++i) { + sum += X_ptr[i * C + c]; } - top_data[index] /= (hend - hstart) * (wend - wstart); + Y_ptr[y * C + c] = sum * scale; } } template -__global__ void Average3DPoolForwardNCHW( - const int nthreads, - const T* bottom_data, - const int num, - const int channels, - const int height, - const int width, - const int depth, - const int pooled_height, - const int pooled_width, - const int pooled_depth, +__global__ void AveragePool2DForwardNCHWCUDAKernel( + const int K, + const int X_H, + const int X_W, + const int Y_H, + const int Y_W, const int kernel_h, const int kernel_w, - const int kernel_d, const int stride_h, const int stride_w, - const int stride_d, const int pad_t, const int pad_l, - const int pad_f, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index; - int pd = n % pooled_depth; - n /= pooled_depth; - int pw = n % pooled_width; - n /= pooled_width; - int ph = n % pooled_height; - n /= pooled_height; - int c = n % channels; - n /= channels; - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int dstart = pd * stride_d - pad_f; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - int dend = min(dstart + kernel_d, depth); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - dstart = max(dstart, 0); - top_data[index] = 0; - int bottom_offset = (n * channels + c) * height * width * depth; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - for (int d = dstart; d < dend; ++d) { - const int input_index = - bottom_offset + h * width * depth + w * depth + d; - top_data[index] += bottom_data[input_index]; - } + const bool count_include_pad, + const T* X, + T* Y) { + const int X_HxW = X_H * X_W; + const int Y_HxW = Y_H * Y_W; + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const T* X_ptr = X + nc * X_HxW; + T* Y_ptr = Y + nc * Y_HxW; + const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS; + if (y < Y_HxW) { + const int yh = y / Y_W; + const int yw = y % Y_W; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + const T scale = T(1) / + static_cast(count_include_pad ? kernel_h * kernel_w + : (b - t) * (r - l)); + T sum = 0; + for (int i = t; i < b; ++i) { + for (int j = l; j < r; ++j) { + sum += X_ptr[i * X_W + j]; } } - top_data[index] /= (hend - hstart) * (wend - wstart) * (dend - dstart); + Y_ptr[y] = sum * scale; } } template -__global__ void Average1DPoolForwardNHWC( - const int nthreads, - const T* bottom_data, - const int num, - const int height, - const int channels, - const int pooled_height, +__global__ void AveragePool2DForwardNHWCCUDAKernel( + const int C, + const int X_H, + const int X_W, + const int Y_H, + const int Y_W, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int c = index % channels; - int ph = (index / channels) % pooled_height; - int n = index / channels / pooled_height; - int hstart = ph * stride_h - pad_t; - int hend = min(hstart + kernel_h, height); - hstart = max(hstart, 0); - T output = 0; - int bottom_offset = n * height * channels + c; - for (int h = hstart; h < hend; ++h) { - output += bottom_data[bottom_offset + h * channels]; + const int pad_l, + const bool count_include_pad, + const T* X, + T* Y) { + const int X_HxW = X_H * X_W; + const int Y_HxW = Y_H * Y_W; + const int n = blockIdx.x / Y_HxW; + const int y = blockIdx.x % Y_HxW; + const int yh = y / Y_W; + const int yw = y % Y_W; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + const T scale = T(1) / + static_cast(count_include_pad ? kernel_h * kernel_w + : (b - t) * (r - l)); + const T* X_ptr = X + n * X_HxW * C; + T* Y_ptr = Y + n * Y_HxW * C; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + T sum = 0; + for (int i = t; i < b; ++i) { + for (int j = l; j < r; ++j) { + sum += X_ptr[(i * X_W + j) * C + c]; + } } - int pool_size = (hend - hstart); - top_data[index] = output / pool_size; + Y_ptr[y * C + c] = sum * scale; } } template -__global__ void Average2DPoolForwardNHWC( - const int nthreads, - const T* bottom_data, - const int num, - const int height, - const int width, - const int channels, - const int pooled_height, - const int pooled_width, +__global__ void AveragePool3DForwardNCHWCUDAKernel( + const int K, + const int X_D, + const int X_H, + const int X_W, + const int Y_D, + const int Y_H, + const int Y_W, + const int kernel_d, const int kernel_h, const int kernel_w, + const int stride_d, const int stride_h, const int stride_w, + const int pad_p, const int pad_t, const int pad_l, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int c = index % channels; - int pw = (index / channels) % pooled_width; - int ph = (index / channels / pooled_width) % pooled_height; - int n = index / channels / pooled_width / pooled_height; - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - T output = 0; - int bottom_offset = n * height * width * channels + c; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - output += bottom_data[bottom_offset + (h * width + w) * channels]; + const bool count_include_pad, + const T* X, + T* Y) { + const int X_HxW = X_D * X_H * X_W; + const int Y_HxW = Y_D * Y_H * Y_W; + const int nc = blockIdx.x / K; + const int block = blockIdx.x % K; + const T* X_ptr = X + nc * X_HxW; + T* Y_ptr = Y + nc * Y_HxW; + const int y = threadIdx.x + block * CAFFE_CUDA_NUM_THREADS; + if (y < Y_HxW) { + const int yy = y / Y_W; + const int yw = y % Y_W; + const int yh = yy % Y_H; + const int yd = yy / Y_H; + const int xd = yd * stride_d; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int p = max(xd - pad_p, 0); + const int a = min(xd - pad_p + kernel_h, X_D); + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + const T scale = T(1) / + static_cast(count_include_pad ? kernel_d * kernel_h * kernel_w + : (a - p) * (b - t) * (r - l)); + T sum = 0; + for (int i = p; i < a; ++i) { + for (int j = t; j < b; ++j) { + for (int k = l; k < r; ++k) { + sum += X_ptr[(i * X_H + j) * X_W + k]; + } } } - int pool_size = (hend - hstart) * (wend - wstart); - top_data[index] = output / pool_size; + Y_ptr[y] = sum * scale; } } template -__global__ void Average3DPoolForwardNHWC( - const int nthreads, - const T* bottom_data, - const int num, - const int height, - const int width, - const int depth, - const int channels, - const int pooled_height, - const int pooled_width, - const int pooled_depth, +__global__ void AveragePool3DForwardNHWCCUDAKernel( + const int C, + const int X_D, + const int X_H, + const int X_W, + const int Y_D, + const int Y_H, + const int Y_W, + const int kernel_d, const int kernel_h, const int kernel_w, - const int kernel_d, + const int stride_d, const int stride_h, const int stride_w, - const int stride_d, + const int pad_p, const int pad_t, const int pad_l, - const int pad_f, - T* top_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int c = index % channels; - int pd = (index / channels) % pooled_depth; - int pw = (index / channels / pooled_depth) % pooled_width; - int ph = (index / channels / pooled_depth / pooled_width) % pooled_height; - int n = index / channels / pooled_depth / pooled_width / pooled_height; - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int dstart = pd * stride_d - pad_f; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - int dend = min(dstart + kernel_d, depth); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - dstart = max(dstart, 0); - T output = 0; - int bottom_offset = n * height * width * depth * channels + c; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - for (int d = dstart; d < dend; ++d) { - const int bottom_index = - bottom_offset + (h * depth * width + w * depth + d) * channels; - output += bottom_data[bottom_index]; + const bool count_include_pad, + const T* X, + T* Y) { + const int X_HxW = X_D * X_H * X_W; + const int Y_HxW = Y_D * Y_H * Y_W; + const int n = blockIdx.x / Y_HxW; + const int y = blockIdx.x % Y_HxW; + const int yy = y / Y_W; + const int yw = y % Y_W; + const int yh = yy % Y_H; + const int yd = yy / Y_H; + const int xd = yd * stride_d; + const int xh = yh * stride_h; + const int xw = yw * stride_w; + const int p = max(xd - pad_p, 0); + const int a = min(xd - pad_p + kernel_h, X_D); + const int t = max(xh - pad_t, 0); + const int b = min(xh - pad_t + kernel_h, X_H); + const int l = max(xw - pad_l, 0); + const int r = min(xw - pad_l + kernel_w, X_W); + const T scale = T(1) / + static_cast(count_include_pad ? kernel_d * kernel_h * kernel_w + : (a - p) * (b - t) * (r - l)); + const T* X_ptr = X + n * X_HxW * C; + T* Y_ptr = Y + n * Y_HxW * C; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + T sum = 0; + for (int i = p; i < a; ++i) { + for (int j = t; j < b; ++j) { + for (int k = l; k < r; ++k) { + sum += X_ptr[((i * X_H + j) * X_W + k) * C + c]; } } } - int pool_size = (hend - hstart) * (wend - wstart) * (dend - dstart); - top_data[index] = output / pool_size; + Y_ptr[y * C + c] = sum * scale; } } @@ -562,163 +569,211 @@ __global__ void Ave3DPoolBackwardNHWC( } } -} // namespace +} // namespace template <> -bool PoolOp::RunOnDeviceWithOrderNCHW() { - auto& X = Input(0); - auto* Y = Output(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); - int output_size = Y->size(); - switch (kernel_.size()) { - case 1: - Average1DPoolForwardNCHW - <<>>( - output_size, - X.data(), - X.dim32(0), - X.dim32(1), - X.dim32(2), - Y->dim32(2), - kernel_h(), - stride_h(), - pad_t(), - Y->template mutable_data()); - break; - case 2: - Average2DPoolForwardNCHW - <<>>( - output_size, - X.data(), - X.dim32(0), - X.dim32(1), - X.dim32(2), - X.dim32(3), - Y->dim32(2), - Y->dim32(3), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l(), - Y->template mutable_data()); - break; - case 3: - Average3DPoolForwardNCHW - <<>>( - output_size, - X.data(), - X.dim32(0), - X.dim32(1), - X.dim32(2), - X.dim32(3), - X.dim32(4), - Y->dim32(2), - Y->dim32(3), - Y->dim32(4), - kernel_h(), - kernel_w(), - kernel_[2], - stride_h(), - stride_w(), - stride_[2], - pad_t(), - pad_l(), - pads_[2], - Y->template mutable_data()); - break; - default: - CAFFE_THROW("Unsupported pooling size : ", kernel_.size()); - } +template <> +bool AveragePoolFunctor::Forward( + const int N, + const int C, + const std::array& X_dims, + const std::array& Y_dims, + const std::array& kernel, + const std::array& /* dilation */, + const std::array& stride, + const std::array& pads, + const float* X, + float* Y, + CUDAContext* context) const { + const int K = + (Y_dims[0] + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; + AveragePool1DForwardNCHWCUDAKernel + <<cuda_stream()>>>( + K, + X_dims[0], + Y_dims[0], + kernel[0], + stride[0], + pads[0], + count_include_pad, + X, + Y); return true; } template <> -bool PoolOp::RunOnDeviceWithOrderNHWC() { - auto& X = Input(0); - auto* Y = Output(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(X.ndim() - 1)); - int output_size = Y->size(); - switch (kernel_.size()) { - case 1: - Average1DPoolForwardNHWC - <<>>( - output_size, - X.data(), - X.dim32(0), - X.dim32(1), - X.dim32(2), - Y->dim32(1), - kernel_h(), - stride_h(), - pad_t(), - Y->template mutable_data()); - break; - case 2: - Average2DPoolForwardNHWC - <<>>( - output_size, - X.data(), - X.dim32(0), - X.dim32(1), - X.dim32(2), - X.dim32(3), - Y->dim32(1), - Y->dim32(2), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l(), - Y->template mutable_data()); - break; - case 3: - Average3DPoolForwardNHWC - <<>>( - output_size, - X.data(), - X.dim32(0), - X.dim32(1), - X.dim32(2), - X.dim32(3), - X.dim32(4), - Y->dim32(1), - Y->dim32(2), - Y->dim32(3), - kernel_h(), - kernel_w(), - kernel_[2], - stride_h(), - stride_w(), - stride_[2], - pad_t(), - pad_l(), - pads_[2], - Y->template mutable_data()); - break; - default: - CAFFE_THROW("Unsupported pooling size : ", kernel_.size()); - } +template <> +bool AveragePoolFunctor::Forward( + const int N, + const int C, + const std::array& X_dims, + const std::array& Y_dims, + const std::array& kernel, + const std::array& /* dilation */, + const std::array& stride, + const std::array& pads, + const float* X, + float* Y, + CUDAContext* context) const { + AveragePool1DForwardNHWCCUDAKernel + <<cuda_stream()>>>( + C, + X_dims[0], + Y_dims[0], + kernel[0], + stride[0], + pads[0], + count_include_pad, + X, + Y); + return true; +} + +template <> +template <> +bool AveragePoolFunctor::Forward( + const int N, + const int C, + const std::array& X_dims, + const std::array& Y_dims, + const std::array& kernel, + const std::array& /* dilation */, + const std::array& stride, + const std::array& pads, + const float* X, + float* Y, + CUDAContext* context) const { + const int Y_HxW = Y_dims[0] * Y_dims[1]; + const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; + AveragePool2DForwardNCHWCUDAKernel + <<cuda_stream()>>>( + K, + X_dims[0], + X_dims[1], + Y_dims[0], + Y_dims[1], + kernel[0], + kernel[1], + stride[0], + stride[1], + pads[0], + pads[1], + count_include_pad, + X, + Y); + return true; +} + +template <> +template <> +bool AveragePoolFunctor::Forward( + const int N, + const int C, + const std::array& X_dims, + const std::array& Y_dims, + const std::array& kernel, + const std::array& /* dilation */, + const std::array& stride, + const std::array& pads, + const float* X, + float* Y, + CUDAContext* context) const { + const int Y_HxW = Y_dims[0] * Y_dims[1]; + AveragePool2DForwardNHWCCUDAKernel + <<cuda_stream()>>>( + C, + X_dims[0], + X_dims[1], + Y_dims[0], + Y_dims[1], + kernel[0], + kernel[1], + stride[0], + stride[1], + pads[0], + pads[1], + count_include_pad, + X, + Y); + return true; +} + +template <> +template <> +bool AveragePoolFunctor::Forward( + const int N, + const int C, + const std::array& X_dims, + const std::array& Y_dims, + const std::array& kernel, + const std::array& /* dilation */, + const std::array& stride, + const std::array& pads, + const float* X, + float* Y, + CUDAContext* context) const { + const int Y_HxW = Y_dims[0] * Y_dims[1] * Y_dims[2]; + const int K = (Y_HxW + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; + AveragePool3DForwardNCHWCUDAKernel + <<cuda_stream()>>>( + K, + X_dims[0], + X_dims[1], + X_dims[2], + Y_dims[0], + Y_dims[1], + Y_dims[2], + kernel[0], + kernel[1], + kernel[2], + stride[0], + stride[1], + stride[2], + pads[0], + pads[1], + pads[2], + count_include_pad, + X, + Y); + return true; +} + +template <> +template <> +bool AveragePoolFunctor::Forward( + const int N, + const int C, + const std::array& X_dims, + const std::array& Y_dims, + const std::array& kernel, + const std::array& /* dilation */, + const std::array& stride, + const std::array& pads, + const float* X, + float* Y, + CUDAContext* context) const { + const int Y_HxW = Y_dims[0] * Y_dims[1] * Y_dims[2]; + AveragePool3DForwardNHWCCUDAKernel + <<cuda_stream()>>>( + C, + X_dims[0], + X_dims[1], + X_dims[2], + Y_dims[0], + Y_dims[1], + Y_dims[2], + kernel[0], + kernel[1], + kernel[2], + stride[0], + stride[1], + stride[2], + pads[0], + pads[1], + pads[2], + count_include_pad, + X, + Y); return true; } @@ -889,7 +944,6 @@ bool PoolGradientOp:: return true; } - namespace { template @@ -1335,8 +1389,7 @@ __global__ void MaxPool2DBackwardNHWC( const int phend = min(h / stride_h + 1, pooled_height); const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; const int pwend = min(w / stride_w + 1, pooled_width); - const int top_offset = - n * pooled_height * pooled_width * channels + c; + const int top_offset = n * pooled_height * pooled_width * channels + c; bottom_diff[index] = 0; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { @@ -1403,7 +1456,7 @@ __global__ void MaxPool3DBackwardNHWC( } } } -} // namespace +} // namespace template <> bool PoolOp::RunOnDeviceWithOrderNCHW() { @@ -1735,28 +1788,38 @@ bool PoolGradientOp::RunOnDeviceWithOrderNHWC() { return true; } -REGISTER_CUDA_OPERATOR(AveragePool, PoolOp); -REGISTER_CUDA_OPERATOR(AveragePoolGradient, - PoolGradientOp); +REGISTER_CUDA_OPERATOR( + AveragePool, + PoolOp>); +REGISTER_CUDA_OPERATOR( + AveragePoolGradient, + PoolGradientOp); -REGISTER_CUDA_OPERATOR(AveragePool1D, PoolOp); +REGISTER_CUDA_OPERATOR( + AveragePool1D, + PoolOp>); REGISTER_CUDA_OPERATOR( AveragePool1DGradient, PoolGradientOp); -REGISTER_CUDA_OPERATOR(AveragePool2D, PoolOp); +REGISTER_CUDA_OPERATOR( + AveragePool2D, + PoolOp>); REGISTER_CUDA_OPERATOR( AveragePool2DGradient, PoolGradientOp); -REGISTER_CUDA_OPERATOR(AveragePool3D, PoolOp); +REGISTER_CUDA_OPERATOR( + AveragePool3D, + PoolOp>); REGISTER_CUDA_OPERATOR( AveragePool3DGradient, PoolGradientOp); REGISTER_CUDA_OPERATOR(MaxPool, PoolOp); -REGISTER_CUDA_OPERATOR(MaxPoolGradient, - PoolGradientOp); +REGISTER_CUDA_OPERATOR( + MaxPoolGradient, + PoolGradientOp); REGISTER_CUDA_OPERATOR(MaxPool1D, PoolOp); REGISTER_CUDA_OPERATOR( @@ -1772,4 +1835,4 @@ REGISTER_CUDA_OPERATOR(MaxPool3D, PoolOp); REGISTER_CUDA_OPERATOR( MaxPool3DGradient, PoolGradientOp); -} // namespace caffe2 +} // namespace caffe2