auto input_shape = slice->get_input_shape(0);
auto output_shape = slice->get_output_shape(0);
- // MKLDNN: "Crop supports only 2d, 4d and 5d blobs."
- if (input_shape.size() != 2 && input_shape.size() != 4 && input_shape.size() != 5) {
- return false;
- }
auto begin = begin_node->cast_vector<int64_t>();
auto end = end_node->cast_vector<int64_t>();
new_ops.push_back(data_node);
}
+ auto data_node_shape = data_node->get_output_shape(0);
+ // MKLDNN: "Crop supports only 2d, 4d and 5d blobs."
+ if (data_node_shape.size() != 2 && data_node_shape.size() != 4 && data_node_shape.size() != 5) {
+ return false;
+ }
+
// Crop
data_node = std::make_shared<ngraph::op::CropIE> (data_node, axes, dim, offset);
data_node->set_friendly_name(slice->get_friendly_name());
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
+}
+
+// in this test the Crop will get 3D input which is not supported so the transformation will not be applied
+TEST(TransformationTests, ConvertStridedSliceToCropNegative2) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{128, 1});
+ auto slice_begin = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+ auto slice_end = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+ auto slice_stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
+
+ std::vector<int64_t> begin_mask = {0, 1, 1};
+ std::vector<int64_t> end_mask = {0, 1, 1};
+ std::vector<int64_t> new_axis_mask = {1, 0, 0};
+ std::vector<int64_t> shrink_axis_mask = {0, 0, 0};
+ std::vector<int64_t> ellipsis_mask = {0, 0, 0};
+
+ auto sslice = std::make_shared<ngraph::opset1::StridedSlice>(input, slice_begin, slice_end, slice_stride,
+ begin_mask, end_mask,
+ new_axis_mask, shrink_axis_mask, ellipsis_mask);
+ sslice->set_friendly_name("strided_slice");
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
+ ngraph::pass::InitNodeInfo().run_on_function(f);
+ ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{128, 1});
+ auto slice_begin = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+ auto slice_end = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 0, 0});
+ auto slice_stride = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
+
+ std::vector<int64_t> begin_mask = {0, 1, 1};
+ std::vector<int64_t> end_mask = {0, 1, 1};
+ std::vector<int64_t> new_axis_mask = {1, 0, 0};
+ std::vector<int64_t> shrink_axis_mask = {0, 0, 0};
+ std::vector<int64_t> ellipsis_mask = {0, 0, 0};
+
+ auto sslice = std::make_shared<ngraph::opset1::StridedSlice>(input, slice_begin, slice_end, slice_stride,
+ begin_mask, end_mask,
+ new_axis_mask, shrink_axis_mask, ellipsis_mask);
+ sslice->set_friendly_name("strided_slice");
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
}
\ No newline at end of file