From 7cead20209b3a4290eb059de90b21a9a2df2837d Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Thu, 15 Oct 2020 19:04:43 +0200 Subject: [PATCH] [ONNX] Replace global poolings with reduce operations (#2650) --- .../include/onnx_import/utils/pooling_factory.hpp | 12 -------- .../onnx_import/src/op/global_average_pool.cpp | 36 ++++++++++++++++++++-- .../onnx_import/src/op/global_max_pool.cpp | 36 ++++++++++++++++++++-- .../onnx_import/src/utils/pooling_factory.cpp | 20 ------------ 4 files changed, 66 insertions(+), 38 deletions(-) diff --git a/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp b/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp index c4636d4..e882d85 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp @@ -86,18 +86,6 @@ namespace ngraph virtual ~LocalPoolingFactory() = default; }; - /// - /// \brief Factory class which generates sub-graphs for ONNX 'global' pooling - /// operators. - /// \note In a 'global' pooling operation, the kernel shape is calculated - /// based on spatial dims - class GlobalPoolingFactory : public PoolingFactory - { - public: - explicit GlobalPoolingFactory(const Node& node); - virtual ~GlobalPoolingFactory() = default; - }; - } // namespace pooling } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/op/global_average_pool.cpp b/ngraph/frontend/onnx_import/src/op/global_average_pool.cpp index 7ea27bc..30b6d4b 100644 --- a/ngraph/frontend/onnx_import/src/op/global_average_pool.cpp +++ b/ngraph/frontend/onnx_import/src/op/global_average_pool.cpp @@ -14,10 +14,12 @@ // limitations under the License. //***************************************************************************** +#include +#include + #include "global_average_pool.hpp" #include "ngraph/node.hpp" -#include "ngraph/op/avg_pool.hpp" -#include "onnx_import/utils/pooling_factory.hpp" +#include "onnx_import/default_opset.hpp" namespace ngraph { @@ -29,7 +31,35 @@ namespace ngraph { OutputVector global_average_pool(const Node& node) { - return pooling::GlobalPoolingFactory(node).make_avg_pool(); + auto data = node.get_ng_inputs()[0]; + auto data_rank = data.get_partial_shape().rank(); + + NGRAPH_CHECK(data_rank.is_static(), + "The input data tensor's rank has to be known (static)"); + + auto data_rank_value = data_rank.get_length(); + + NGRAPH_CHECK(data_rank_value > 2, + "The input data tensor's rank has to be greater than 2." + "Provided data rank is: ", + data_rank_value); + + // Generate axes for reduce operation which contain all spatial dims indexes. + // Examples: + // Input shape: [N, C, H, W] + // Input spatial dimensions are H and W + // Expected spatial dims indexes: [2, 3] + // + // Input shape: [N, C, H, W, D] + // Input spatial dimensions are H, W and D + // Expected spatial dims indexes: [2, 3, 4] + uint64_t data_spatial_rank = data_rank_value - 2; + auto reduce_axes_vector = std::vector(data_spatial_rank); + std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2); + auto reduce_axes = default_opset::Constant::create( + element::i64, Shape{data_spatial_rank}, reduce_axes_vector); + + return {std::make_shared(data, reduce_axes, true)}; } } // namespace set_1 diff --git a/ngraph/frontend/onnx_import/src/op/global_max_pool.cpp b/ngraph/frontend/onnx_import/src/op/global_max_pool.cpp index 79dee54..53af9d6 100644 --- a/ngraph/frontend/onnx_import/src/op/global_max_pool.cpp +++ b/ngraph/frontend/onnx_import/src/op/global_max_pool.cpp @@ -14,10 +14,12 @@ // limitations under the License. //***************************************************************************** +#include +#include + #include "global_max_pool.hpp" #include "ngraph/node.hpp" -#include "ngraph/op/max_pool.hpp" -#include "onnx_import/utils/pooling_factory.hpp" +#include "onnx_import/default_opset.hpp" namespace ngraph { @@ -29,7 +31,35 @@ namespace ngraph { OutputVector global_max_pool(const Node& node) { - return pooling::GlobalPoolingFactory(node).make_max_pool(); + auto data = node.get_ng_inputs()[0]; + auto data_rank = data.get_partial_shape().rank(); + + NGRAPH_CHECK(data_rank.is_static(), + "The input data tensor's rank has to be known (static)"); + + auto data_rank_value = data_rank.get_length(); + + NGRAPH_CHECK(data_rank_value > 2, + "The input data tensor's rank has to be greater than 2." + "Provided data rank is: ", + data_rank_value); + + // Generate axes for reduce operation which contain all spatial dims indexes. + // Examples: + // Input shape: [N, C, H, W] + // Input spatial dimensions are H and W + // Expected spatial dims indexes: [2, 3] + // + // Input shape: [N, C, H, W, D] + // Input spatial dimensions are H, W and D + // Expected spatial dims indexes: [2, 3, 4] + uint64_t data_spatial_rank = data_rank_value - 2; + auto reduce_axes_vector = std::vector(data_spatial_rank); + std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2); + auto reduce_axes = default_opset::Constant::create( + element::i64, Shape{data_spatial_rank}, reduce_axes_vector); + + return {std::make_shared(data, reduce_axes, true)}; } } // namespace set_1 diff --git a/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp b/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp index 766120d..09b83ae 100644 --- a/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp +++ b/ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp @@ -73,26 +73,6 @@ namespace ngraph // Kernel shape is required m_kernel_shape = node.get_attribute_value>("kernel_shape"); } - - GlobalPoolingFactory::GlobalPoolingFactory(const Node& node) - : PoolingFactory(node) - { - const auto data_shape = node.get_ng_inputs().at(0).get_partial_shape(); - const auto data_rank = data_shape.rank(); - CHECK_VALID_NODE( - node, data_rank.is_static(), "Data rank must be static for global pooling ops"); - Shape kernel_shape; - for (auto i = 2; i < data_rank.get_length(); ++i) - { - CHECK_VALID_NODE(node, - data_shape[i].is_static(), - "All spatial dimensions must be known for global pooling ops"); - kernel_shape.emplace_back(data_shape[i].get_length()); - } - - // Set shape to all but {N,C} axes. - m_kernel_shape = kernel_shape; - } } // namespace pooling } // namespace onnx_import } // namespace ngraph -- 2.7.4