Fixed StridedSlice to Crop transformation (#836)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Tue, 9 Jun 2020 17:00:43 +0000 (20:00 +0300)
committerGitHub <noreply@github.com>
Tue, 9 Jun 2020 17:00:43 +0000 (20:00 +0300)
* Fixed StridedSlice to Crop transformation to not apply when rank of data is changed

* Added unit test for StridedSlice to Crop transformation

Co-authored-by: Evgeny Lazarev <elazarev.nnov@gmail.com>
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_strided_slice_to_crop_test.cpp

index 647e54f..96e2915 100644 (file)
@@ -41,10 +41,6 @@ void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
 
         auto input_shape = slice->get_input_shape(0);
         auto output_shape = slice->get_output_shape(0);
-        // MKLDNN: "Crop supports only 2d, 4d and 5d blobs."
-        if (input_shape.size() != 2 && input_shape.size() != 4 && input_shape.size() != 5) {
-            return false;
-        }
 
         auto begin = begin_node->cast_vector<int64_t>();
         auto end = end_node->cast_vector<int64_t>();
@@ -201,6 +197,12 @@ void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
             new_ops.push_back(data_node);
         }
 
+        auto data_node_shape = data_node->get_output_shape(0);
+        // MKLDNN: "Crop supports only 2d, 4d and 5d blobs."
+        if (data_node_shape.size() != 2 && data_node_shape.size() != 4 && data_node_shape.size() != 5) {
+            return false;
+        }
+
         // Crop
         data_node = std::make_shared<ngraph::op::CropIE> (data_node, axes, dim, offset);
         data_node->set_friendly_name(slice->get_friendly_name());
index 95189df..8e39546 100644 (file)
@@ -179,4 +179,54 @@ TEST(TransformationTests, ConvertStridedSliceToCropNegative) {
 
     auto res = compare_functions(f, f_ref);
     ASSERT_TRUE(res.first) << res.second;
+}
+
+// in this test the Crop will get 3D input which is not supported so the transformation will not be applied
+TEST(TransformationTests, ConvertStridedSliceToCropNegative2) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input        = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{128, 1});
+        auto slice_begin  = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+        auto slice_end    = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+        auto slice_stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
+
+        std::vector<int64_t> begin_mask       = {0, 1, 1};
+        std::vector<int64_t> end_mask         = {0, 1, 1};
+        std::vector<int64_t> new_axis_mask    = {1, 0, 0};
+        std::vector<int64_t> shrink_axis_mask = {0, 0, 0};
+        std::vector<int64_t> ellipsis_mask    = {0, 0, 0};
+
+        auto sslice = std::make_shared<ngraph::opset1::StridedSlice>(input, slice_begin, slice_end, slice_stride,
+                                                                     begin_mask, end_mask,
+                                                                     new_axis_mask, shrink_axis_mask, ellipsis_mask);
+        sslice->set_friendly_name("strided_slice");
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
+        ngraph::pass::InitNodeInfo().run_on_function(f);
+        ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input        = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{128, 1});
+        auto slice_begin  = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+        auto slice_end    = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+        auto slice_stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
+
+        std::vector<int64_t> begin_mask       = {0, 1, 1};
+        std::vector<int64_t> end_mask         = {0, 1, 1};
+        std::vector<int64_t> new_axis_mask    = {1, 0, 0};
+        std::vector<int64_t> shrink_axis_mask = {0, 0, 0};
+        std::vector<int64_t> ellipsis_mask    = {0, 0, 0};
+
+        auto sslice = std::make_shared<ngraph::opset1::StridedSlice>(input, slice_begin, slice_end, slice_stride,
+                                                                     begin_mask, end_mask,
+                                                                     new_axis_mask, shrink_axis_mask, ellipsis_mask);
+        sslice->set_friendly_name("strided_slice");
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
 }
\ No newline at end of file