Add ceil_mode for Max and Avg pooling (#2965)
authorBartosz Sledz <bartosz.sledz@intel.com>
Mon, 16 Nov 2020 14:16:24 +0000 (15:16 +0100)
committerGitHub <noreply@github.com>
Mon, 16 Nov 2020 14:16:24 +0000 (15:16 +0100)
ngraph/frontend/onnx_import/include/onnx_import/utils/convpool.hpp
ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp
ngraph/frontend/onnx_import/src/utils/convpool.cpp
ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp
ngraph/python/tests/test_onnx/test_backend.py

index 9cd2af1..9b5a93d 100644 (file)
@@ -53,6 +53,13 @@ namespace ngraph
             ///             (height, width, depth).
             Strides get_dilations(const Node& node, const std::size_t kernel_rank = 0UL);
 
+            /// \brief      Gets the 'ceil_mode' (rounding type) attribute value.
+            ///
+            /// \param[in]  node  The ONNX node we query for attribute.
+            ///
+            /// \return     The nGraph RoundingType object representing 'ceil_mode' attribute value.
+            ngraph::op::RoundingType get_rounding_type(const Node& node);
+
             /// \brief Get padding values for the operation described by an ONNX node.
             /// \details Values are taken from the `pads` attribute.
             ///
index e882d85..230de87 100644 (file)
@@ -73,6 +73,7 @@ namespace ngraph
                 Shape m_padding_below;
                 Shape m_padding_above;
                 ngraph::op::PadType m_auto_pad;
+                ngraph::op::RoundingType m_rounding_type;
             };
 
             ///
index b503c45..f932e51 100644 (file)
@@ -99,6 +99,12 @@ namespace ngraph
                 return detail::get_attribute_value(node, "dilations", kernel_rank);
             }
 
+            ngraph::op::RoundingType get_rounding_type(const Node& node)
+            {
+                return static_cast<ngraph::op::RoundingType>(
+                    node.get_attribute_value<std::int64_t>("ceil_mode", 0));
+            }
+
             ngraph::op::PadType get_auto_pad(const Node& node)
             {
                 // Default value means use explicitly provided padding values.
index 09b83ae..17a7a28 100644 (file)
@@ -34,6 +34,7 @@ namespace ngraph
                 , m_strides{convpool::get_strides(node)}
                 , m_dilations{convpool::get_dilations(node)}
                 , m_auto_pad{convpool::get_auto_pad(node)}
+                , m_rounding_type{convpool::get_rounding_type(node)}
             {
                 const auto paddings = convpool::get_pads(node);
                 const CoordinateDiff& padding_above{paddings.second};
@@ -52,7 +53,7 @@ namespace ngraph
                                                                  m_padding_above,
                                                                  m_kernel_shape,
                                                                  !count_include_pad,
-                                                                 ngraph::op::RoundingType::FLOOR,
+                                                                 m_rounding_type,
                                                                  m_auto_pad)};
             }
 
@@ -63,7 +64,7 @@ namespace ngraph
                                                                  m_padding_below,
                                                                  m_padding_above,
                                                                  m_kernel_shape,
-                                                                 ngraph::op::RoundingType::FLOOR,
+                                                                 m_rounding_type,
                                                                  m_auto_pad)};
             }
 
index e7588d6..a727eb8 100644 (file)
@@ -194,9 +194,7 @@ tests_expected_to_fail = [
         "OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu",
         "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
     (xfail_issue_33616,
-        "OnnxBackendNodeModelTest.test_maxpool_2d_ceil_cpu",
-        "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu",
-        "OnnxBackendNodeModelTest.test_averagepool_2d_ceil_cpu"),
+        "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"),
     (xfail_issue_38086,
         "OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_expanded_cpu",
         "OnnxBackendNodeModelTest.test_dynamicquantizelinear_expanded_cpu",