Reshape-able SliceConverter (#2954)
authorYegor Kruglov <yegor.kruglov@intel.com>
Tue, 10 Nov 2020 14:51:59 +0000 (17:51 +0300)
committerGitHub <noreply@github.com>
Tue, 10 Nov 2020 14:51:59 +0000 (17:51 +0300)
* initial commit

* add cast

* data type fix

* added tests

* added test without axes and steps

* remove redundant imports

* discussions resolving

* Add cast to TFSliceToSlice

* layer tests fix

* update unittest

model-optimizer/extensions/front/tf/TFSliceToSlice.py
model-optimizer/extensions/front/tf/TFSliceToSlice_test.py
model-optimizer/extensions/middle/SliceConverter.py
model-optimizer/extensions/middle/SliceConverter_test.py

index 8c03ca7..62e5987 100644 (file)
@@ -16,6 +16,7 @@
 
 import numpy as np
 
+from extensions.ops.Cast import Cast
 from extensions.ops.elementwise import Add, Equal
 from extensions.ops.select import Select
 from mo.front.common.replacement import FrontReplacementOp
@@ -74,4 +75,7 @@ class TFSliceToSliceReplacer(FrontReplacementOp):
         # out of select to end (2nd of slice)
         select_node.out_port(0).connect(slice_node.in_port(2))
 
+        cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node()
+        select_node.in_port(2).get_connection().insert_node(cast)
+
         node.out_port(0).get_connection().set_source(slice_node.out_port(0))
index 14be81e..2919a71 100644 (file)
@@ -37,6 +37,7 @@ nodes = {
     **regular_op_with_empty_data('equal', {'op': 'Equal', 'type': 'Equal'}),
     **regular_op_with_empty_data('select', {'op': 'Select', 'type': 'Select'}),
     **regular_op_with_empty_data('slice', {'op': 'Slice', 'type': None}),
+    **regular_op_with_empty_data('cast', {'op': 'Cast', 'type': 'Convert'}),
 }
 
 
@@ -68,7 +69,8 @@ class SliceReplacerTest(unittest.TestCase):
 
             *connect_front('equal:0', 'select:0'),
 
-            *connect_front('end_const:0', 'select:2'),
+            *connect_front('end_const:0', 'cast:0'),
+            *connect_front('cast:0', 'select:2'),
             *connect_front('select:0', 'slice:2'),
 
             *connect_front('slice:0', 'output'),
@@ -97,7 +99,8 @@ class SliceReplacerTest(unittest.TestCase):
             *connect_front('int32_max:0', '1:select'),
             *connect_front('minus_one:0', '1:equal'),
             *connect_front('equal:0', '0:select'),
-            *connect_front('end_const:0', '2:select'),
+            *connect_front('end_const:0', '0:cast'),
+            *connect_front('cast:0', '2:select'),
             *connect_front('select:0', '2:slice'),
             *connect_front('slice:0', 'output'),
         ], nodes_with_edges_only=True)
index 5ed91a7..7f2b11a 100644 (file)
 
 import numpy as np
 
+from extensions.ops.Cast import Cast
+from mo.front.caffe.extractors.utils import get_canonical_axis_index
 from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.graph_utils import create_op_with_const_inputs
 from mo.graph.graph import Graph, rename_nodes
+from mo.graph.port import Port
 from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.concat import Concat
 from mo.ops.const import Const
 from mo.ops.strided_slice import StridedSlice
-from mo.utils.error import Error
 
 
-def convert_negative_indices(indices: np.array, shape: np.array):
-    for ind, value in enumerate(indices):
-        if value < 0:
-            indices[ind] += shape[ind]
+def create_ss_interval_border(graph: Graph, shape, axes, port_to_connect: Port, node_name):
+    shape_mask = np.zeros(len(shape), dtype=np.int64)
+    first_part = shape_mask[:axes[0]]
+    last_part = shape_mask[axes[-1] + 1:]
+
+    cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node()
+    port_to_connect.get_connection().set_destination(cast.in_port(0))
+    concat = create_op_with_const_inputs(graph, Concat, port_value_dict={0: first_part, 2: last_part},
+                                         op_attrs={'name': node_name + '/Concat', 'axis': 0,
+                                                   'in_ports_count': 3})
+    cast.out_port(0).connect(concat.in_port(1))
+    return concat
 
 
 class ConvertSlice(MiddleReplacementPattern):
@@ -36,80 +48,57 @@ class ConvertSlice(MiddleReplacementPattern):
     """
 
     enabled = True
-    op = "Slice"
     force_clean_up = True
+    op = "Slice"
 
     def run_after(self):
         from extensions.middle.pass_separator import MiddleStart
         return [MiddleStart]
 
-    def pattern(self):
-        return dict(
-            nodes=[
-                ('slice', dict(kind='op', op='Slice'))
-            ],
-            edges=[]
-        )
-
-    def replace_pattern(self, graph: Graph, match: dict):
-        node = match['slice']
-
-        input_shape = node.in_port(0).data.get_shape()
-        output_shape = node.out_port(0).data.get_shape()
-        starts = node.in_port(1).data.get_value()
-        ends = node.in_port(2).data.get_value()
-        if starts is None or ends is None:
-            raise Error('The input with starts or end is not constant for node {}'.format(node.id))
-
-        # the value for 'ends' is usually maximum possible value of int64. This
-        # value must be converted to maximum of int32 because such big values do not fit into the int32 which is
-        # supported by the StridedSlice layer
-        ends = np.clip(ends, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
-        if node.is_in_port_connected(3):
-            axes = node.in_port(3).data.get_value()
-            if axes is None:
-                raise Error('The input with axes is not constant for node {}'.format(node.id))
-        else:
-            axes = int64_array(list(range(starts.size)))
-
-        if node.is_in_port_connected(4):
-            steps = node.in_port(4).data.get_value()
-            if steps is None:
-                raise Error('The input with steps is not constant for node {}'.format(node.id))
-        else:
-            steps = np.ones([starts.size])
-
-        ss_begin_mask = np.zeros(len(input_shape), dtype=np.int32)
-        ss_end_mask = np.zeros(len(input_shape), dtype=np.int32)
-        ss_begin = np.zeros(len(input_shape), dtype=np.int32)
-        ss_end = np.zeros(len(input_shape), dtype=np.int32)
-        ss_step = np.ones(len(input_shape), dtype=np.int32)
-
-        # prepare inputs and attributes for the StridedSlice layer
-        for i, axis in enumerate(axes):
-            if starts[i] != 0:
+    def find_and_replace_pattern(self, graph: Graph):
+        for node in graph.get_op_nodes(op='Slice'):
+            node_name = node.soft_get('name', node.id)
+
+            input_shape = node.in_port(0).data.get_shape()
+            if node.is_in_port_connected(3):
+                axes = node.in_port(3).data.get_value().copy()
+                assert axes is not None, 'The input with axes is not constant for node {}'.format(node_name)
+                for i, val in enumerate(axes):
+                    axes[i] = get_canonical_axis_index(input_shape, val)
+            else:
+                axes = int64_array(range(len(input_shape)))
+
+            ss_begin = create_ss_interval_border(graph, input_shape, axes, node.in_port(1).get_source(), node_name)
+            ss_end = create_ss_interval_border(graph, input_shape, axes, node.in_port(2).get_source(), node_name)
+            rename_nodes([(ss_begin, node_name + '/Begin'), (ss_end, node_name + '/End')])
+
+            if node.is_in_port_connected(4):
+                steps = node.in_port(4).data.get_value()
+                assert steps is not None, 'The input with steps is not constant for node {}'.format(node_name)
+            else:
+                steps = np.ones([axes.size])
+
+            ss_begin_mask = np.zeros(len(input_shape), dtype=np.int64)
+            ss_end_mask = np.zeros(len(input_shape), dtype=np.int64)
+            ss_step = np.ones(len(input_shape), dtype=np.int64)
+
+            for i, axis in enumerate(axes):
                 ss_begin_mask[axis] = 1
-                ss_begin[axis] = starts[i]
-
-            ss_end_mask[axis] = 1
-            ss_end[axis] = ends[i]
-
-            ss_step[axis] = steps[i]
-
-        slice_node_name = node.soft_get('name', node.id)
-
-        begin_node = Const(graph, {'value': ss_begin, 'name': slice_node_name + '/begin'}).create_node()
-        end_node = Const(graph, {'value': ss_end, 'name': slice_node_name + '/end'}).create_node()
-        strides_node = Const(graph, {'value': ss_step, 'name': slice_node_name + '/stride'}).create_node()
-
-        ss = StridedSlice(graph, dict(new_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
-                                      shrink_axis_mask=np.zeros(len(output_shape), dtype=np.int32),
-                                      ellipsis_mask=np.zeros(len(output_shape), dtype=np.int32),
-                                      begin_mask=ss_begin_mask,
-                                      end_mask=ss_end_mask)).create_node()
-        rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)])
-        node.in_port(0).get_connection().set_destination(ss.in_port(0))
-        begin_node.out_port(0).connect(ss.in_port(1))
-        end_node.out_port(0).connect(ss.in_port(2))
-        strides_node.out_port(0).connect(ss.in_port(3))
-        node.out_port(0).get_connection().set_source(ss.out_port(0))
+                ss_end_mask[axis] = 1
+                ss_step[axis] = steps[i]
+
+            ss_strides = Const(graph, dict(name=node_name + '/Strides', value=ss_step)).create_node()
+
+            ss = StridedSlice(graph, dict(name='ss', new_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
+                                          shrink_axis_mask=np.zeros(len(input_shape), dtype=np.int64),
+                                          ellipsis_mask=np.zeros(len(input_shape), dtype=np.int64),
+                                          begin_mask=ss_begin_mask,
+                                          end_mask=ss_end_mask, override_output_shape=True)).create_node()
+
+            node.in_port(0).get_connection().set_destination(ss.in_port(0))
+            ss.in_port(1).connect(ss_begin.out_port(0))
+            ss.in_port(2).connect(ss_end.out_port(0))
+            ss.in_port(3).connect(ss_strides.out_port(0))
+            node.out_port(0).get_connection().set_source(ss.out_port(0))
+
+            rename_nodes([(node, node_name + '/ShouldBeDeleted'), (ss, node_name)])
index 92b118d..377cf67 100644 (file)
 import unittest
 
 import numpy as np
+from generator import generate, generator
 
 from extensions.middle.SliceConverter import ConvertSlice
 from mo.front.common.partial_infer.utils import int64_array
-from mo.graph.graph import Node
-from mo.ops.slice import Slice
 from mo.utils.ir_engine.compare_graphs import compare_graphs
-from mo.utils.unittest.graph import build_graph
-
-nodes_attributes = {
-    # input data
-    'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-    'placeholder_2': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-    'placeholder_3': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-    'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-    'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-    'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-    # Slice layer
-    'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
-    'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
-    # Output operation
-    'output_op': {'type': 'Const', 'value': None, 'kind': 'op', 'op': 'Const'},
-    'output_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-    'op_output': { 'kind': 'op', 'op': 'Result'},
-    # StridedSlice layer
-    'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None}
-}
+from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
+    regular_op_with_empty_data, result, connect, const, empty_data
 
 
+@generator
 class ConvertSliceTests(unittest.TestCase):
-    nodes_attributes = {
-        # input data
-        'placeholder_1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-        # Slice layer inputs
-        'starts': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-        'starts_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-        'ends': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-        'ends_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-        'strides': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-        'strides_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-        'axes': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-        'axes_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-        'steps': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-        'steps_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
-        # Slice layer
-        'slice': {'type': 'Slice', 'kind': 'op', 'op': 'Slice', 'name': 'slice_node'},
-        'slice_data': {'value': None, 'shape': None, 'kind': 'data'},
-        # Output operation
-        'output_op': {'type': 'Const', 'kind': 'op', 'op': 'Const'},
-        'output_data': {'shape': None, 'kind': 'data', 'data_type': None},
-        'op_output': {'kind': 'op', 'op': 'Result'},
-        # StridedSlice layer
-        'strided_slice': {'kind': 'op', 'op': 'StridedSlice', 'slices': None, 'shrink_axis_mask': None}
-    }
-
-    def test_slice_all_params(self):
-        input_shape = int64_array([5, 10, 20])
-        starts_value = int64_array([4, 2])
-        ends_value = int64_array([15, 8])
-        axes_value = int64_array([2, 1])
-        steps_value = int64_array([1, 1])
-
-        masks_value = np.zeros([len(input_shape)], dtype=np.int64)
-        graph = build_graph(self.nodes_attributes,
-                            [('placeholder_1', 'placeholder_1_data'),
-                             ('placeholder_1_data', 'slice', {'in': 0}),
-                             ('starts', 'starts_data'),
-                             ('starts_data', 'slice', {'in': 1}),
-                             ('ends', 'ends_data'),
-                             ('ends_data', 'slice', {'in': 2}),
-                             ('axes', 'axes_data'),
-                             ('axes_data', 'slice', {'in': 3}),
-                             ('steps', 'steps_data'),
-                             ('steps_data', 'slice', {'in': 4}),
-                             ('slice', 'slice_data'),
-                             ('slice_data', 'output_op'),
-                             ('output_op', 'output_data'),
-                             ('output_data', 'op_output')
-                             ],
-                            {'placeholder_1_data': {'shape': input_shape},
-                             'starts': {'shape': starts_value.shape, 'value': starts_value},
-                             'starts_data': {'shape': starts_value.shape, 'value': starts_value},
-                             'ends': {'shape': ends_value.shape, 'value': ends_value},
-                             'ends_data': {'shape': ends_value.shape, 'value': ends_value},
-                             'steps': {'shape': steps_value.shape, 'value': steps_value},
-                             'steps_data': {'shape': steps_value.shape, 'value': steps_value},
-                             'axes': {'shape': axes_value.shape, 'value': axes_value},
-                             'axes_data': {'shape': axes_value.shape, 'value': axes_value},
-                             }, nodes_with_edges_only=True
-                            )
-        slice_node = Node(graph, 'slice')
-        Slice.infer(slice_node)
-
-        pattern = ConvertSlice()
-        pattern.find_and_replace_pattern(graph)
-
-        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
-        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
-
-        graph_ref = build_graph(self.nodes_attributes,
-                                [('placeholder_1', 'placeholder_1_data'),
-                                 ('placeholder_1_data', 'strided_slice', {'in': 0}),
-                                 ('starts', 'starts_data'),
-                                 ('starts_data', 'strided_slice', {'in': 1}),
-                                 ('ends', 'ends_data'),
-                                 ('ends_data', 'strided_slice', {'in': 2}),
-                                 ('strides', 'strides_data'),
-                                 ('strides_data', 'strided_slice', {'in': 3}),
-                                 ('strided_slice', 'slice_data'),
-                                 ('slice_data', 'output_op'),
-                                 ('output_op', 'output_data'),
-                                 ('output_data', 'op_output')
-                                 ],
-                                {'placeholder_1_data': {'shape': input_shape},
-                                 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
-                                                   'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]),
-                                                   'end_mask': int64_array([0, 1, 1])},
-                                 'slice_data': {'shape': int64_array([5, 6, 11])}
-                                 }, nodes_with_edges_only=True
-                                )
-        (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
+    @generate(*[
+        (int64_array([1, 3, 300, 300]), np.array([0, 0]), np.array([150, 150]), np.array([2, 3]), np.array([1, 1]),
+         (int64_array([0, 0]), int64_array([])), (int64_array([0, 0]), int64_array([])), int64_array([1, 1, 1, 1]),
+         int64_array([0, 0, 1, 1]), int64_array([0, 0, 1, 1])),
+
+        (int64_array([1, 3, 300, 300]), np.array([0]), np.array([150]), np.array([2]), np.array([1]),
+         (int64_array([0, 0]), int64_array([0])), (int64_array([0, 0]), int64_array([0])), int64_array([1, 1, 1, 1]),
+         int64_array([0, 0, 1, 0]), int64_array([0, 0, 1, 0])),
+
+        (int64_array([1, 3, 300, 300]), np.array([0, 0]), np.array([150, 150]), np.array([-2, -1]), np.array([1, 1]),
+         (int64_array([0, 0]), int64_array([])), (int64_array([0, 0]), int64_array([])), int64_array([1, 1, 1, 1]),
+         int64_array([0, 0, 1, 1]), int64_array([0, 0, 1, 1]))
+    ])
+    def test_convert_slice_to_strided_slice(self, input_shape, start, end, axes, steps,
+                                            ss_begin_parts: tuple, ss_end_parts: tuple, ss_steps,
+                                            ss_begin_mask, ss_end_mask):
+        graph = build_graph(
+            nodes_attrs={
+                **regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter'}),
+                **valued_const_with_data('start', start),
+                **valued_const_with_data('end', end),
+                **valued_const_with_data('axes', axes),
+                **valued_const_with_data('steps', steps),
+                **regular_op_with_empty_data('slice', {'type': None, 'op': 'Slice'}),
+                **result('result')
+            },
+            edges=[
+                *connect('input', 'slice'),
+                *connect('start', '1:slice'),
+                *connect('end', '2:slice'),
+                *connect('axes', '3:slice'),
+                *connect('steps', '4:slice'),
+                *connect('slice', 'result')
+            ]
+        )
+        ref_graph = build_graph(
+            nodes_attrs={
+                **regular_op_with_shaped_data('input', input_shape, {'type': 'Parameter'}),
+                **valued_const_with_data('start', start),
+                **valued_const_with_data('begin_first_part', ss_begin_parts[0]),
+                **valued_const_with_data('begin_last_part', ss_begin_parts[1]),
+                **regular_op_with_empty_data('convert_start', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
+                **regular_op_with_empty_data('ss_begin', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
+                **valued_const_with_data('end', end),
+                **valued_const_with_data('end_first_part', ss_end_parts[0]),
+                **valued_const_with_data('end_last_part', ss_end_parts[1]),
+                **regular_op_with_empty_data('convert_end', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
+                **regular_op_with_empty_data('ss_end', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
+                **const('ss_steps', ss_steps),
+                **empty_data('ss_steps_d'),
+                **regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice',
+                                                    'begin_mask': ss_begin_mask, 'end_mask': ss_end_mask,
+                                                    'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
+                                                    'shrink_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
+                                                    'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64)}),
+                **result('result')
+            },
+            edges=[
+                *connect('input', 'ss'),
+                *connect('begin_first_part', 'ss_begin'),
+                *connect('start', 'convert_start'),
+                *connect('convert_start', '1:ss_begin'),
+                *connect('begin_last_part', '2:ss_begin'),
+                *connect('ss_begin', '1:ss'),
+                *connect('end_first_part', 'ss_end'),
+                *connect('end', 'convert_end'),
+                *connect('convert_end', '1:ss_end'),
+                *connect('end_last_part', '2:ss_end'),
+                *connect('ss_end', '2:ss'),
+                *connect('ss_steps', '3:ss'),
+                *connect('ss', 'result')
+            ]
+        )
+        ConvertSlice().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
         self.assertTrue(flag, resp)
 
-    def test_no_steps_no_axes(self):
-        input_shape = int64_array([5, 10, 20])
-        starts_value = int64_array([3, 2, 7])
-        ends_value = int64_array([5, 8, 15])
-        steps_value = int64_array([1, 1, 1])
-        masks_value = np.zeros([len(input_shape)], dtype=np.int64)
-        graph = build_graph(self.nodes_attributes,
-                            [('placeholder_1', 'placeholder_1_data'),
-                             ('placeholder_1_data', 'slice', {'in': 0}),
-                             ('starts', 'starts_data'),
-                             ('starts_data', 'slice', {'in': 1}),
-                             ('ends', 'ends_data'),
-                             ('ends_data', 'slice', {'in': 2}),
-                             ('slice', 'slice_data'),
-                             ('slice_data', 'output_op'),
-                             ('output_op', 'output_data'),
-                             ('output_data', 'op_output')
-                             ],
-                            {'placeholder_1_data': {'shape': input_shape},
-                             'starts': {'shape': starts_value.shape, 'value': starts_value},
-                             'starts_data': {'shape': starts_value.shape, 'value': starts_value},
-                             'ends': {'shape': ends_value.shape, 'value': ends_value},
-                             'ends_data': {'shape': ends_value.shape, 'value': ends_value},
-                             }, nodes_with_edges_only=True
-                            )
-        slice_node = Node(graph, 'slice')
-        Slice.infer(slice_node)
-
-        pattern = ConvertSlice()
-        pattern.find_and_replace_pattern(graph)
-
-        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
-        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
-
-        graph_ref = build_graph(self.nodes_attributes,
-                                [('placeholder_1', 'placeholder_1_data'),
-                                 ('placeholder_1_data', 'strided_slice', {'in': 0}),
-                                 ('starts', 'starts_data'),
-                                 ('starts_data', 'strided_slice', {'in': 1}),
-                                 ('ends', 'ends_data'),
-                                 ('ends_data', 'strided_slice', {'in': 2}),
-                                 ('strides', 'strides_data'),
-                                 ('strides_data', 'strided_slice', {'in': 3}),
-                                 ('strided_slice', 'slice_data'),
-                                 ('slice_data', 'output_op'),
-                                 ('output_op', 'output_data'),
-                                 ('output_data', 'op_output')
-                                 ],
-                                {'placeholder_1_data': {'shape': input_shape},
-                                 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
-                                                   'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
-                                                   'end_mask': np.ones([3])},
-                                 'slice_data': {'shape': int64_array([2, 6, 8])}
-                                 }, nodes_with_edges_only=True
-                                )
-        (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
-        self.assertTrue(flag, resp)
-
-    def test_no_axes(self):
-        input_shape = int64_array([5, 10, 20])
-        starts_value = int64_array([3, 2, 7])
-        ends_value = int64_array([5, 8, 15])
-        steps_value = int64_array([2, 3, 1])
-        masks_value = np.zeros([len(input_shape)], dtype=np.int64)
-        graph = build_graph(self.nodes_attributes,
-                            [('placeholder_1', 'placeholder_1_data'),
-                             ('placeholder_1_data', 'slice', {'in': 0}),
-                             ('starts', 'starts_data'),
-                             ('starts_data', 'slice', {'in': 1}),
-                             ('ends', 'ends_data'),
-                             ('ends_data', 'slice', {'in': 2}),
-                             ('steps', 'steps_data'),
-                             ('steps_data', 'slice', {'in': 4}),
-                             ('slice', 'slice_data'),
-                             ('slice_data', 'output_op'),
-                             ('output_op', 'output_data'),
-                             ('output_data', 'op_output')
-                             ],
-                            {'placeholder_1_data': {'shape': input_shape},
-                             'starts': {'shape': starts_value.shape, 'value': starts_value},
-                             'starts_data': {'shape': starts_value.shape, 'value': starts_value},
-                             'ends': {'shape': ends_value.shape, 'value': ends_value},
-                             'ends_data': {'shape': ends_value.shape, 'value': ends_value},
-                             'steps': {'shape': steps_value.shape, 'value': steps_value},
-                             'steps_data': {'shape': steps_value.shape, 'value': steps_value},
-                             }, nodes_with_edges_only=True
-                            )
-        slice_node = Node(graph, 'slice')
-        Slice.infer(slice_node)
-
-        pattern = ConvertSlice()
-        pattern.find_and_replace_pattern(graph)
-
-        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
-        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
-
-        graph_ref = build_graph(self.nodes_attributes,
-                                [('placeholder_1', 'placeholder_1_data'),
-                                 ('placeholder_1_data', 'strided_slice', {'in': 0}),
-                                 ('starts', 'starts_data'),
-                                 ('starts_data', 'strided_slice', {'in': 1}),
-                                 ('ends', 'ends_data'),
-                                 ('ends_data', 'strided_slice', {'in': 2}),
-                                 ('strides', 'strides_data'),
-                                 ('strides_data', 'strided_slice', {'in': 3}),
-                                 ('strided_slice', 'slice_data'),
-                                 ('slice_data', 'output_op'),
-                                 ('output_op', 'output_data'),
-                                 ('output_data', 'op_output')
-                                 ],
-                                {'placeholder_1_data': {'shape': input_shape},
-                                 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
-                                                   'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
-                                                   'end_mask': np.ones([3])},
-                                 'slice_data': {'shape': int64_array([1, 2, 8])}
-                                 }, nodes_with_edges_only=True
-                                )
-        (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
-        self.assertTrue(flag, resp)
-
-    def test_no_steps(self):
-        input_shape = int64_array([5, 10, 20])
-        starts_value = int64_array([4, 2])
-        ends_value = int64_array([15, 8])
-        axes_value = int64_array([2, 1])
-        masks_value = np.zeros([len(input_shape)], dtype=np.int64)
-        graph = build_graph(self.nodes_attributes,
-                            [('placeholder_1', 'placeholder_1_data'),
-                             ('placeholder_1_data', 'slice', {'in': 0}),
-                             ('starts', 'starts_data'),
-                             ('starts_data', 'slice', {'in': 1}),
-                             ('ends', 'ends_data'),
-                             ('ends_data', 'slice', {'in': 2}),
-                             ('axes', 'axes_data'),
-                             ('axes_data', 'slice', {'in': 3}),
-                             ('slice', 'slice_data'),
-                             ('slice_data', 'output_op'),
-                             ('output_op', 'output_data'),
-                             ('output_data', 'op_output')
-                             ],
-                            {'placeholder_1_data': {'shape': input_shape},
-                             'starts': {'shape': starts_value.shape, 'value': starts_value},
-                             'starts_data': {'shape': starts_value.shape, 'value': starts_value},
-                             'ends': {'shape': ends_value.shape, 'value': ends_value},
-                             'ends_data': {'shape': ends_value.shape, 'value': ends_value},
-                             'axes': {'shape': axes_value.shape, 'value': axes_value},
-                             'axes_data': {'shape': axes_value.shape, 'value': axes_value},
-                             }, nodes_with_edges_only=True
-                            )
-        slice_node = Node(graph, 'slice')
-        Slice.infer(slice_node)
-
-        pattern = ConvertSlice()
-        pattern.find_and_replace_pattern(graph)
-
-        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
-        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'
-
-        graph_ref = build_graph(self.nodes_attributes,
-                                [('placeholder_1', 'placeholder_1_data'),
-                                 ('placeholder_1_data', 'strided_slice', {'in': 0}),
-                                 ('starts', 'starts_data'),
-                                 ('starts_data', 'strided_slice', {'in': 1}),
-                                 ('ends', 'ends_data'),
-                                 ('ends_data', 'strided_slice', {'in': 2}),
-                                 ('strides', 'strides_data'),
-                                 ('strides_data', 'strided_slice', {'in': 3}),
-                                 ('strided_slice', 'slice_data'),
-                                 ('slice_data', 'output_op'),
-                                 ('output_op', 'output_data'),
-                                 ('output_data', 'op_output')
-                                 ],
-                                {'placeholder_1_data': {'shape': input_shape},
-                                 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
-                                                   'ellipsis_mask': masks_value, 'begin_mask': int64_array([0, 1, 1]),
-                                                   'end_mask': int64_array([0, 1, 1])},
-                                 'slice_data': {'shape': int64_array([5, 6, 11])}
-                                 }, nodes_with_edges_only=True
-                                )
-        (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
+    def test_convert_slice_to_strided_slice_without_axes_and_steps(self):
+        graph = build_graph(
+            nodes_attrs={
+                **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {'type': 'Parameter'}),
+                **valued_const_with_data('start', np.array([0, 0, 0])),
+                **valued_const_with_data('end', np.array([1, 3, 5])),
+                **regular_op_with_empty_data('slice', {'type': None, 'op': 'Slice'}),
+                **result('result')
+            },
+            edges=[
+                *connect('input', 'slice'),
+                *connect('start', '1:slice'),
+                *connect('end', '2:slice'),
+                *connect('slice', 'result')
+            ]
+        )
+        ref_graph = build_graph(
+            nodes_attrs={
+                **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {'type': 'Parameter'}),
+                **valued_const_with_data('start', np.array([0, 0, 0])),
+                **valued_const_with_data('begin_first_part', int64_array([])),
+                **valued_const_with_data('begin_last_part', int64_array([])),
+                **regular_op_with_empty_data('convert_start', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
+                **regular_op_with_empty_data('ss_begin', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
+                **valued_const_with_data('end', np.array([1, 3, 5])),
+                **valued_const_with_data('end_first_part', int64_array([])),
+                **valued_const_with_data('end_last_part', int64_array([])),
+                **regular_op_with_empty_data('convert_end', {'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64}),
+                **regular_op_with_empty_data('ss_end', {'type': 'Concat', 'op': 'Concat', 'axis': 0}),
+                **const('ss_steps', int64_array([1, 1, 1])),
+                **empty_data('ss_steps_d'),
+                **regular_op_with_empty_data('ss', {'op': 'StridedSlice', 'type': 'StridedSlice',
+                                                    'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 1, 1]),
+                                                    'new_axis_mask': np.zeros(3, dtype=np.int64),
+                                                    'shrink_axis_mask': np.zeros(3, dtype=np.int64),
+                                                    'ellipsis_mask': np.zeros(3, dtype=np.int64)}),
+                **result('result')
+            },
+            edges=[
+                *connect('input', 'ss'),
+                *connect('begin_first_part', 'ss_begin'),
+                *connect('start', 'convert_start'),
+                *connect('convert_start', '1:ss_begin'),
+                *connect('begin_last_part', '2:ss_begin'),
+                *connect('ss_begin', '1:ss'),
+                *connect('end_first_part', 'ss_end'),
+                *connect('end', 'convert_end'),
+                *connect('convert_end', '1:ss_end'),
+                *connect('end_last_part', '2:ss_end'),
+                *connect('ss_end', '2:ss'),
+                *connect('ss_steps', '3:ss'),
+                *connect('ss', 'result')
+            ]
+        )
+        ConvertSlice().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
         self.assertTrue(flag, resp)