[ONNX] Replace global poolings with reduce operations (#2650)
authorTomasz Socha <tomasz.socha@intel.com>
Thu, 15 Oct 2020 17:04:43 +0000 (19:04 +0200)
committerGitHub <noreply@github.com>
Thu, 15 Oct 2020 17:04:43 +0000 (19:04 +0200)
ngraph/frontend/onnx_import/include/onnx_import/utils/pooling_factory.hpp
ngraph/frontend/onnx_import/src/op/global_average_pool.cpp
ngraph/frontend/onnx_import/src/op/global_max_pool.cpp
ngraph/frontend/onnx_import/src/utils/pooling_factory.cpp

index c4636d4..e882d85 100644 (file)
@@ -86,18 +86,6 @@ namespace ngraph
                 virtual ~LocalPoolingFactory() = default;
             };
 
-            ///
-            /// \brief      Factory class which generates sub-graphs for ONNX 'global' pooling
-            ///             operators.
-            /// \note       In a 'global' pooling operation, the kernel shape is calculated
-            ///             based on spatial dims
-            class GlobalPoolingFactory : public PoolingFactory
-            {
-            public:
-                explicit GlobalPoolingFactory(const Node& node);
-                virtual ~GlobalPoolingFactory() = default;
-            };
-
         } // namespace pooling
     }     // namespace onnx_import
 } // namespace ngraph
index 7ea27bc..30b6d4b 100644 (file)
 // limitations under the License.
 //*****************************************************************************
 
+#include <numeric>
+#include <vector>
+
 #include "global_average_pool.hpp"
 #include "ngraph/node.hpp"
-#include "ngraph/op/avg_pool.hpp"
-#include "onnx_import/utils/pooling_factory.hpp"
+#include "onnx_import/default_opset.hpp"
 
 namespace ngraph
 {
@@ -29,7 +31,35 @@ namespace ngraph
             {
                 OutputVector global_average_pool(const Node& node)
                 {
-                    return pooling::GlobalPoolingFactory(node).make_avg_pool();
+                    auto data = node.get_ng_inputs()[0];
+                    auto data_rank = data.get_partial_shape().rank();
+
+                    NGRAPH_CHECK(data_rank.is_static(),
+                                 "The input data tensor's rank has to be known (static)");
+
+                    auto data_rank_value = data_rank.get_length();
+
+                    NGRAPH_CHECK(data_rank_value > 2,
+                                 "The input data tensor's rank has to be greater than 2."
+                                 "Provided data rank is: ",
+                                 data_rank_value);
+
+                    // Generate axes for reduce operation which contain all spatial dims indexes.
+                    // Examples:
+                    // Input shape: [N, C, H, W]
+                    // Input spatial dimensions are H and W
+                    // Expected spatial dims indexes: [2, 3]
+                    //
+                    // Input shape: [N, C, H, W, D]
+                    // Input spatial dimensions are H, W and D
+                    // Expected spatial dims indexes: [2, 3, 4]
+                    uint64_t data_spatial_rank = data_rank_value - 2;
+                    auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
+                    std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
+                    auto reduce_axes = default_opset::Constant::create(
+                        element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
+
+                    return {std::make_shared<default_opset::ReduceMean>(data, reduce_axes, true)};
                 }
 
             } // namespace set_1
index 79dee54..53af9d6 100644 (file)
 // limitations under the License.
 //*****************************************************************************
 
+#include <numeric>
+#include <vector>
+
 #include "global_max_pool.hpp"
 #include "ngraph/node.hpp"
-#include "ngraph/op/max_pool.hpp"
-#include "onnx_import/utils/pooling_factory.hpp"
+#include "onnx_import/default_opset.hpp"
 
 namespace ngraph
 {
@@ -29,7 +31,35 @@ namespace ngraph
             {
                 OutputVector global_max_pool(const Node& node)
                 {
-                    return pooling::GlobalPoolingFactory(node).make_max_pool();
+                    auto data = node.get_ng_inputs()[0];
+                    auto data_rank = data.get_partial_shape().rank();
+
+                    NGRAPH_CHECK(data_rank.is_static(),
+                                 "The input data tensor's rank has to be known (static)");
+
+                    auto data_rank_value = data_rank.get_length();
+
+                    NGRAPH_CHECK(data_rank_value > 2,
+                                 "The input data tensor's rank has to be greater than 2."
+                                 "Provided data rank is: ",
+                                 data_rank_value);
+
+                    // Generate axes for reduce operation which contain all spatial dims indexes.
+                    // Examples:
+                    // Input shape: [N, C, H, W]
+                    // Input spatial dimensions are H and W
+                    // Expected spatial dims indexes: [2, 3]
+                    //
+                    // Input shape: [N, C, H, W, D]
+                    // Input spatial dimensions are H, W and D
+                    // Expected spatial dims indexes: [2, 3, 4]
+                    uint64_t data_spatial_rank = data_rank_value - 2;
+                    auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
+                    std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
+                    auto reduce_axes = default_opset::Constant::create(
+                        element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
+
+                    return {std::make_shared<default_opset::ReduceMax>(data, reduce_axes, true)};
                 }
 
             } // namespace set_1
index 766120d..09b83ae 100644 (file)
@@ -73,26 +73,6 @@ namespace ngraph
                 // Kernel shape is required
                 m_kernel_shape = node.get_attribute_value<std::vector<std::size_t>>("kernel_shape");
             }
-
-            GlobalPoolingFactory::GlobalPoolingFactory(const Node& node)
-                : PoolingFactory(node)
-            {
-                const auto data_shape = node.get_ng_inputs().at(0).get_partial_shape();
-                const auto data_rank = data_shape.rank();
-                CHECK_VALID_NODE(
-                    node, data_rank.is_static(), "Data rank must be static for global pooling ops");
-                Shape kernel_shape;
-                for (auto i = 2; i < data_rank.get_length(); ++i)
-                {
-                    CHECK_VALID_NODE(node,
-                                     data_shape[i].is_static(),
-                                     "All spatial dimensions must be known for global pooling ops");
-                    kernel_shape.emplace_back(data_shape[i].get_length());
-                }
-
-                // Set shape to all but {N,C} axes.
-                m_kernel_shape = kernel_shape;
-            }
         } // namespace pooling
     }     // namespace onnx_import
 } // namespace ngraph