Add count_include_pad for avg_pool on CuDNN (#16100)
authorXiaomeng Yang <yangxm@fb.com>
Thu, 17 Jan 2019 10:07:04 +0000 (02:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 10:10:12 +0000 (02:10 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16100

Add count_include_pad for avg_pool on CuDNN

Reviewed By: houseroad

Differential Revision: D13707959

fbshipit-source-id: 261f5d116066fef75cf9a5787dfbc5d12b5b9f9b

caffe2/operators/pool_op.cc
caffe2/operators/pool_op.cu
caffe2/operators/pool_op.h
caffe2/operators/pool_op_cudnn.cc [new file with mode: 0644]
caffe2/operators/pool_op_cudnn.cu [deleted file]
caffe2/python/operator_test/pooling_test.py

index 8f5d522..314e7be 100644 (file)
@@ -997,14 +997,24 @@ std::function<void(OpSchema&)> AveragePoolDocGenerator(const char* dim) {
         "X",
         "*(type: Tensor`<float>`)* Input data tensor of shape NCHW or NHWC.");
     schema.Output(0, "Y", "*(type: Tensor`<float>`)* Output data tensor.");
-    /*
-    schema.Arg("kernel", "*(type: int)* Size of the window to take an average
-    over."); schema.Arg("stride", "*(type: int)* Stride of the window.");
-    schema.Arg("pad", "*(type: int)* Implicit zero padding to be added on both
-    sides."); schema.Arg("dilation", "*(type: int)* Parameter that controls
-    the stride of elements in the window."); schema.Arg("order", "*(type:
-    string; default: 'NCHW')* Order of the blob dimensions.");
-    */
+    // schema.Arg(
+    //     "kernel", "*(type: int)* Size of the window to take an average
+    //     over.");
+    // schema.Arg("stride", "*(type: int)* Stride of the window.");
+    // schema.Arg(
+    //     "pad",
+    //     "*(type: int)* Implicit zero padding to be added on both sides.");
+    // schema.Arg(
+    //     "dilation",
+    //     "*(type: int)* Parameter that controls the stride of elements in the
+    //     " "window.");
+    // schema.Arg(
+    //     "order",
+    //     "*(type: string; default: 'NCHW')* Order of the blob dimensions.");
+    // schema.Arg(
+    //     "count_include_pad",
+    //     "*(type: bool; default: False)* When True, will include the "
+    //     "zero-padding in the averaging.");
   };
 }
 
index 83f2759..b9f6050 100644 (file)
@@ -689,25 +689,53 @@ __global__ void AveragePool3DBackwardNHWCCUDAKernel(
 } // namespace
 
 template <>
-template <typename T, StorageOrder kOrder>
-bool AveragePoolFunctor<CUDAContext>::GlobalPoolingForward(
-    const int N,
-    const int C,
-    const int HxW,
-    const T* X,
-    T* Y,
-    CUDAContext* context) const {
-  if (kOrder == StorageOrder::NCHW) {
-    const std::array<int, 2> dims = {N * C, HxW};
-    const int axis = 1;
-    math::ReduceMean<float, CUDAContext>(
-        2, dims.data(), 1, &axis, 1.0f, X, Y, context);
-  } else {
-    const std::array<int, 3> dims = {N, HxW, C};
-    const int axis = 1;
-    math::ReduceMean<float, CUDAContext>(
-        3, dims.data(), 1, &axis, 1.0f, X, Y, context);
+template <>
+bool AveragePoolFunctor<CUDAContext>::
+    GlobalPoolingForward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CUDAContext* context) const {
+  const std::array<int, 2> dims = {N * C, HxW};
+  const int axis = 1;
+  math::ReduceMean<float, CUDAContext>(
+      2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  return true;
+}
+
+template <>
+template <>
+bool AveragePoolFunctor<CUDAContext>::
+    GlobalPoolingForward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CUDAContext* context) const {
+  if (ones.numel() != HxW) {
+    ones.Resize(HxW);
+    math::Set<float, CUDAContext>(
+        HxW, 1.0f, ones.mutable_data<float>(), context);
   }
+  math::GemmStridedBatched<float, CUDAContext>(
+      CblasTrans,
+      CblasNoTrans,
+      N,
+      C,
+      1,
+      HxW,
+      1.0f / static_cast<float>(HxW),
+      X,
+      HxW * C,
+      ones.data<float>(),
+      0,
+      0.0f,
+      Y,
+      C,
+      context);
   return true;
 }
 
@@ -1719,25 +1747,36 @@ __global__ void MaxPool3DBackwardNHWCCUDAKernel(
 } // namespace
 
 template <>
-template <typename T, StorageOrder kOrder>
-bool MaxPoolFunctor<CUDAContext>::GlobalPoolingForward(
-    const int N,
-    const int C,
-    const int HxW,
-    const T* X,
-    T* Y,
-    CUDAContext* context) const {
-  if (kOrder == StorageOrder::NCHW) {
-    const std::array<int, 2> dims = {N * C, HxW};
-    const int axis = 1;
-    math::ReduceMax<float, CUDAContext>(
-        2, dims.data(), 1, &axis, 1.0f, X, Y, context);
-  } else {
-    const std::array<int, 3> dims = {N, HxW, C};
-    const int axis = 1;
-    math::ReduceMax<float, CUDAContext>(
-        3, dims.data(), 1, &axis, 1.0f, X, Y, context);
-  }
+template <>
+bool MaxPoolFunctor<CUDAContext>::
+    GlobalPoolingForward<float, StorageOrder::NCHW>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CUDAContext* context) const {
+  const std::array<int, 2> dims = {N * C, HxW};
+  const int axis = 1;
+  math::ReduceMax<float, CUDAContext>(
+      2, dims.data(), 1, &axis, 1.0f, X, Y, context);
+  return true;
+}
+
+template <>
+template <>
+bool MaxPoolFunctor<CUDAContext>::
+    GlobalPoolingForward<float, StorageOrder::NHWC>(
+        const int N,
+        const int C,
+        const int HxW,
+        const float* X,
+        float* Y,
+        CUDAContext* context) const {
+  const std::array<int, 3> dims = {N, HxW, C};
+  const int axis = 1;
+  math::ReduceMax<float, CUDAContext>(
+      3, dims.data(), 1, &axis, 1.0f, X, Y, context);
   return true;
 }
 
index 8c9db86..909fe12 100644 (file)
@@ -238,6 +238,7 @@ struct AveragePoolFunctor {
       Context* context) const;
 
   const bool count_include_pad;
+  Tensor ones{Context::GetDeviceType()};
 };
 
 template <class Context>
diff --git a/caffe2/operators/pool_op_cudnn.cc b/caffe2/operators/pool_op_cudnn.cc
new file mode 100644 (file)
index 0000000..1ed723c
--- /dev/null
@@ -0,0 +1,594 @@
+#include "caffe2/operators/pool_op.h"
+
+#include <algorithm>
+#include <array>
+#include <type_traits>
+#include <vector>
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/core/cudnn_wrappers.h"
+
+namespace caffe2 {
+
+namespace {
+
+void SetTensorDescriptor(
+    const cudnnDataType_t data_type,
+    const StorageOrder order,
+    const std::vector<std::int64_t>& dims,
+    cudnnTensorDescriptor_t* desc) {
+  const int ndim = dims.size();
+  const int N = dims[0];
+  const int C = order == StorageOrder::NCHW ? dims[1] : dims[ndim - 1];
+  switch (ndim) {
+    case 4: {
+      const int H = order == StorageOrder::NCHW ? dims[2] : dims[1];
+      const int W = order == StorageOrder::NCHW ? dims[3] : dims[2];
+      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
+          *desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
+      break;
+    }
+    case 5: {
+      const int D = order == StorageOrder::NCHW ? dims[2] : dims[1];
+      const int H = order == StorageOrder::NCHW ? dims[3] : dims[2];
+      const int W = order == StorageOrder::NCHW ? dims[4] : dims[3];
+      const std::array<int, 5> dims_arr = {N, C, D, H, W};
+      const std::array<int, 5> strides_arr = order == StorageOrder::NCHW
+          ? std::array<int, 5>{C * D * H * W, D * H * W, H * W, W, 1}
+          : std::array<int, 5>{D * H * W * C, 1, H * W * C, W * C, C};
+      CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
+          *desc, data_type, 5, dims_arr.data(), strides_arr.data()));
+      break;
+    }
+    default: {
+      CAFFE_THROW("Unsupported tensor dim: ", ndim);
+      break;
+    }
+  }
+}
+
+template <class Functor>
+class CuDNNPoolOp final : public ConvPoolOpBase<CUDAContext> {
+ public:
+  CuDNNPoolOp(const OperatorDef& operator_def, Workspace* ws)
+      : ConvPoolOpBase<CUDAContext>(operator_def, ws),
+        cudnn_wrapper_(&context_),
+        functor_(*this),
+        equal_padding_(std::equal(
+            pads_.cbegin(),
+            pads_.cbegin() + kernel_.size(),
+            pads_.cbegin() + kernel_.size())) {
+    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&X_desc_));
+    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&Y_desc_));
+    CUDNN_ENFORCE(cudnnCreatePoolingDescriptor(&pooling_desc_));
+    if (!global_pooling_ && equal_padding_) {
+      if (kernel_.size() == 2) {
+        CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
+            pooling_desc_,
+            functor_.GetPoolingMode(),
+            CUDNN_NOT_PROPAGATE_NAN,
+            kernel_h(),
+            kernel_w(),
+            pad_t(),
+            pad_l(),
+            stride_h(),
+            stride_w()));
+      } else if (kernel_.size() == 3) {
+        CUDNN_ENFORCE(cudnnSetPoolingNdDescriptor(
+            pooling_desc_,
+            functor_.GetPoolingMode(),
+            CUDNN_NOT_PROPAGATE_NAN,
+            kernel_.size(),
+            kernel_.data(),
+            pads_.data(),
+            stride_.data()));
+      }
+    }
+  }
+
+  ~CuDNNPoolOp() {
+    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(X_desc_));
+    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(Y_desc_));
+    CUDNN_ENFORCE(cudnnDestroyPoolingDescriptor(pooling_desc_));
+  }
+
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& X = Input(0);
+    auto* Y = Output(0);
+    const int ndim = X.ndim();
+    const int N = X.dim32(0);
+    const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
+    ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, C);
+    const T* X_data = X.template data<T>();
+    T* Y_data = Y->template mutable_data<T>();
+
+    if (global_pooling_) {
+      const int HxW = X.numel() / (N * C);
+      if (order_ == StorageOrder::NCHW) {
+        return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>(
+            N, C, HxW, X_data, Y_data, &context_);
+      } else {
+        return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>(
+            N, C, HxW, X_data, Y_data, &context_);
+      }
+    }
+
+    const std::vector<int> X_HW_dims = GetDims(X);
+    const std::vector<int> Y_HW_dims = GetDims(*Y);
+    if (order_ == StorageOrder::NHWC) {
+      // CuDNN Pooling on NHWC order is very slow, fallback to CUDA
+      // implementation.
+      return functor_.template Forward<T, StorageOrder::NHWC>(
+          N,
+          C,
+          X_HW_dims,
+          Y_HW_dims,
+          kernel_,
+          dilation_,
+          stride_,
+          pads_,
+          X.template data<T>(),
+          Y->template mutable_data<T>(),
+          &context_);
+    } else if (!equal_padding_ || ndim == 3) {
+      return functor_.template Forward<T, StorageOrder::NCHW>(
+          N,
+          C,
+          X_HW_dims,
+          Y_HW_dims,
+          kernel_,
+          dilation_,
+          stride_,
+          pads_,
+          X.template data<T>(),
+          Y->template mutable_data<T>(),
+          &context_);
+    }
+
+    const std::vector<std::int64_t> X_dims = X.sizes().vec();
+    const std::vector<std::int64_t> Y_dims = Y->sizes().vec();
+    if (cached_X_dims_ != X_dims) {
+      constexpr cudnnDataType_t data_type = cudnnTypeWrapper<T>::type;
+      SetTensorDescriptor(data_type, order_, X_dims, &X_desc_);
+      SetTensorDescriptor(data_type, order_, Y_dims, &Y_desc_);
+      cached_X_dims_ = X_dims;
+    }
+    CUDNN_ENFORCE(cudnnPoolingForward(
+        cudnn_wrapper_.inline_cudnn_handle(),
+        pooling_desc_,
+        cudnnTypeWrapper<T>::kOne(),
+        X_desc_,
+        X_data,
+        cudnnTypeWrapper<T>::kZero(),
+        Y_desc_,
+        Y_data));
+
+    return true;
+  }
+
+ private:
+  CuDNNWrapper cudnn_wrapper_;
+  cudnnTensorDescriptor_t X_desc_;
+  cudnnTensorDescriptor_t Y_desc_;
+  cudnnPoolingDescriptor_t pooling_desc_;
+
+  const Functor functor_;
+
+  const bool equal_padding_;
+  std::vector<std::int64_t> cached_X_dims_;
+};
+
+template <class Functor>
+class CuDNNPoolGradientOp final : public ConvPoolOpBase<CUDAContext> {
+ public:
+  CuDNNPoolGradientOp(const OperatorDef& operator_def, Workspace* ws)
+      : ConvPoolOpBase<CUDAContext>(operator_def, ws),
+        cudnn_wrapper_(&context_),
+        functor_(*this),
+        equal_padding_(std::equal(
+            pads_.cbegin(),
+            pads_.cbegin() + kernel_.size(),
+            pads_.cbegin() + kernel_.size())) {
+    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&X_desc_));
+    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&Y_desc_));
+    CUDNN_ENFORCE(cudnnCreatePoolingDescriptor(&pooling_desc_));
+    if (!global_pooling_ && equal_padding_) {
+      if (kernel_.size() == 2) {
+        CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
+            pooling_desc_,
+            functor_.GetPoolingMode(),
+            CUDNN_NOT_PROPAGATE_NAN,
+            kernel_h(),
+            kernel_w(),
+            pad_t(),
+            pad_l(),
+            stride_h(),
+            stride_w()));
+      } else if (kernel_.size() == 3) {
+        CUDNN_ENFORCE(cudnnSetPoolingNdDescriptor(
+            pooling_desc_,
+            functor_.GetPoolingMode(),
+            CUDNN_NOT_PROPAGATE_NAN,
+            kernel_.size(),
+            kernel_.data(),
+            pads_.data(),
+            stride_.data()));
+      }
+    }
+  }
+
+  ~CuDNNPoolGradientOp() {
+    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(X_desc_));
+    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(Y_desc_));
+    CUDNN_ENFORCE(cudnnDestroyPoolingDescriptor(pooling_desc_));
+  }
+
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& X = Input(0);
+    const auto& Y = Input(1);
+    const auto& dY = Input(2);
+    auto* dX = Output(0, X.sizes(), at::dtype<T>());
+    const int ndim = X.ndim();
+    const int N = X.dim32(0);
+    const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
+    const std::vector<int> X_HW_dims = GetDims(X);
+    const std::vector<int> Y_HW_dims = GetDims(Y);
+    ConvPoolOpBase<CUDAContext>::ComputePads(X_HW_dims);
+    const T* dY_data = dY.template data<T>();
+    const T* X_data = X.template data<T>();
+    const T* Y_data = Y.template data<T>();
+    T* dX_data = dX->template mutable_data<T>();
+
+    if (global_pooling_) {
+      const int HxW = X.numel() / (N * C);
+      if (order_ == StorageOrder::NCHW) {
+        return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>(
+            N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
+      } else {
+        return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>(
+            N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
+      }
+    }
+
+    if (order_ == StorageOrder::NHWC) {
+      // CuDNN Pooling on NHWC order is very slow, fallback to CUDA
+      // implementation.
+      return functor_.template Backward<T, StorageOrder::NHWC>(
+          N,
+          C,
+          X_HW_dims,
+          Y_HW_dims,
+          kernel_,
+          dilation_,
+          stride_,
+          pads_,
+          dY_data,
+          X_data,
+          Y_data,
+          dX_data,
+          &context_);
+    } else if (!equal_padding_ || ndim == 3) {
+      return functor_.template Backward<T, StorageOrder::NCHW>(
+          N,
+          C,
+          X_HW_dims,
+          Y_HW_dims,
+          kernel_,
+          dilation_,
+          stride_,
+          pads_,
+          dY_data,
+          X_data,
+          Y_data,
+          dX_data,
+          &context_);
+    }
+
+    const std::vector<std::int64_t> X_dims = X.sizes().vec();
+    const std::vector<std::int64_t> Y_dims = Y.sizes().vec();
+    if (cached_X_dims_ != X_dims) {
+      constexpr cudnnDataType_t data_type = cudnnTypeWrapper<T>::type;
+      SetTensorDescriptor(data_type, order_, X_dims, &X_desc_);
+      SetTensorDescriptor(data_type, order_, Y_dims, &Y_desc_);
+      cached_X_dims_ = X_dims;
+    }
+    CUDNN_ENFORCE(cudnnPoolingBackward(
+        cudnn_wrapper_.inline_cudnn_handle(),
+        pooling_desc_,
+        cudnnTypeWrapper<T>::kOne(),
+        Y_desc_,
+        Y_data,
+        Y_desc_,
+        dY_data,
+        X_desc_,
+        X_data,
+        cudnnTypeWrapper<T>::kZero(),
+        X_desc_,
+        dX_data));
+
+    return true;
+  }
+
+ private:
+  CuDNNWrapper cudnn_wrapper_;
+  cudnnTensorDescriptor_t X_desc_;
+  cudnnTensorDescriptor_t Y_desc_;
+  cudnnPoolingDescriptor_t pooling_desc_;
+
+  const Functor functor_;
+
+  const bool equal_padding_;
+  std::vector<std::int64_t> cached_X_dims_;
+};
+
+struct CuDNNAveragePoolFunctor {
+  explicit CuDNNAveragePoolFunctor(const OperatorBase& op)
+      : avg_pool_functor(op) {}
+
+  cudnnPoolingMode_t GetPoolingMode() const {
+    return avg_pool_functor.count_include_pad
+        ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
+        : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingForward(
+      const int N,
+      const int C,
+      const int HxW,
+      const T* X,
+      T* Y,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for average_pooling.");
+      return false;
+    } else {
+      return avg_pool_functor.GlobalPoolingForward<T, kOrder>(
+          N, C, HxW, X, Y, context);
+    }
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool Forward(
+      const int N,
+      const int C,
+      const std::vector<int>& X_dims,
+      const std::vector<int>& Y_dims,
+      const std::vector<int>& kernel,
+      const std::vector<int>& dilation,
+      const std::vector<int>& stride,
+      const std::vector<int>& pads,
+      const T* X,
+      T* Y,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for average_pooling.");
+      return false;
+    } else {
+      return avg_pool_functor.Forward<T, kOrder>(
+          N, C, X_dims, Y_dims, kernel, dilation, stride, pads, X, Y, context);
+    }
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingBackward(
+      const int N,
+      const int C,
+      const int HxW,
+      const T* dY,
+      const T* X,
+      const T* Y,
+      T* dX,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for average_pooling.");
+      return false;
+    } else {
+      return avg_pool_functor.GlobalPoolingBackward<T, kOrder>(
+          N, C, HxW, dY, X, Y, dX, context);
+    }
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool Backward(
+      const int N,
+      const int C,
+      const std::vector<int>& X_dims,
+      const std::vector<int>& Y_dims,
+      const std::vector<int>& kernel,
+      const std::vector<int>& dilation,
+      const std::vector<int>& stride,
+      const std::vector<int>& pads,
+      const T* dY,
+      const T* X,
+      const T* Y,
+      T* dX,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for average_pooling.");
+      return false;
+    } else {
+      return avg_pool_functor.Backward<T, kOrder>(
+          N,
+          C,
+          X_dims,
+          Y_dims,
+          kernel,
+          dilation,
+          stride,
+          pads,
+          dY,
+          X,
+          Y,
+          dX,
+          context);
+    }
+  }
+
+  const AveragePoolFunctor<CUDAContext> avg_pool_functor;
+};
+
+struct CuDNNMaxPoolFunctor {
+  explicit CuDNNMaxPoolFunctor(const OperatorBase& op)
+      : max_pool_functor(op),
+        deterministic(op.GetSingleArgument<bool>("deterministic", false)) {}
+
+  cudnnPoolingMode_t GetPoolingMode() const {
+#if CUDNN_VERSION_MIN(6, 0, 0)
+    return deterministic ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX;
+#else
+    return CUDNN_POOLING_MAX;
+#endif
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingForward(
+      const int N,
+      const int C,
+      const int HxW,
+      const T* X,
+      T* Y,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for max_pooling.");
+      return false;
+    } else {
+      return max_pool_functor.GlobalPoolingForward<T, kOrder>(
+          N, C, HxW, X, Y, context);
+    }
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool Forward(
+      const int N,
+      const int C,
+      const std::vector<int>& X_dims,
+      const std::vector<int>& Y_dims,
+      const std::vector<int>& kernel,
+      const std::vector<int>& dilation,
+      const std::vector<int>& stride,
+      const std::vector<int>& pads,
+      const T* X,
+      T* Y,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for max_pooling.");
+      return false;
+    } else {
+      return max_pool_functor.Forward<T, kOrder>(
+          N, C, X_dims, Y_dims, kernel, dilation, stride, pads, X, Y, context);
+    }
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool GlobalPoolingBackward(
+      const int N,
+      const int C,
+      const int HxW,
+      const T* dY,
+      const T* X,
+      const T* Y,
+      T* dX,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for max_pooling.");
+      return false;
+    } else {
+      return max_pool_functor.GlobalPoolingBackward<T, kOrder>(
+          N, C, HxW, dY, X, Y, dX, context);
+    }
+  }
+
+  template <typename T, StorageOrder kOrder>
+  bool Backward(
+      const int N,
+      const int C,
+      const std::vector<int>& X_dims,
+      const std::vector<int>& Y_dims,
+      const std::vector<int>& kernel,
+      const std::vector<int>& dilation,
+      const std::vector<int>& stride,
+      const std::vector<int>& pads,
+      const T* dY,
+      const T* X,
+      const T* Y,
+      T* dX,
+      CUDAContext* context) const {
+    if (std::is_same<T, at::Half>::value) {
+      CAFFE_THROW("Float16 is not supported for max_pooling.");
+      return false;
+    } else {
+      return max_pool_functor.Backward<T, kOrder>(
+          N,
+          C,
+          X_dims,
+          Y_dims,
+          kernel,
+          dilation,
+          stride,
+          pads,
+          dY,
+          X,
+          Y,
+          dX,
+          context);
+    }
+  }
+
+  const MaxPoolFunctor<CUDAContext> max_pool_functor;
+  const bool deterministic;
+};
+
+} // namespace
+
+REGISTER_CUDNN_OPERATOR(AveragePool, CuDNNPoolOp<CuDNNAveragePoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    AveragePoolGradient,
+    CuDNNPoolGradientOp<CuDNNAveragePoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(AveragePool1D, CuDNNPoolOp<CuDNNAveragePoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    AveragePool1DGradient,
+    CuDNNPoolGradientOp<CuDNNAveragePoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(AveragePool2D, CuDNNPoolOp<CuDNNAveragePoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    AveragePool2DGradient,
+    CuDNNPoolGradientOp<CuDNNAveragePoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(AveragePool3D, CuDNNPoolOp<CuDNNAveragePoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    AveragePool3DGradient,
+    CuDNNPoolGradientOp<CuDNNAveragePoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(MaxPool, CuDNNPoolOp<CuDNNMaxPoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    MaxPoolGradient,
+    CuDNNPoolGradientOp<CuDNNMaxPoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(MaxPool1D, CuDNNPoolOp<CuDNNMaxPoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    MaxPool1DGradient,
+    CuDNNPoolGradientOp<CuDNNMaxPoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(MaxPool2D, CuDNNPoolOp<CuDNNMaxPoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    MaxPool2DGradient,
+    CuDNNPoolGradientOp<CuDNNMaxPoolFunctor>);
+
+REGISTER_CUDNN_OPERATOR(MaxPool3D, CuDNNPoolOp<CuDNNMaxPoolFunctor>);
+REGISTER_CUDNN_OPERATOR(
+    MaxPool3DGradient,
+    CuDNNPoolGradientOp<CuDNNMaxPoolFunctor>);
+
+} // namespace caffe2
diff --git a/caffe2/operators/pool_op_cudnn.cu b/caffe2/operators/pool_op_cudnn.cu
deleted file mode 100644 (file)
index b521d34..0000000
+++ /dev/null
@@ -1,534 +0,0 @@
-#include "caffe2/core/context_gpu.h"
-#include "caffe2/core/cudnn_wrappers.h"
-#include "caffe2/operators/conv_pool_op_base.h"
-
-#include <cub/cub.cuh>
-
-namespace caffe2 {
-
-namespace {
-
-// Explicit fast paths for avg and max global pooling due to CuDNN global
-// pooling performance bug which makes pooling extremely slow.
-template <typename T>
-__global__ void
-global_avgpool_kernel_NCHW(const int NC, const int sz, const T* data, T* out) {
-  typedef cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS> BlockReduce;
-  __shared__ typename BlockReduce::TempStorage temp_storage;
-  for (int j = blockIdx.x; j < NC; j += gridDim.x) {
-    T sum(0);
-    for (int k = threadIdx.x; k < sz; k += blockDim.x) {
-      sum += data[j * sz + k];
-    }
-    float totalsum = BlockReduce(temp_storage).Sum(sum);
-    if (threadIdx.x == 0) {
-      out[j] = totalsum / sz;
-    }
-    __syncthreads();
-  }
-}
-
-template <typename T>
-__global__ void
-global_avgpool_backward_NCHW(const int NC, const int sz, const T* dx, T* out) {
-  CUDA_1D_KERNEL_LOOP(i, NC * sz) {
-    out[i] = dx[i / sz] / sz;
-  }
-}
-
-template <typename T>
-__global__ void
-global_maxpool_kernel_NCHW(const int NC, const int sz, const T* data, T* out) {
-  typedef cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS> BlockReduce;
-  __shared__ typename BlockReduce::TempStorage temp_storage;
-  for (int j = blockIdx.x; j < NC; j += gridDim.x) {
-    T max(-FLT_MAX);
-    for (int k = threadIdx.x; k < sz; k += blockDim.x) {
-      max = data[j * sz + k] > max ? data[j * sz + k] : max;
-    }
-    float totalmax = BlockReduce(temp_storage).Reduce(max, cub::Max());
-    if (threadIdx.x == 0) {
-      out[j] = totalmax;
-    }
-    __syncthreads();
-  }
-}
-
-template <typename T>
-__global__ void global_maxpool_backward_NCHW(
-    const int NC,
-    const int sz,
-    const T* dx,
-    T* out,
-    const T* x,
-    const T* in) {
-  CUDA_1D_KERNEL_LOOP(i, NC * sz) {
-    if (in[i] == x[i / sz]) {
-      out[i] = dx[i / sz];
-    } else {
-      out[i] = 0.0;
-    }
-  }
-}
-
-template <typename T>
-void setTensorDescriptor(
-    const int size,
-    const StorageOrder order,
-    const int N,
-    const int C,
-    const int H,
-    const int W,
-    const int D,
-    cudnnTensorDescriptor_t& desc) {
-  if (size == 4) {
-    CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
-        desc,
-        GetCudnnTensorFormat(order),
-        cudnnTypeWrapper<T>::type,
-        N,
-        C,
-        H,
-        W));
-  } else {
-    vector<int> dims = {N, C, H, W, D};
-    vector<int> strides;
-    order == NCHW
-        ? strides.insert(strides.end(), {C * H * W * D, H * W * D, W * D, D, 1})
-        : strides.insert(
-              strides.end(), {H * W * D * C, 1, W * D * C, D * C, C});
-    CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
-        desc,
-        cudnnTypeWrapper<T>::type,
-        size > 3 ? size : 4,
-        dims.data(),
-        strides.data()));
-  }
-}
-
-} // namespace
-
-class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
- public:
-  CuDNNPoolOp(const OperatorDef& operator_def, Workspace* ws)
-      : ConvPoolOpBase<CUDAContext>(operator_def, ws),
-        cudnn_wrapper_(&context_) {
-    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bottom_desc_));
-    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
-    CUDNN_ENFORCE(cudnnCreatePoolingDescriptor(&pooling_desc_));
-    OPERATOR_NEEDS_FEATURE(kernel_.size() >=2 && kernel_.size() <=3,
-        "Cudnn pooling only supports 4d and 5d tensor");
-    if (legacy_pad_ != LegacyPadding::CAFFE_LEGACY_POOLING) {
-      for (int i = 0; i < kernel_.size(); ++i) {
-        OPERATOR_NEEDS_FEATURE(
-            pads_[i] == pads_[kernel_.size() + i],
-            "The current padding scheme leads to unequal padding on the left "
-            "and right, which is not supported by cudnn.");
-      }
-    }
-    // Figure out the pooling descriptor.
-    if (operator_def.type().substr(0, 7) == "MaxPool") {
-      bool deterministic =
-          OperatorBase::GetSingleArgument<bool>("deterministic", false);
-#if CUDNN_VERSION_MIN(6, 0, 0)
-      mode_ =
-          deterministic ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX;
-#else
-      mode_ = CUDNN_POOLING_MAX;
-#endif
-    } else if (operator_def.type().substr(0, 11) == "AveragePool") {
-      mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
-    } else {
-      LOG(FATAL) << "Unsupported pooling method: " << operator_def.type();
-    }
-  }
-
-  ~CuDNNPoolOp() {
-    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bottom_desc_));
-    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
-    CUDNN_ENFORCE(cudnnDestroyPoolingDescriptor(pooling_desc_));
-  }
-
-  template <typename T, typename M>
-  bool DoRunWithType() {
-    auto& X = Input(0);
-    auto* Y = Output(0);
-    int N = 0, C = 0, H = 0, W = 0, D = 0;
-    int H_out = 0, W_out = 0, D_out = 0;
-
-    // cuDNN pooling support only 2 and 3 spatial dimensions.
-    CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
-
-    switch (order_) {
-      case StorageOrder::NHWC:
-        N = X.dim32(0);
-        H = X.dim32(1);
-        W = X.ndim() > 3 ? X.dim32(2) : 1;
-        D = X.ndim() > 4 ? X.dim32(3) : 1;
-        C = X.dim32(X.ndim() - 1);
-        ConvPoolOpBase::SetOutputSize(X, Y, C);
-        H_out = Y->dim32(1);
-        W_out = Y->ndim() > 3 ? Y->dim32(2) : 1;
-        D_out = Y->ndim() > 4 ? Y->dim32(3) : 1;
-        break;
-      case StorageOrder::NCHW:
-        N = X.dim32(0);
-        C = X.dim32(1);
-        H = X.dim32(2);
-        W = X.ndim() > 3 ? X.dim32(3) : 1;
-        D = X.ndim() > 4 ? X.dim32(4) : 1;
-        ConvPoolOpBase::SetOutputSize(X, Y, C);
-        H_out = Y->dim32(2);
-        W_out = Y->ndim() > 3 ? Y->dim32(3) : 1;
-        D_out = Y->ndim() > 4 ? Y->dim32(4) : 1;
-        break;
-      default:
-        LOG(FATAL) << "Unknown storage order: " << order_;
-    }
-
-    // Fast path for global pooling, as cudnn is slow. But only
-    // on float, because fp16 not supported for CUB.
-    if (std::is_same<T, float>::value) {
-      if (order_ == StorageOrder::NCHW && global_pooling_) {
-        if (mode_ == CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING) {
-          global_avgpool_kernel_NCHW<float>
-              <<<std::min(N * C, CAFFE_MAXIMUM_NUM_BLOCKS),
-                 CAFFE_CUDA_NUM_THREADS,
-                 0,
-                 context_.cuda_stream()>>>(
-                  N * C,
-                  H * W * D,
-                  X.data<float>(),
-                  Y->template mutable_data<float>());
-          return true;
-        }
-        if (mode_ == CUDNN_POOLING_MAX) {
-          global_maxpool_kernel_NCHW<float>
-              <<<std::min(N * C, CAFFE_MAXIMUM_NUM_BLOCKS),
-                 CAFFE_CUDA_NUM_THREADS,
-                 0,
-                 context_.cuda_stream()>>>(
-                  N * C,
-                  H * W * D,
-                  X.data<float>(),
-                  Y->template mutable_data<float>());
-          return true;
-        }
-      }
-    }
-
-    if (cudnn_input_dims_ != X.sizes()) {
-      // Dimensions changed; we will need to re-initialize things.
-      VLOG(1) << "Changing the cudnn descriptor configurations.";
-      cudnn_input_dims_ = X.sizes().vec();
-      setTensorDescriptor<T>(X.ndim(), order_, N, C, H, W, D, bottom_desc_);
-      setTensorDescriptor<T>(
-          Y->ndim(), order_, N, C, H_out, W_out, D_out, top_desc_);
-      for (int i = 0; i < kernel_.size(); ++i) {
-        if (pads_[i] != pads_[kernel_.size() + i]) {
-          CAFFE_ENFORCE(
-              legacy_pad_ == LegacyPadding::CAFFE_LEGACY_POOLING,
-              "Cudnn pooling only supports even padding on both sides, with "
-              "the only exception of the caffe legacy pooling case where we "
-              "try to preserve backward compatibility with Caffe.");
-        }
-      }
-      if (kernel_.size() == 2) {
-        CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
-            pooling_desc_,
-            mode_,
-            CUDNN_NOT_PROPAGATE_NAN,
-            kernel_h(),
-            kernel_w(),
-            pad_t(),
-            pad_l(),
-            stride_h(),
-            stride_w()));
-      } else {
-        CUDNN_ENFORCE(cudnnSetPoolingNdDescriptor(
-            pooling_desc_,
-            mode_,
-            CUDNN_NOT_PROPAGATE_NAN,
-            kernel_.size(),
-            kernel_.data(),
-            pads_.data(),
-            stride_.data()));
-      }
-    }
-    // Carry out the pooling computation.
-    const T* Xdata = X.template data<T>();
-    T* Ydata = Y->template mutable_data<T>();
-    CUDNN_ENFORCE(cudnnPoolingForward(
-        cudnn_wrapper_.inline_cudnn_handle(),
-        pooling_desc_,
-        cudnnTypeWrapper<T>::kOne(),
-        bottom_desc_,
-        Xdata,
-        cudnnTypeWrapper<T>::kZero(),
-        top_desc_,
-        Ydata));
-    return true;
-  }
-
-  bool RunOnDevice() final {
-    auto& X = Input(0);
-    auto* Y = Output(0);
-
-    if (X.IsType<float>()) {
-      return DoRunWithType<float, float>();
-    } else if (X.IsType<at::Half>()) {
-      return DoRunWithType<at::Half, float>();
-    } else {
-      LOG(FATAL) << "Unsupported input types";
-    }
-    return true;
-  }
-
- protected:
-  vector<int64_t> cudnn_input_dims_;
-
-  CuDNNWrapper cudnn_wrapper_;
-  cudnnTensorDescriptor_t bottom_desc_;
-  cudnnTensorDescriptor_t top_desc_;
-  cudnnPoolingDescriptor_t pooling_desc_;
-  cudnnPoolingMode_t mode_;
-
- private:
-};
-
-class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
- public:
-  CuDNNPoolGradientOp(const OperatorDef& operator_def, Workspace* ws)
-      : ConvPoolOpBase<CUDAContext>(operator_def, ws),
-        cudnn_wrapper_(&context_) {
-    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bottom_desc_));
-    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
-    CUDNN_ENFORCE(cudnnCreatePoolingDescriptor(&pooling_desc_));
-    // Figure out the pooling descriptor.
-    if (operator_def.type() == "MaxPoolGradient" ||
-        operator_def.type() == "MaxPool1DGradient" ||
-        operator_def.type() == "MaxPool2DGradient" ||
-        operator_def.type() == "MaxPool3DGradient") {
-      bool deterministic =
-          OperatorBase::GetSingleArgument<bool>("deterministic", false);
-#if CUDNN_VERSION_MIN(6, 0, 0)
-      mode_ =
-          deterministic ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX;
-#else
-      mode_ = CUDNN_POOLING_MAX;
-#endif
-    } else if (
-        operator_def.type() == "AveragePoolGradient" ||
-        operator_def.type() == "AveragePool1DGradient" ||
-        operator_def.type() == "AveragePool2DGradient" ||
-        operator_def.type() == "AveragePool3DGradient") {
-      mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
-    } else {
-      LOG(FATAL) << "Unsupported pooling method: " << operator_def.type();
-    }
-  }
-
-  ~CuDNNPoolGradientOp() {
-    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bottom_desc_));
-    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
-    CUDNN_ENFORCE(cudnnDestroyPoolingDescriptor(pooling_desc_));
-  }
-
-  template <typename T, typename M>
-  bool DoRunWithType() {
-    auto& X = Input(0);
-    auto& Y = Input(1);
-    auto& dY = Input(2);
-
-    // cuDNN pooling support only 2 and 3 spatial dimensions.
-    CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
-
-    auto* dX = Output(0, X.sizes(), at::dtype<float>());
-    int N = 0, C = 0, H = 0, W = 0, D = 0;
-    int H_out = 0, W_out = 0, D_out = 0;
-    switch (order_) {
-      case StorageOrder::NHWC:
-        N = X.dim32(0);
-        H = X.dim32(1);
-        W = X.ndim() > 3 ? X.dim32(2) : 1;
-        D = X.ndim() > 4 ? X.dim32(3) : 1;
-        C = X.dim32(X.ndim() - 1);
-        H_out = Y.dim32(1);
-        W_out = Y.ndim() > 3 ? Y.dim32(2) : 1;
-        D_out = Y.ndim() > 4 ? Y.dim32(3) : 1;
-        break;
-      case StorageOrder::NCHW:
-        N = X.dim32(0);
-        C = X.dim32(1);
-        H = X.dim32(2);
-        W = X.ndim() > 3 ? X.dim32(3) : 1;
-        D = X.ndim() > 4 ? X.dim32(4) : 1;
-        H_out = Y.dim32(2);
-        W_out = Y.ndim() > 3 ? Y.dim32(3) : 1;
-        D_out = Y.ndim() > 4 ? Y.dim32(4) : 1;
-        break;
-      default:
-        LOG(FATAL) << "Unknown storage order: " << order_;
-    }
-
-    // Fast path for global pooling, as cudnn is slow. But only
-    // on float, because fp16 not supported for CUB.
-    if (std::is_same<T, float>::value) {
-      if (order_ == StorageOrder::NCHW && global_pooling_) {
-        if (mode_ == CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING) {
-          global_avgpool_backward_NCHW<float>
-              <<<CAFFE_GET_BLOCKS(dX->size()),
-                 CAFFE_CUDA_NUM_THREADS,
-                 0,
-                 context_.cuda_stream()>>>(
-                  N * C,
-                  H * W * D,
-                  dY.data<float>(),
-                  dX->template mutable_data<float>());
-          return true;
-        }
-#if CUDNN_VERSION_MIN(6, 0, 0)
-        if (mode_ == CUDNN_POOLING_MAX ||
-            mode_ == CUDNN_POOLING_MAX_DETERMINISTIC) {
-#else
-        if (mode_ == CUDNN_POOLING_MAX) {
-#endif
-          global_maxpool_backward_NCHW<float>
-              <<<CAFFE_GET_BLOCKS(dX->size()),
-                 CAFFE_CUDA_NUM_THREADS,
-                 0,
-                 context_.cuda_stream()>>>(
-                  N * C,
-                  H * W * D,
-                  dY.data<float>(),
-                  dX->template mutable_data<float>(),
-                  Y.data<float>(),
-                  X.data<float>());
-          return true;
-        }
-      }
-    }
-
-    if (kernel_.size() == 1) {
-      ConvPoolOpBase<CUDAContext>::ComputePads({H});
-    } else if (kernel_.size() == 2) {
-      ConvPoolOpBase<CUDAContext>::ComputePads({H, W});
-    } else if (kernel_.size() == 3) {
-      ConvPoolOpBase<CUDAContext>::ComputePads({H, W, D});
-    } else {
-      CAFFE_THROW("Unsupported kernel size :", kernel_.size());
-    }
-
-    if (cudnn_input_dims_ != X.sizes()) {
-      // Dimensions changed; we will need to re-initialize things.
-      VLOG(1) << "Changing the cudnn descriptor configurations.";
-      cudnn_input_dims_ = X.sizes().vec();
-      setTensorDescriptor<T>(X.ndim(), order_, N, C, H, W, D, bottom_desc_);
-      setTensorDescriptor<T>(
-          Y.ndim(), order_, N, C, H_out, W_out, D_out, top_desc_);
-      for (int i = 0; i < kernel_.size(); ++i) {
-        if (pads_[i] != pads_[kernel_.size() + i]) {
-          CAFFE_ENFORCE(
-              legacy_pad_ == LegacyPadding::CAFFE_LEGACY_POOLING,
-              "Cudnn pooling only supports even padding on both sides, with "
-              "the only exception of the caffe legacy pooling case where we "
-              "try to preserve backward compatibility with Caffe.");
-        }
-      }
-      if (kernel_.size() == 2) {
-        CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
-            pooling_desc_,
-            mode_,
-            CUDNN_NOT_PROPAGATE_NAN,
-            kernel_h(),
-            kernel_w(),
-            pad_t(),
-            pad_l(),
-            stride_h(),
-            stride_w()));
-      } else {
-        CUDNN_ENFORCE(cudnnSetPoolingNdDescriptor(
-            pooling_desc_,
-            mode_,
-            CUDNN_NOT_PROPAGATE_NAN,
-            kernel_.size(),
-            kernel_.data(),
-            pads_.data(),
-            stride_.data()));
-      }
-    }
-    // Carry out the pooling computation.
-    const T* Xdata = X.template data<T>();
-    const T* Ydata = Y.template data<T>();
-    const T* dYdata = dY.template data<T>();
-    T* dXdata = dX->template mutable_data<T>();
-
-    CUDNN_ENFORCE(cudnnPoolingBackward(
-        cudnn_wrapper_.inline_cudnn_handle(),
-        pooling_desc_,
-        cudnnTypeWrapper<T>::kOne(),
-        top_desc_,
-        Ydata,
-        top_desc_,
-        dYdata,
-        bottom_desc_,
-        Xdata,
-        cudnnTypeWrapper<T>::kZero(),
-        bottom_desc_,
-        dXdata));
-    return true;
-  }
-
-  bool RunOnDevice() final {
-    auto& X = Input(0);
-    auto& Y = Input(1);
-    auto& dY = Input(2);
-    auto* dX = Output(0);
-    dX->ResizeLike(X);
-
-    if (X.IsType<float>()) {
-      return DoRunWithType<float, float>();
-    } else if (X.IsType<at::Half>()) {
-      return DoRunWithType<at::Half, float>();
-    } else {
-      LOG(FATAL) << "Unsupported input types";
-    }
-    return true;
-  }
-
- protected:
-  vector<int64_t> cudnn_input_dims_;
-
-  CuDNNWrapper cudnn_wrapper_;
-  cudnnTensorDescriptor_t bottom_desc_;
-  cudnnTensorDescriptor_t top_desc_;
-  cudnnPoolingDescriptor_t pooling_desc_;
-  cudnnPoolingMode_t mode_;
-};
-
-namespace {
-REGISTER_CUDNN_OPERATOR(AveragePool, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(AveragePoolGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(AveragePool1D, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(AveragePool1DGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(AveragePool2D, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(AveragePool2DGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(AveragePool3D, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(AveragePool3DGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(MaxPool, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(MaxPoolGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(MaxPool1D, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(MaxPool1DGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(MaxPool2D, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(MaxPool2DGradient, CuDNNPoolGradientOp);
-
-REGISTER_CUDNN_OPERATOR(MaxPool3D, CuDNNPoolOp);
-REGISTER_CUDNN_OPERATOR(MaxPool3DGradient, CuDNNPoolGradientOp);
-} // namespace
-} // namespace caffe2
index 4740d7d..3caaf24 100644 (file)
@@ -398,6 +398,61 @@ class TestPooling(hu.HypothesisTestCase):
         self.assertGradientChecks(
             gc, op, [X], 0, [0], threshold=0.05, stepsize=0.005)
 
+    @given(op_type=st.sampled_from(["AveragePool", "AveragePoolND"]),
+           dim=st.integers(1, 3),
+           N=st.integers(1, 3),
+           C=st.integers(1, 3),
+           D=st.integers(3, 5),
+           H=st.integers(3, 5),
+           W=st.integers(3, 5),
+           kernel=st.integers(1, 3),
+           stride=st.integers(1, 3),
+           pad=st.integers(0, 2),
+           count_include_pad=st.booleans(),
+           order=st.sampled_from(["NCHW", "NHWC"]),
+           engine=st.sampled_from(["", "CUDNN"]),
+           **hu.gcs)
+    def test_avg_pool_count_include_pad(
+            self, op_type, dim, N, C, D, H, W, kernel, stride, pad,
+            count_include_pad, order, engine, gc, dc):
+        assume(pad < kernel)
+        if hiputl.run_in_hip(gc, dc):
+            if dim != 2:
+                assume(engine != "CUDNN")
+            elif engine == "CUDNN":
+                assume(order == "NCHW")
+
+        if op_type.endswith("ND"):
+            op_type = op_type.replace("N", str(dim))
+
+        op = core.CreateOperator(
+            op_type,
+            ["X"],
+            ["Y"],
+            kernels=[kernel] * dim,
+            strides=[stride] * dim,
+            pads=[pad] * dim * 2,
+            count_include_pad=count_include_pad,
+            order=order,
+            engine=engine,
+        )
+
+        if dim == 1:
+            dims = [N, C, W]
+            axes = [0, 2, 1]
+        elif dim == 2:
+            dims = [N, C, H, W]
+            axes = [0, 2, 3, 1]
+        else:
+            dims = [N, C, D, H, W]
+            axes = [0, 2, 3, 4, 1]
+        X = np.random.randn(*dims).astype(np.float32)
+        if order == "NHWC":
+            X = np.transpose(X, axes)
+
+        self.assertDeviceChecks(dc, op, [X], [0])
+        self.assertGradientChecks(gc, op, [X], 0, [0])
+
 
 if __name__ == "__main__":
     import unittest