From 749d70bb63fd43ee34d1f0d169742c8326b50f21 Mon Sep 17 00:00:00 2001 From: Bartosz Sledz Date: Mon, 16 Nov 2020 15:16:24 +0100 Subject: [PATCH] Add ceil_mode for Max and Avg pooling (#2965) --- ngraph/frontend/onnx_import/include/onnx_import/utils/convpool.hpp | 7 +++++++ .../onnx_import/include/onnx_import/utils/pooling_factory.hpp | 1 + ngraph/frontend/onnx_import/src/utils/convpool.cpp | 6 ++++++ ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp | 5 +++-- ngraph/python/tests/test_onnx/test_backend.py | 4 +--- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ngraph/frontend/onnx_import/include/onnx_import/utils/convpool.hpp b/ngraph/frontend/onnx_import/include/onnx_import/utils/convpool.hpp index 9cd2af1..9b5a93d 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/utils/convpool.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/utils/convpool.hpp @@ -53,6 +53,13 @@ namespace ngraph /// (height, width, depth). Strides get_dilations(const Node& node, const std::size_t kernel_rank = 0UL); + /// \brief Gets the 'ceil_mode' (rounding type) attribute value. + /// + /// \param[in] node The ONNX node we query for attribute. + /// + /// \return The nGraph RoundingType object representing 'ceil_mode' attribute value. + ngraph::op::RoundingType get_rounding_type(const Node& node); + /// \brief Get padding values for the operation described by an ONNX node. /// \details Values are taken from the `pads` attribute. /// diff --git a/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp b/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp index e882d85..230de87 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp @@ -73,6 +73,7 @@ namespace ngraph Shape m_padding_below; Shape m_padding_above; ngraph::op::PadType m_auto_pad; + ngraph::op::RoundingType m_rounding_type; }; /// diff --git a/ngraph/frontend/onnx_import/src/utils/convpool.cpp b/ngraph/frontend/onnx_import/src/utils/convpool.cpp index b503c45..f932e51 100644 --- a/ngraph/frontend/onnx_import/src/utils/convpool.cpp +++ b/ngraph/frontend/onnx_import/src/utils/convpool.cpp @@ -99,6 +99,12 @@ namespace ngraph return detail::get_attribute_value(node, "dilations", kernel_rank); } + ngraph::op::RoundingType get_rounding_type(const Node& node) + { + return static_cast( + node.get_attribute_value("ceil_mode", 0)); + } + ngraph::op::PadType get_auto_pad(const Node& node) { // Default value means use explicitly provided padding values. diff --git a/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp b/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp index 09b83ae..17a7a28 100644 --- a/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp +++ b/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp @@ -34,6 +34,7 @@ namespace ngraph , m_strides{convpool::get_strides(node)} , m_dilations{convpool::get_dilations(node)} , m_auto_pad{convpool::get_auto_pad(node)} + , m_rounding_type{convpool::get_rounding_type(node)} { const auto paddings = convpool::get_pads(node); const CoordinateDiff& padding_above{paddings.second}; @@ -52,7 +53,7 @@ namespace ngraph m_padding_above, m_kernel_shape, !count_include_pad, - ngraph::op::RoundingType::FLOOR, + m_rounding_type, m_auto_pad)}; } @@ -63,7 +64,7 @@ namespace ngraph m_padding_below, m_padding_above, m_kernel_shape, - ngraph::op::RoundingType::FLOOR, + m_rounding_type, m_auto_pad)}; } diff --git a/ngraph/python/tests/test_onnx/test_backend.py b/ngraph/python/tests/test_onnx/test_backend.py index e7588d6..a727eb8 100644 --- a/ngraph/python/tests/test_onnx/test_backend.py +++ b/ngraph/python/tests/test_onnx/test_backend.py @@ -194,9 +194,7 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu", "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"), (xfail_issue_33616, - "OnnxBackendNodeModelTest.test_maxpool_2d_ceil_cpu", - "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu", - "OnnxBackendNodeModelTest.test_averagepool_2d_ceil_cpu"), + "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"), (xfail_issue_38086, "OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_expanded_cpu", "OnnxBackendNodeModelTest.test_dynamicquantizelinear_expanded_cpu", -- 2.7.4