Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ConvertGroupedStridedSlice_test.py
index 0ebdb38..24d1ca9 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -19,20 +19,26 @@ import unittest
 import numpy as np
 
 from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
+from mo.graph.graph import Node
 from mo.utils.unittest.graph import build_graph, compare_graphs
 
 nodes_attributes = {
     'placeholder_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+    'placeholder_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
+    'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+    'placeholder_begin_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+    'placeholder_end_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+    'placeholder_stride_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
     # StridedSlice layers
     'sslice_1': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
-                 'shrink_axis_mask': np.array([False, False, False, False])},
+                 'shrink_axis_mask': np.array([0, 0, 0, 0])},
     'sslice_1_data': {'value': None, 'shape': None, 'kind': 'data'},
     'sslice_2': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
-                 'shrink_axis_mask': np.array([False, False, False, False])},
+                 'shrink_axis_mask': np.array([0, 0, 0, 0])},
     'sslice_2_data': {'value': None, 'shape': None, 'kind': 'data'},
     'sslice_3': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
-                 'shrink_axis_mask': np.array([False, False, False, False])},
+                 'shrink_axis_mask': np.array([0, 0, 0, 0])},
     'sslice_3_data': {'value': None, 'shape': None, 'kind': 'data'},
     # Split layer
     'split_1': {'type': 'Split', 'kind': 'op', 'op': 'SplitV'},
@@ -43,6 +49,16 @@ nodes_attributes = {
     # Concat1 operation
     'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
     'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'op_output': {'kind': 'op', 'op': 'OpOutput'},
+    'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
+    'op_output_2': {'kind': 'op', 'op': 'OpOutput'},
+    # Reshape layer
+    'sslice_1/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
+    'sslice_1/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'sslice_2/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
+    'sslice_2/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'sslice_2/Reshape_new': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
+    'sslice_2/Reshape_new_data': {'value': None, 'shape': None, 'kind': 'data'},
 }
 
 
@@ -59,7 +75,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
                              ('sslice_3_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -75,7 +92,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(36, 54, 1)])},
                              'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -88,14 +105,16 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('split_1_data', 'concat_1'),
                                  ('split_2_data', 'concat_1'),
                                  ('split_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output')
+
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  'split_1': {'axis': 3},
                                  'split_1_data': {'shape': np.array([1, 227, 227, 18])},
                                  'split_2_data': {'shape': np.array([1, 227, 227, 18])},
                                  'split_3_data': {'shape': np.array([1, 227, 227, 18])},
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -116,7 +135,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
                              ('sslice_3_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -132,7 +152,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
                              'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -145,14 +165,15 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('split_1_data', 'concat_1'),
                                  ('split_2_data', 'concat_1'),
                                  ('split_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output')
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  'split_1': {'axis': 3},
                                  'split_1_data': {'shape': np.array([1, 227, 227, 18])},
                                  'split_2_data': {'shape': np.array([1, 227, 227, 17])},
                                  'split_3_data': {'shape': np.array([1, 227, 227, 19])},
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -174,7 +195,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
                              ('sslice_3_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -190,7 +212,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
                              'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -205,7 +227,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('sslice_1_data', 'concat_1'),
                                  ('sslice_2_data', 'concat_1'),
                                  ('sslice_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output')
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -221,7 +244,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
                                  'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
 
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -243,7 +266,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
                              ('sslice_3_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -259,7 +283,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
                              'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -274,7 +298,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('sslice_1_data', 'concat_1'),
                                  ('sslice_2_data', 'concat_1'),
                                  ('sslice_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output')
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -290,7 +315,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
                                  'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
 
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -315,7 +340,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
                              ('sslice_3_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output'),
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -331,7 +357,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(1, 19, 1)])},
                              'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -345,7 +371,9 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('split_2_data', 'concat_1'),
                                  ('split_3_data', 'concat_1'),
                                  ('split_4_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output'),
+                                 ('split_1_data', 'op_output_1')
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  'split_1': {'axis': 3},
@@ -353,7 +381,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  'split_2_data': {'shape': np.array([1, 227, 227, 18])},
                                  'split_3_data': {'shape': np.array([1, 227, 227, 17])},
                                  'split_4_data': {'shape': np.array([1, 227, 227, 18])},
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -376,7 +404,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_2', 'sslice_2_data'),
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -388,7 +417,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
                              'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -401,7 +430,10 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('split_1', 'split_4_data'),
                                  ('split_1_data', 'concat_1'),
                                  ('split_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output'),
+                                 ('split_2_data', 'op_output_1'),
+                                 ('split_4_data', 'op_output_2'),
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  'split_1': {'axis': 3},
@@ -409,7 +441,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  'split_2_data': {'shape': np.array([1, 227, 227, 9])},
                                  'split_3_data': {'shape': np.array([1, 227, 227, 18])},
                                  'split_4_data': {'shape': np.array([1, 227, 227, 9])},
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -427,7 +459,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_2', 'sslice_2_data'),
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -439,7 +472,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
                              'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
 
-                             'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -451,7 +484,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('sslice_2', 'sslice_2_data'),
                                  ('sslice_1_data', 'concat_1'),
                                  ('sslice_2_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output')
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
 
@@ -463,7 +497,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                      [slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
                                  'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
 
-                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()
@@ -485,7 +519,8 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                              ('sslice_2', 'sslice_2_data'),
                              ('sslice_1_data', 'concat_1'),
                              ('sslice_2_data', 'concat_1'),
-                             ('concat_1', 'concat_1_data')
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output')
                              ],
                             {'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
 
@@ -497,7 +532,7 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  [slice(0, 1, 1), slice(18, 36, 1), slice(0, 54, 1), slice(0, 3, 1)])},
                              'sslice_2_data': {'shape': np.array([1, 18, 54, 3])},
 
-                             'concat_1_data': {'shape': np.array([1, 54, 54, 3]), 'is_output': True},
+                             'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
                              })
         graph.graph['layout'] = 'NHWC'
 
@@ -509,14 +544,336 @@ class ConvertGroupedStridedSliceTests(unittest.TestCase):
                                  ('split_1', 'split_3_data'),
                                  ('split_1_data', 'concat_1'),
                                  ('split_3_data', 'concat_1'),
-                                 ('concat_1', 'concat_1_data')
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output'),
+                                 ('split_2_data', 'op_output_1')
                                  ],
                                 {'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
                                  'split_1': {'axis': 1},
                                  'split_1_data': {'shape': np.array([1, 18, 54, 3])},
                                  'split_2_data': {'shape': np.array([1, 18, 54, 3])},
                                  'split_3_data': {'shape': np.array([1, 18, 54, 3])},
-                                 'concat_1_data': {'shape': np.array([1, 54, 54, 3]), 'is_output': True},
+                                 'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
+                                 })
+
+        pattern = ConvertGroupedStridedSlice()
+        pattern.find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
+        self.assertTrue(flag, resp)
+
+
+class AddReshapeAfterStridedSliceTests(unittest.TestCase):
+    def test_ss_1_shrink_last(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'sslice_1'),
+                             ('placeholder_begin_data', 'sslice_1'),
+                             ('placeholder_end_data', 'sslice_1'),
+                             ('placeholder_stride_data', 'sslice_1'),
+                             ('sslice_1', 'sslice_1_data'),
+                             ('sslice_1_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                             'sslice_1': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+                                          'shrink_axis_mask': [0, 0, 1, 0],
+                                          'new_axis_mask': np.array([0, 0, 0, 0])},
+                             'sslice_1_data': {'shape': np.array([1, 227, 54])},
+                             })
+        graph.graph['layout'] = 'NHWC'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_1_data', 'sslice_1'),
+                                 ('placeholder_begin_data', 'sslice_1'),
+                                 ('placeholder_end_data', 'sslice_1'),
+                                 ('placeholder_stride_data', 'sslice_1'),
+                                 ('sslice_1', 'sslice_1/Reshape_shrink_data'),
+                                 ('sslice_1/Reshape_shrink_data', 'sslice_1/Reshape_shrink'),
+                                 ('sslice_1/Reshape_shrink', 'sslice_1_data'),
+                                 ('sslice_1_data', 'op_output')
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                                 'sslice_1': {'slices': np.array(
+                                     [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+                                     'shrink_axis_mask': np.array([0, 0, 0, 0]),
+                                     'new_axis_mask': np.array([0, 0, 0, 0])},
+                                 'sslice_1_data': {'shape': np.array([1, 227, 54])},
+                                 'sslice_1/Reshape_shrink': {'dim': np.array([1, 227, 54])},
+                                 'sslice_1/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])}
+                                 })
+
+        pattern = ConvertGroupedStridedSlice()
+        pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_1'))
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_1_data', check_op_attrs=True)
+        graph.clear()
+        graph_ref.clear()
+        self.assertTrue(flag, resp)
+
+    def test_ss_1_shrink(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'sslice_2'),
+                             ('placeholder_begin_data', 'sslice_2'),
+                             ('placeholder_end_data', 'sslice_2'),
+                             ('placeholder_stride_data', 'sslice_2'),
+                             ('sslice_2', 'sslice_2_data'),
+                             ('sslice_2_data', 'placeholder_2'),
+                             ('placeholder_2', 'placeholder_2_data'),
+                             ('sslice_2_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                             'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+                                          'shrink_axis_mask': [0, 0, 1, 0],
+                                          'new_axis_mask': np.array([0, 0, 0, 0])},
+                             'sslice_2_data': {'shape': np.array([1, 227, 54])}
+                             })
+        graph.graph['layout'] = 'NHWC'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_1_data', 'sslice_2'),
+                                 ('placeholder_begin_data', 'sslice_2'),
+                                 ('placeholder_end_data', 'sslice_2'),
+                                 ('placeholder_stride_data', 'sslice_2'),
+                                 ('sslice_2', 'sslice_2/Reshape_shrink_data'),
+                                 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
+                                 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
+                                 ('sslice_2_data', 'placeholder_2'),
+                                 ('placeholder_2', 'placeholder_2_data'),
+                                 ('sslice_2_data', 'op_output')
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                                 'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+                                              'shrink_axis_mask': np.array([0, 0, 0, 0]),
+                                              'new_axis_mask': np.array([0, 0, 0, 0])},
+                                 'sslice_2_data': {'shape': np.array([1, 227, 54])},
+                                 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227, 54])},
+                                 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])},
+                                 })
+
+        pattern = ConvertGroupedStridedSlice()
+        pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+        graph.clear()
+        graph_ref.clear()
+        self.assertTrue(flag, resp)
+
+    def test_ss_2_shrink(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'sslice_2'),
+                             ('placeholder_begin_data', 'sslice_2'),
+                             ('placeholder_end_data', 'sslice_2'),
+                             ('placeholder_stride_data', 'sslice_2'),
+                             ('sslice_2', 'sslice_2_data'),
+                             ('sslice_2_data', 'placeholder_2'),
+                             ('placeholder_2', 'placeholder_2_data'),
+                             ('sslice_2_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                             'sslice_2': {
+                                 'slices': np.array([slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
+                                 'shrink_axis_mask': np.array([0, 1, 0, 1]),
+                                 'new_axis_mask': np.array([0, 0, 0, 0])},
+                             'sslice_2_data': {'shape': np.array([1, 227])}
+                             })
+        graph.graph['layout'] = 'NHWC'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_1_data', 'sslice_2'),
+                                 ('placeholder_begin_data', 'sslice_2'),
+                                 ('placeholder_end_data', 'sslice_2'),
+                                 ('placeholder_stride_data', 'sslice_2'),
+                                 ('sslice_2', 'sslice_2/Reshape_shrink_data'),
+                                 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
+                                 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
+                                 ('sslice_2_data', 'placeholder_2'),
+                                 ('placeholder_2', 'placeholder_2_data'),
+                                 ('sslice_2_data', 'op_output')
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                                 'sslice_2': {'slices': np.array(
+                                     [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
+                                     'shrink_axis_mask': np.array([0, 0, 0, 0]),
+                                     'new_axis_mask': np.array([0, 0, 0, 0])},
+                                 'sslice_2_data': {'shape': np.array([1, 227])},
+                                 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227])},
+                                 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1])},
+                                 })
+
+        pattern = ConvertGroupedStridedSlice()
+        pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+        graph.clear()
+        graph_ref.clear()
+        self.assertTrue(flag, resp)
+
+    def test_ss_1_new(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'sslice_2'),
+                             ('placeholder_begin_data', 'sslice_2'),
+                             ('placeholder_end_data', 'sslice_2'),
+                             ('placeholder_stride_data', 'sslice_2'),
+                             ('sslice_2', 'sslice_2_data'),
+                             ('sslice_2_data', 'placeholder_2'),
+                             ('placeholder_2', 'placeholder_2_data'), ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                             'sslice_2': {'slices': np.array(
+                                 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 54, 1)]),
+                                 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
+                                 'new_axis_mask': np.array([0, 1, 0, 0, 0])},
+                             'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])}
+                             })
+        graph.graph['layout'] = 'NHWC'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_1_data', 'sslice_2'),
+                                 ('placeholder_begin_data', 'sslice_2'),
+                                 ('placeholder_end_data', 'sslice_2'),
+                                 ('placeholder_stride_data', 'sslice_2'),
+                                 ('sslice_2', 'sslice_2/Reshape_new_data'),
+                                 ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
+                                 ('sslice_2/Reshape_new', 'sslice_2_data'),
+                                 ('sslice_2_data', 'placeholder_2'),
+                                 ('placeholder_2', 'placeholder_2_data')],
+                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                                 'sslice_2': {'slices': np.array(
+                                     [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1),
+                                      slice(0, 54, 1)]),
+                                     'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
+                                     'new_axis_mask': np.array([0, 0, 0, 0, 0])},
+                                 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])},
+                                 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 227, 54])},
+                                 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 227, 54])},
+                                 })
+
+        pattern = ConvertGroupedStridedSlice()
+        pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+        graph.clear()
+        graph_ref.clear()
+        self.assertTrue(flag, resp)
+
+    def test_ss_shrink_new(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'sslice_2'),
+                             ('placeholder_begin_data', 'sslice_2'),
+                             ('placeholder_end_data', 'sslice_2'),
+                             ('placeholder_stride_data', 'sslice_2'),
+                             ('sslice_2', 'sslice_2_data'),
+                             ('sslice_2_data', 'placeholder_2'),
+                             ('placeholder_2', 'placeholder_2_data'),
+                             ('sslice_2_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                             'sslice_2': {'slices': np.array(
+                                 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
+                                 'shrink_axis_mask': np.array([0, 0, 0, 1, 0]),
+                                 'new_axis_mask': np.array([0, 1, 0, 0, 0])},
+                             'sslice_2_data': {'shape': np.array([1, 1, 227, 54])}
+                             })
+        graph.graph['layout'] = 'NHWC'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_1_data', 'sslice_2'),
+                                 ('placeholder_begin_data', 'sslice_2'),
+                                 ('placeholder_end_data', 'sslice_2'),
+                                 ('placeholder_stride_data', 'sslice_2'),
+                                 ('sslice_2', 'sslice_2/Reshape_new_data'),
+                                 ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
+                                 ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
+                                 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
+                                 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
+                                 ('sslice_2_data', 'placeholder_2'),
+                                 ('placeholder_2', 'placeholder_2_data'),
+                                 ('sslice_2_data', 'op_output')
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                                 'sslice_2': {'slices': np.array(
+                                     [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1),
+                                      slice(0, 54, 1)]),
+                                     'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
+                                     'new_axis_mask': np.array([0, 0, 0, 0, 0])},
+                                 'sslice_2_data': {'shape': np.array([1, 1, 227, 54])},
+                                 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 1, 54])},
+                                 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 1, 54])},
+                                 'sslice_2/Reshape_shrink': {'dim': np.array([1, 1, 227, 54])},
+                                 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1, 54])},
+                                 })
+
+        pattern = ConvertGroupedStridedSlice()
+        pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
+        pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
+        graph.clear()
+        graph_ref.clear()
+        self.assertTrue(flag, resp)
+
+    # test case with 2 strided slices with the same parameters but different outputs
+    def test_1(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_1_data', 'sslice_1'),
+                             ('sslice_1', 'sslice_1_data'),
+                             ('placeholder_1_data', 'sslice_2'),
+                             ('sslice_2', 'sslice_2_data'),
+                             ('placeholder_1_data', 'sslice_3'),
+                             ('sslice_3', 'sslice_3_data'),
+                             ('sslice_1_data', 'concat_1'),
+                             ('sslice_2_data', 'concat_1'),
+                             ('sslice_3_data', 'placeholder_2'),
+                             ('placeholder_2', 'placeholder_2_data'),
+                             ('concat_1', 'concat_1_data'),
+                             ('concat_1_data', 'op_output'),
+                             ('placeholder_2_data', 'op_output')
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+
+                             'sslice_1': {'slices': np.array(
+                                 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
+                             'sslice_1_data': {'shape': np.array([1, 227, 227, 27])},
+
+                             'sslice_2': {'slices': np.array(
+                                 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 54, 1)])},
+                             'sslice_2_data': {'shape': np.array([1, 227, 227, 27])},
+
+                             'sslice_3': {'slices': np.array(
+                                 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
+                             'sslice_3_data': {'shape': np.array([1, 227, 227, 27])},
+
+                             'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
+                             })
+        graph.graph['layout'] = 'NHWC'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_1_data', 'split_1'),
+                                 ('split_1', 'split_1_data'),
+                                 ('split_1', 'split_2_data'),
+                                 ('split_1_data', 'concat_1'),
+                                 ('split_2_data', 'concat_1'),
+                                 ('split_1_data', 'placeholder_2'),
+                                 ('placeholder_2', 'placeholder_2_data'),
+                                 ('concat_1', 'concat_1_data'),
+                                 ('concat_1_data', 'op_output'),
+                                 ('placeholder_2_data', 'op_output')
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
+                                 'split_1': {'axis': 3},
+                                 'split_1_data': {'shape': np.array([1, 227, 227, 27])},
+                                 'split_2_data': {'shape': np.array([1, 227, 227, 27])},
+                                 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
                                  })
 
         pattern = ConvertGroupedStridedSlice()