#include <algorithm>
#include <memory>
#include <vector>
+#include <ngraph/ops.hpp>
+#include <ngraph/opsets/opset1.hpp>
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
constexpr NodeTypeInfo op::GatherIE::type_info;
-op::GatherIE::GatherIE(const Output<Node>& params, const Output<Node>& indices, int64_t axis, const Shape & output_shape)
+op::GatherIE::GatherIE(const Output<Node>& params, const Output<Node>& indices, int64_t axis)
: Op({params, indices})
- , m_axis(axis)
- , m_output_shape(output_shape) {
+ , m_axis(axis) {
constructor_validate_and_infer_types();
}
-shared_ptr<Node> op::GatherIE::copy_with_new_args(const NodeVector& new_args) const {
+shared_ptr<Node> op::GatherIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
check_new_args_count(this, new_args);
- return make_shared<GatherIE>(new_args.at(0), new_args.at(1), m_axis, m_output_shape);
+ return make_shared<GatherIE>(new_args.at(0), new_args.at(1), m_axis);
}
void op::GatherIE::validate_and_infer_types() {
- set_output_type(0, get_input_element_type(0), m_output_shape);
+ // Use opset1::Gather to calculate output shape
+ auto gather = std::make_shared<opset1::Gather>(input_value(0), input_value(1), opset1::Constant::create(element::i64, Shape{1}, {m_axis}));
+ set_output_type(0, gather->output(0).get_element_type(), gather->output(0).get_partial_shape());
}
#include <ngraph/rt_info.hpp>
void ngraph::pass::ConvertGatherToGatherIE::convert_gather_to_gather_ie() {
- auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
- auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
- auto input_2 = std::make_shared<pattern::op::Label>(element::i64, Shape{});
- auto gather = std::make_shared<ngraph::opset1::Gather>(input_0, input_1, input_2);
+ auto gather = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::Gather>());
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
auto gather = std::dynamic_pointer_cast<ngraph::opset1::Gather>(m.get_match_root());
return false;
}
- auto axes_node = gather->input(2).get_source_output().get_node_shared_ptr();
- auto axes_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node);
+ auto axes_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(gather->input_value(2).get_node_shared_ptr());
if (!axes_constant) {
return false;
}
- auto axis = axes_constant->get_vector<int64_t>()[0];
+ auto axis = axes_constant->cast_vector<int64_t>()[0];
// vector of new created nGraph operations
NodeVector new_ops;
// if the input with indices is scalar we need to unsqueeze it to 1D so plugins which do not support 0D can
// execute this layer. Then we need to squeeze the axis dimension to restore original shape of gather output
- auto indices = gather->input(1).get_source_output();
- auto gather_output_shape = gather->output(0).get_shape();
+ auto indices = gather->input_value(1);
+ const auto indices_rank = indices.get_partial_shape().rank();
+ if (indices_rank.is_dynamic()) {
+ return false;
+ }
+
bool squeeze_gather_output = false;
- if (indices.get_shape().empty()) {
+ if (indices_rank.get_length() == 0) {
squeeze_gather_output = true;
- gather_output_shape.insert(gather_output_shape.begin() + axis, 1);
- indices = std::make_shared<ngraph::opset1::Unsqueeze>(indices.get_node_shared_ptr(),
- opset1::Constant::create(element::i64, Shape{1}, {0}));
+ indices = std::make_shared<ngraph::opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(indices.get_node_shared_ptr());
}
- auto gather_ie = std::make_shared<ngraph::op::GatherIE>(gather->input(0).get_source_output(),
- indices,
- axis,
- gather_output_shape);
+
+ auto gather_ie = std::make_shared<ngraph::op::GatherIE>(gather->input_value(0), indices, axis);
new_ops.push_back(gather_ie);
if (squeeze_gather_output) {
auto sq = std::make_shared<ngraph::opset1::Squeeze>(gather_ie,
- op::Constant::create(element::i64, Shape{1}, {axis}));
+ opset1::Constant::create(element::i64, Shape{1}, {axis}));
sq->set_friendly_name(gather->get_friendly_name());
new_ops.push_back(sq);
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph_ops/gather_ie.hpp>
+
+#include "ngraph_test_utils.hpp"
+
+using namespace testing;
+using namespace ngraph;
+
+TEST(TransformationTests, ConvertGatherToGatherIEStatic1) {
+ std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
+ auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+ auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+ f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+ pass::InitNodeInfo().run_on_function(f);
+ pass::ConvertGatherToGatherIE().run_on_function(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
+ auto gather = std::make_shared<op::GatherIE>(input, indices, 1);
+
+ f_ref = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertGatherToGatherIEStatic2) {
+ std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+ auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+ auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+ f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+ pass::InitNodeInfo().run_on_function(f);
+ pass::ConvertGatherToGatherIE().run_on_function(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+ auto unsqueeze = std::make_shared<opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
+ auto gather = std::make_shared<op::GatherIE>(input, unsqueeze, 1);
+ auto squeeze = std::make_shared<opset1::Squeeze>(gather, opset1::Constant::create(element::i64, Shape{1}, {1}));
+
+ f_ref = std::make_shared<Function>(NodeVector{squeeze}, ParameterVector{input, indices});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertGatherToGatherIEDynamic1) {
+ std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN, DYN, DYN});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+ auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+ auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+ f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+ pass::InitNodeInfo().run_on_function(f);
+ pass::ConvertGatherToGatherIE().run_on_function(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN, DYN, DYN});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+ auto gather = std::make_shared<op::GatherIE>(input, indices, 1);
+
+ f_ref = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertGatherToGatherIEDynamic2) {
+ std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+ auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+ auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+ f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+ pass::InitNodeInfo().run_on_function(f);
+ pass::ConvertGatherToGatherIE().run_on_function(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+ auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+ auto unsqueeze = std::make_shared<opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
+ auto gather = std::make_shared<op::GatherIE>(input, unsqueeze, 1);
+ auto squeeze = std::make_shared<opset1::Squeeze>(gather, opset1::Constant::create(element::i64, Shape{1}, {1}));
+
+ f_ref = std::make_shared<Function>(NodeVector{squeeze}, ParameterVector{input, indices});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file