Updated ConvertGatherToGatherIE transformation to support dynamic shapes (#611)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Tue, 26 May 2020 21:38:04 +0000 (00:38 +0300)
committerGitHub <noreply@github.com>
Tue, 26 May 2020 21:38:04 +0000 (00:38 +0300)
inference-engine/src/transformations/include/ngraph_ops/gather_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.hpp
inference-engine/src/transformations/src/ngraph_ops/gather_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_gather_to_gather_ie.cpp [new file with mode: 0644]

index f91c92a..aea5636 100644 (file)
@@ -20,17 +20,16 @@ public:
     const NodeTypeInfo& get_type_info() const override { return type_info; }
     GatherIE() = default;
 
-    GatherIE(const Output<Node>& params, const Output<Node>& indices, int64_t axis, const Shape & output_shape);
+    GatherIE(const Output<Node>& params, const Output<Node>& indices, int64_t axis);
 
     void validate_and_infer_types() override;
 
     int64_t get_axis() const { return m_axis; }
     void set_axis(int64_t axis) { m_axis = axis; }
-    std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
+    std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
 
 protected:
     int64_t m_axis;
-    Shape m_output_shape;
 };
 
 }  // namespace op
index a89506a..05dce84 100644 (file)
@@ -27,6 +27,13 @@ class INFERENCE_ENGINE_API_CLASS(ConvertGatherToGatherIE);
 }  // namespace pass
 }  // namespace ngraph
 
+/*
+ * Description:
+ *     This transformation converts opset1::Gather to legacy GatherIE
+ *     GatherIE takes axes as value and if indices input has empty shape (scalar)
+ *     we unsqueeze indices input and squeeze GatherIE output.
+ */
+
 class ngraph::pass::ConvertGatherToGatherIE : public ngraph::pass::GraphRewrite {
 public:
     ConvertGatherToGatherIE() : GraphRewrite() {
index 0fd4a33..f9659ea 100644 (file)
@@ -7,6 +7,8 @@
 #include <algorithm>
 #include <memory>
 #include <vector>
+#include <ngraph/ops.hpp>
+#include <ngraph/opsets/opset1.hpp>
 
 #include "ngraph/util.hpp"
 #include "ngraph/validation_util.hpp"
@@ -16,18 +18,19 @@ using namespace ngraph;
 
 constexpr NodeTypeInfo op::GatherIE::type_info;
 
-op::GatherIE::GatherIE(const Output<Node>& params, const Output<Node>& indices, int64_t axis, const Shape & output_shape)
+op::GatherIE::GatherIE(const Output<Node>& params, const Output<Node>& indices, int64_t axis)
         : Op({params, indices})
-        , m_axis(axis)
-        , m_output_shape(output_shape) {
+        , m_axis(axis) {
     constructor_validate_and_infer_types();
 }
 
-shared_ptr<Node> op::GatherIE::copy_with_new_args(const NodeVector& new_args) const {
+shared_ptr<Node> op::GatherIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
     check_new_args_count(this, new_args);
-    return make_shared<GatherIE>(new_args.at(0), new_args.at(1), m_axis, m_output_shape);
+    return make_shared<GatherIE>(new_args.at(0), new_args.at(1), m_axis);
 }
 
 void op::GatherIE::validate_and_infer_types() {
-    set_output_type(0, get_input_element_type(0), m_output_shape);
+    // Use opset1::Gather to calculate output shape
+    auto gather = std::make_shared<opset1::Gather>(input_value(0), input_value(1), opset1::Constant::create(element::i64, Shape{1}, {m_axis}));
+    set_output_type(0, gather->output(0).get_element_type(), gather->output(0).get_partial_shape());
 }
index 1982a92..9ee227a 100644 (file)
 #include <ngraph/rt_info.hpp>
 
 void ngraph::pass::ConvertGatherToGatherIE::convert_gather_to_gather_ie() {
-    auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
-    auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
-    auto input_2 = std::make_shared<pattern::op::Label>(element::i64, Shape{});
-    auto gather = std::make_shared<ngraph::opset1::Gather>(input_0, input_1, input_2);
+    auto gather = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::Gather>());
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
         auto gather = std::dynamic_pointer_cast<ngraph::opset1::Gather>(m.get_match_root());
@@ -22,37 +19,36 @@ void ngraph::pass::ConvertGatherToGatherIE::convert_gather_to_gather_ie() {
             return false;
         }
 
-        auto axes_node = gather->input(2).get_source_output().get_node_shared_ptr();
-        auto axes_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node);
+        auto axes_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(gather->input_value(2).get_node_shared_ptr());
         if (!axes_constant) {
             return false;
         }
-        auto axis = axes_constant->get_vector<int64_t>()[0];
+        auto axis = axes_constant->cast_vector<int64_t>()[0];
 
         // vector of new created nGraph operations
         NodeVector new_ops;
 
         // if the input with indices is scalar we need to unsqueeze it to 1D so plugins which do not support 0D can
         // execute this layer. Then we need to squeeze the axis dimension to restore original shape of gather output
-        auto indices = gather->input(1).get_source_output();
-        auto gather_output_shape = gather->output(0).get_shape();
+        auto indices = gather->input_value(1);
+        const auto indices_rank = indices.get_partial_shape().rank();
+        if (indices_rank.is_dynamic()) {
+            return false;
+        }
+
         bool squeeze_gather_output = false;
-        if (indices.get_shape().empty()) {
+        if (indices_rank.get_length() == 0) {
             squeeze_gather_output = true;
-            gather_output_shape.insert(gather_output_shape.begin() + axis, 1);
-            indices = std::make_shared<ngraph::opset1::Unsqueeze>(indices.get_node_shared_ptr(),
-                                                                  opset1::Constant::create(element::i64, Shape{1}, {0}));
+            indices = std::make_shared<ngraph::opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
             new_ops.push_back(indices.get_node_shared_ptr());
         }
-        auto gather_ie = std::make_shared<ngraph::op::GatherIE>(gather->input(0).get_source_output(),
-                                                                indices,
-                                                                axis,
-                                                                gather_output_shape);
+
+        auto gather_ie = std::make_shared<ngraph::op::GatherIE>(gather->input_value(0), indices, axis);
         new_ops.push_back(gather_ie);
 
         if (squeeze_gather_output) {
             auto sq = std::make_shared<ngraph::opset1::Squeeze>(gather_ie,
-                                                                op::Constant::create(element::i64, Shape{1}, {axis}));
+                                                                opset1::Constant::create(element::i64, Shape{1}, {axis}));
             sq->set_friendly_name(gather->get_friendly_name());
             new_ops.push_back(sq);
 
diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_gather_to_gather_ie.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_gather_to_gather_ie.cpp
new file mode 100644 (file)
index 0000000..e284ebf
--- /dev/null
@@ -0,0 +1,133 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph_ops/gather_ie.hpp>
+
+#include "ngraph_test_utils.hpp"
+
+using namespace testing;
+using namespace ngraph;
+
+TEST(TransformationTests, ConvertGatherToGatherIEStatic1) {
+    std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
+        auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+        auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+        f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+        pass::InitNodeInfo().run_on_function(f);
+        pass::ConvertGatherToGatherIE().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{15, 4, 20, 28});
+        auto gather = std::make_shared<op::GatherIE>(input, indices, 1);
+
+        f_ref = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertGatherToGatherIEStatic2) {
+    std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+        auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+        auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+        f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+        pass::InitNodeInfo().run_on_function(f);
+        pass::ConvertGatherToGatherIE().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, Shape{6, 12, 10, 24});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+        auto unsqueeze = std::make_shared<opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
+        auto gather = std::make_shared<op::GatherIE>(input, unsqueeze, 1);
+        auto squeeze = std::make_shared<opset1::Squeeze>(gather, opset1::Constant::create(element::i64, Shape{1}, {1}));
+
+        f_ref = std::make_shared<Function>(NodeVector{squeeze}, ParameterVector{input, indices});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertGatherToGatherIEDynamic1) {
+    std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN, DYN, DYN});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+        auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+        auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+        f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+        pass::InitNodeInfo().run_on_function(f);
+        pass::ConvertGatherToGatherIE().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN, DYN, DYN});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+        auto gather = std::make_shared<op::GatherIE>(input, indices, 1);
+
+        f_ref = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertGatherToGatherIEDynamic2) {
+    std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+        auto axis_const = opset1::Constant::create(element::i64, Shape{}, {1});
+        auto gather = std::make_shared<opset1::Gather>(input, indices, axis_const);
+
+        f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
+
+        pass::InitNodeInfo().run_on_function(f);
+        pass::ConvertGatherToGatherIE().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{DYN, DYN});
+        auto indices = std::make_shared<opset1::Parameter>(element::f32, Shape{});
+        auto unsqueeze = std::make_shared<opset1::Unsqueeze>(indices, opset1::Constant::create(element::i64, Shape{1}, {0}));
+        auto gather = std::make_shared<op::GatherIE>(input, unsqueeze, 1);
+        auto squeeze = std::make_shared<opset1::Squeeze>(gather, opset1::Constant::create(element::i64, Shape{1}, {1}));
+
+        f_ref = std::make_shared<Function>(NodeVector{squeeze}, ParameterVector{input, indices});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file