Updated StridedSlice to StridedSliceIE conversion to support dynamic shapes (#621)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Wed, 27 May 2020 22:14:12 +0000 (01:14 +0300)
committerGitHub <noreply@github.com>
Wed, 27 May 2020 22:14:12 +0000 (01:14 +0300)
* Updated ConvertStridedSliceToStridedSliceIE transformation to support dynamic shapes

* Fixed stridesluce to crop transform not to fail with dynamic shapes

inference-engine/src/transformations/include/ngraph_ops/strided_slice_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_strided_slice_to_strided_slice_ie.hpp
inference-engine/src/transformations/src/ngraph_ops/strided_slice_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_strided_slice_to_strided_slice_ie.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_strided_slice_to_crop_test.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_stridedslice_to_stridedslice_ie_test.cpp [new file with mode: 0644]

index 22f7a3e..cb83c4d 100644 (file)
@@ -27,12 +27,11 @@ public:
                  const std::vector<int64_t>& end_mask,
                  const std::vector<int64_t>& new_axis_mask,
                  const std::vector<int64_t>& shrink_axis_mask,
-                 const std::vector<int64_t>& ellipsis_mask,
-                 const Shape& output_shape);
+                 const std::vector<int64_t>& ellipsis_mask);
 
     void validate_and_infer_types() override;
 
-    std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
+    std::shared_ptr<Node> clone_with_new_inputs(const OutputVector & new_args) const override;
 
     const std::vector<int64_t>& get_begin_mask() const { return m_begin_mask; }
     const std::vector<int64_t>& get_end_mask() const { return m_end_mask; }
@@ -46,7 +45,6 @@ protected:
     const std::vector<int64_t> m_new_axis_mask;
     const std::vector<int64_t> m_shrink_axis_mask;
     const std::vector<int64_t> m_ellipsis_mask;
-    Shape m_output_shape;
 };
 
 }  // namespace op
index 3c5b8d8..df2fd64 100644 (file)
@@ -19,6 +19,13 @@ class INFERENCE_ENGINE_API_CLASS(ConvertStridedSliceToStridedSliceIE);
 }  // namespace pass
 }  // namespace ngraph
 
+/*
+ * Description:
+ *     This transformation converts opset1::StridedSlice to legacy StridedSliceIE
+ *     StridedSliceIE takes begin, end and strides inputs ony in i32 precision.
+ *     Inputs with precision != i32 are converted with Convert operation.
+ */
+
 class ngraph::pass::ConvertStridedSliceToStridedSliceIE: public ngraph::pass::GraphRewrite {
 public:
     ConvertStridedSliceToStridedSliceIE() : GraphRewrite() {
index cb81d53..085d965 100644 (file)
@@ -7,6 +7,8 @@
 #include <algorithm>
 #include <vector>
 #include <memory>
+#include <ngraph/ops.hpp>
+#include <ngraph/opsets/opset1.hpp>
 
 using namespace std;
 using namespace ngraph;
@@ -17,27 +19,44 @@ op::StridedSliceIE::StridedSliceIE(const Output <Node> &data, const Output <Node
                                    const Output <Node> &strides, const std::vector<int64_t> &begin_mask,
                                    const std::vector<int64_t> &end_mask, const std::vector<int64_t> &new_axis_mask,
                                    const std::vector<int64_t> &shrink_axis_mask,
-                                   const std::vector<int64_t> &ellipsis_mask,
-                                   const Shape& output_shape)
-                                   : Op({data, begin, end, strides}),
-                                   m_begin_mask(begin_mask),
-                                   m_end_mask(end_mask),
-                                   m_new_axis_mask(new_axis_mask),
-                                   m_shrink_axis_mask(shrink_axis_mask),
-                                   m_ellipsis_mask(ellipsis_mask),
-                                   m_output_shape(output_shape) {
+                                   const std::vector<int64_t> &ellipsis_mask)
+    : Op({data, begin, end, strides})
+    , m_begin_mask(begin_mask)
+    , m_end_mask(end_mask)
+    , m_new_axis_mask(new_axis_mask)
+    , m_shrink_axis_mask(shrink_axis_mask)
+    , m_ellipsis_mask(ellipsis_mask) {
     constructor_validate_and_infer_types();
 }
 
-std::shared_ptr<Node> op::StridedSliceIE::copy_with_new_args(const NodeVector& new_args) const {
-    if (new_args.size() != 4) {
-        throw ngraph_error("Incorrect number of new arguments");
-    }
-
-    return make_shared<StridedSliceIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_begin_mask,
-            m_end_mask, m_new_axis_mask, m_shrink_axis_mask, m_ellipsis_mask, m_output_shape);
+std::shared_ptr<Node> op::StridedSliceIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
+    check_new_args_count(this, new_args);
+    return std::make_shared<op::StridedSliceIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_begin_mask,
+            m_end_mask, m_new_axis_mask, m_shrink_axis_mask, m_ellipsis_mask);
 }
 
 void op::StridedSliceIE::validate_and_infer_types() {
-    set_output_type(0, get_input_element_type(0), PartialShape(m_output_shape));
+    const auto& begin_mask_et = input_value(1).get_element_type();
+    const auto& end_mask_et = input_value(2).get_element_type();
+    const auto& strides_et = input_value(3).get_element_type();
+
+    NODE_VALIDATION_CHECK(this,
+                          begin_mask_et.is_integral_number(),
+                          "Begin mask must have i32 type, but its: ",
+                          begin_mask_et);
+
+    NODE_VALIDATION_CHECK(this,
+                          end_mask_et == element::i32,
+                          "End mask must have i32 type, but its: ",
+                          end_mask_et);
+
+    NODE_VALIDATION_CHECK(this,
+                          strides_et.is_integral_number(),
+                          "Strides must have i32 type, but its: ",
+                          strides_et);
+
+    // Calculate output shape via opset1::StridedSlice operation
+    auto slice = std::make_shared<opset1::StridedSlice>(input_value(0), input_value(1), input_value(2), input_value(3),
+            m_begin_mask, m_end_mask, m_new_axis_mask, m_shrink_axis_mask, m_ellipsis_mask);
+    set_output_type(0, slice->output(0).get_element_type(), slice->output(0).get_partial_shape());
 }
index 2ff6131..647e54f 100644 (file)
@@ -33,8 +33,6 @@ void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
         auto end_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(slice->get_argument(2));
         auto stride_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(slice->get_argument(3));
 
-        auto output_shape = slice->get_output_shape(0);
-
         auto partial_input_shape = slice->get_input_partial_shape(0);
 
         if (!begin_node || !end_node || !stride_node || partial_input_shape.is_dynamic()) {
@@ -42,6 +40,7 @@ void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
         }
 
         auto input_shape = slice->get_input_shape(0);
+        auto output_shape = slice->get_output_shape(0);
         // MKLDNN: "Crop supports only 2d, 4d and 5d blobs."
         if (input_shape.size() != 2 && input_shape.size() != 4 && input_shape.size() != 5) {
             return false;
index 938efdf..48f0a9d 100644 (file)
 #include <ngraph/rt_info.hpp>
 
 void ngraph::pass::ConvertStridedSliceToStridedSliceIE::convert_strided_slice_to_strided_slice_ie() {
-    auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto m_begin = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
-    auto m_end = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
-    auto m_stride = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
-    std::vector<int64_t> begin_mask = {0, 0, 0, 0};
-    std::vector<int64_t> end_mask = {0, 0, 0, 0};
-    auto m_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, m_begin, m_end, m_stride, begin_mask, end_mask);
+    auto slice = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset1::StridedSlice>());
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
-        auto strided_slice = std::dynamic_pointer_cast<ngraph::opset1::StridedSlice> (m.get_match_root());
-        if (!strided_slice) {
+        auto slice = std::dynamic_pointer_cast<opset1::StridedSlice> (m.get_match_root());
+        if (!slice) {
             return false;
         }
 
-        auto data_node = strided_slice->get_argument(0);
-        auto begin_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(strided_slice->get_argument(1));
-        auto end_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(strided_slice->get_argument(2));
-        auto stride_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(strided_slice->get_argument(3));
-
-        auto output_shape = strided_slice->get_output_shape(0);
+        auto data_node = slice->input_value(0);
+        auto begin_node = std::dynamic_pointer_cast<opset1::Constant>(slice->input_value(1).get_node_shared_ptr());
+        auto end_node = std::dynamic_pointer_cast<opset1::Constant>(slice->input_value(2).get_node_shared_ptr());
+        auto stride_node = std::dynamic_pointer_cast<opset1::Constant>(slice->input_value(3).get_node_shared_ptr());
 
         if (!begin_node || !end_node || !stride_node) {
             return false;
         }
 
-        auto shrink_axis_mask = strided_slice->get_shrink_axis_mask();
-        auto new_axis_mask = strided_slice->get_new_axis_mask();
-        auto ellipsis_mask = strided_slice->get_ellipsis_mask();
-        auto begin_mask = strided_slice->get_begin_mask();
-        auto end_mask = strided_slice->get_end_mask();
-
-        auto converted_begin = std::make_shared<ngraph::opset1::Convert>(begin_node, ngraph::element::Type_t::i32);
-        auto converted_end = std::make_shared<ngraph::opset1::Convert>(end_node, ngraph::element::Type_t::i32);
-        auto converted_stride = std::make_shared<ngraph::opset1::Convert>(stride_node, ngraph::element::Type_t::i32);
-
-        auto strided_slice_ie = std::make_shared<ngraph::op::StridedSliceIE>(data_node,
-                                                                             converted_begin, converted_end,
-                                                                             converted_stride,
-                                                                             begin_mask, end_mask, new_axis_mask, shrink_axis_mask, ellipsis_mask,
-                                                                             output_shape);
-        strided_slice_ie->set_friendly_name(strided_slice->get_friendly_name());
-
-        ngraph::copy_runtime_info(strided_slice, {converted_begin, converted_end, converted_stride, strided_slice_ie});
-        ngraph::replace_node(strided_slice, strided_slice_ie);
+        auto converted_begin = std::make_shared<opset1::Convert>(begin_node, element::i32);
+        auto converted_end = std::make_shared<opset1::Convert>(end_node, element::i32);
+        auto converted_stride = std::make_shared<opset1::Convert>(stride_node, element::i32);
+
+        auto slice_ie = std::make_shared<ngraph::op::StridedSliceIE>(data_node,
+                                                                     converted_begin,
+                                                                     converted_end,
+                                                                     converted_stride,
+                                                                     slice->get_begin_mask(),
+                                                                     slice->get_end_mask(),
+                                                                     slice->get_new_axis_mask(),
+                                                                     slice->get_shrink_axis_mask(),
+                                                                     slice->get_ellipsis_mask());
+        slice_ie->set_friendly_name(slice->get_friendly_name());
+
+        ngraph::copy_runtime_info(slice, {converted_begin, converted_end, converted_stride, slice_ie});
+        ngraph::replace_node(slice, slice_ie);
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(m_slice, "ConvertStridedSliceToStridedSliceIE");
+    auto m = std::make_shared<ngraph::pattern::Matcher>(slice, "ConvertStridedSliceToStridedSliceIE");
     this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
 }
\ No newline at end of file
index 5f832af..95189df 100644 (file)
@@ -131,3 +131,52 @@ TEST(TransformationTests, ConvertStridedSliceToCropTests2) {
                              (reshape_node->get_friendly_name() == "strided_slice");
     ASSERT_TRUE(names_are_correct) << "Transformation ConvertStridedSliceToCrop should keep output names.\n";
 }
+
+TEST(TransformationTests, ConvertStridedSliceToCropNegative) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input        = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
+        auto slice_begin  = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 1, 0, 0});
+        auto slice_end    = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 2, 0, 0});
+        auto slice_stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
+
+        std::vector<int64_t> begin_mask       = {1, 0, 1, 1};
+        std::vector<int64_t> end_mask         = {1, 0, 1, 1};
+        std::vector<int64_t> new_axis_mask    = {0, 0, 0, 0};
+        std::vector<int64_t> shrink_axis_mask = {0, 1, 0, 0};
+        std::vector<int64_t> ellipsis_mask    = {0, 0, 0, 0};
+
+        auto sslice = std::make_shared<ngraph::opset1::StridedSlice>(input, slice_begin, slice_end, slice_stride,
+                                                                     begin_mask, end_mask,
+                                                                     new_axis_mask, shrink_axis_mask, ellipsis_mask);
+        sslice->set_friendly_name("strided_slice");
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input        = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
+        auto slice_begin  = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 1, 0, 0});
+        auto slice_end    = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 2, 0, 0});
+        auto slice_stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 1, 1, 1});
+
+        std::vector<int64_t> begin_mask       = {1, 0, 1, 1};
+        std::vector<int64_t> end_mask         = {1, 0, 1, 1};
+        std::vector<int64_t> new_axis_mask    = {0, 0, 0, 0};
+        std::vector<int64_t> shrink_axis_mask = {0, 1, 0, 0};
+        std::vector<int64_t> ellipsis_mask    = {0, 0, 0, 0};
+
+        auto sslice = std::make_shared<ngraph::opset1::StridedSlice>(input, slice_begin, slice_end, slice_stride,
+                                                                     begin_mask, end_mask,
+                                                                     new_axis_mask, shrink_axis_mask, ellipsis_mask);
+        sslice->set_friendly_name("strided_slice");
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file
diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_stridedslice_to_stridedslice_ie_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_stridedslice_to_stridedslice_ie_test.cpp
new file mode 100644 (file)
index 0000000..41b5834
--- /dev/null
@@ -0,0 +1,97 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset1.hpp>
+#include <transformations/convert_opset1_to_legacy/convert_strided_slice_to_strided_slice_ie.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph/pass/constant_folding.hpp>
+#include <ngraph_ops/strided_slice_ie.hpp>
+
+#include "ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, ConvertStridedSliceToStridedSliceIEStatic) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3, 4});
+        auto begin = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
+        auto end = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
+        auto stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
+
+        std::vector<int64_t> begin_mask{0, 0, 0, 0};
+        std::vector<int64_t> end_mask{1, 1, 1, 1};
+
+        auto ss = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        ngraph::pass::ConvertStridedSliceToStridedSliceIE().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+        ngraph::pass::ConstantFolding().run_on_function(f);
+    }
+
+    {
+        auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 2, 3, 4});
+        auto begin = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 0, 0, 0});
+        auto end = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {-1, -1, -1, -1});
+        auto stride = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {1});
+
+        std::vector<int64_t> begin_mask{0, 0, 0, 0}, end_mask{1, 1, 1, 1}, new_axis_mask{}, shrink_axis_mask{}, ellipsis_mask{};
+
+        auto ss = std::make_shared<ngraph::op::StridedSliceIE>(data, begin, end, stride,
+                begin_mask, end_mask, new_axis_mask, shrink_axis_mask, ellipsis_mask);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertStridedSliceToStridedSliceIEDynamic) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
+        auto begin = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
+        auto end = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1, -1, -1, -1});
+        auto stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1});
+
+        std::vector<int64_t> begin_mask{0, 0, 0, 0};
+        std::vector<int64_t> end_mask{1, 1, 1, 1};
+
+        auto ss = std::make_shared<ngraph::opset1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        ngraph::pass::ConvertStridedSliceToStridedSliceIE().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+        ngraph::pass::ConstantFolding().run_on_function(f);
+    }
+
+    {
+        auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(4));
+        auto begin = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 0, 0, 0});
+        auto end = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {-1, -1, -1, -1});
+        auto stride = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {1});
+
+        std::vector<int64_t> begin_mask{0, 0, 0, 0}, end_mask{1, 1, 1, 1}, new_axis_mask{}, shrink_axis_mask{}, ellipsis_mask{};
+
+        auto ss = std::make_shared<ngraph::op::StridedSliceIE>(data, begin, end, stride,
+                begin_mask, end_mask, new_axis_mask, shrink_axis_mask, ellipsis_mask);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}