virtual ~LocalPoolingFactory() = default;
};
- ///
- /// \brief Factory class which generates sub-graphs for ONNX 'global' pooling
- /// operators.
- /// \note In a 'global' pooling operation, the kernel shape is calculated
- /// based on spatial dims
- class GlobalPoolingFactory : public PoolingFactory
- {
- public:
- explicit GlobalPoolingFactory(const Node& node);
- virtual ~GlobalPoolingFactory() = default;
- };
-
} // namespace pooling
} // namespace onnx_import
} // namespace ngraph
// limitations under the License.
//*****************************************************************************
+#include <numeric>
+#include <vector>
+
#include "global_average_pool.hpp"
#include "ngraph/node.hpp"
-#include "ngraph/op/avg_pool.hpp"
-#include "onnx_import/utils/pooling_factory.hpp"
+#include "onnx_import/default_opset.hpp"
namespace ngraph
{
{
OutputVector global_average_pool(const Node& node)
{
- return pooling::GlobalPoolingFactory(node).make_avg_pool();
+ auto data = node.get_ng_inputs()[0];
+ auto data_rank = data.get_partial_shape().rank();
+
+ NGRAPH_CHECK(data_rank.is_static(),
+ "The input data tensor's rank has to be known (static)");
+
+ auto data_rank_value = data_rank.get_length();
+
+ NGRAPH_CHECK(data_rank_value > 2,
+ "The input data tensor's rank has to be greater than 2."
+ "Provided data rank is: ",
+ data_rank_value);
+
+ // Generate axes for reduce operation which contain all spatial dims indexes.
+ // Examples:
+ // Input shape: [N, C, H, W]
+ // Input spatial dimensions are H and W
+ // Expected spatial dims indexes: [2, 3]
+ //
+ // Input shape: [N, C, H, W, D]
+ // Input spatial dimensions are H, W and D
+ // Expected spatial dims indexes: [2, 3, 4]
+ uint64_t data_spatial_rank = data_rank_value - 2;
+ auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
+ std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
+ auto reduce_axes = default_opset::Constant::create(
+ element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
+
+ return {std::make_shared<default_opset::ReduceMean>(data, reduce_axes, true)};
}
} // namespace set_1
// limitations under the License.
//*****************************************************************************
+#include <numeric>
+#include <vector>
+
#include "global_max_pool.hpp"
#include "ngraph/node.hpp"
-#include "ngraph/op/max_pool.hpp"
-#include "onnx_import/utils/pooling_factory.hpp"
+#include "onnx_import/default_opset.hpp"
namespace ngraph
{
{
OutputVector global_max_pool(const Node& node)
{
- return pooling::GlobalPoolingFactory(node).make_max_pool();
+ auto data = node.get_ng_inputs()[0];
+ auto data_rank = data.get_partial_shape().rank();
+
+ NGRAPH_CHECK(data_rank.is_static(),
+ "The input data tensor's rank has to be known (static)");
+
+ auto data_rank_value = data_rank.get_length();
+
+ NGRAPH_CHECK(data_rank_value > 2,
+ "The input data tensor's rank has to be greater than 2."
+ "Provided data rank is: ",
+ data_rank_value);
+
+ // Generate axes for reduce operation which contain all spatial dims indexes.
+ // Examples:
+ // Input shape: [N, C, H, W]
+ // Input spatial dimensions are H and W
+ // Expected spatial dims indexes: [2, 3]
+ //
+ // Input shape: [N, C, H, W, D]
+ // Input spatial dimensions are H, W and D
+ // Expected spatial dims indexes: [2, 3, 4]
+ uint64_t data_spatial_rank = data_rank_value - 2;
+ auto reduce_axes_vector = std::vector<std::int64_t>(data_spatial_rank);
+ std::iota(reduce_axes_vector.begin(), reduce_axes_vector.end(), 2);
+ auto reduce_axes = default_opset::Constant::create(
+ element::i64, Shape{data_spatial_rank}, reduce_axes_vector);
+
+ return {std::make_shared<default_opset::ReduceMax>(data, reduce_axes, true)};
}
} // namespace set_1
// Kernel shape is required
m_kernel_shape = node.get_attribute_value<std::vector<std::size_t>>("kernel_shape");
}
-
- GlobalPoolingFactory::GlobalPoolingFactory(const Node& node)
- : PoolingFactory(node)
- {
- const auto data_shape = node.get_ng_inputs().at(0).get_partial_shape();
- const auto data_rank = data_shape.rank();
- CHECK_VALID_NODE(
- node, data_rank.is_static(), "Data rank must be static for global pooling ops");
- Shape kernel_shape;
- for (auto i = 2; i < data_rank.get_length(); ++i)
- {
- CHECK_VALID_NODE(node,
- data_shape[i].is_static(),
- "All spatial dimensions must be known for global pooling ops");
- kernel_shape.emplace_back(data_shape[i].get_length());
- }
-
- // Set shape to all but {N,C} axes.
- m_kernel_shape = kernel_shape;
- }
} // namespace pooling
} // namespace onnx_import
} // namespace ngraph