Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ReplaceSpliceNodePattern_test.py
index ca40336..4689784 100644 (file)
@@ -17,43 +17,129 @@ import unittest
 
 from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
 from mo.graph.graph import Node
-from mo.utils.unittest.graph import build_graph
+from mo.utils.unittest.graph import build_graph, compare_graphs
 
 
 class ReplaceSpliceNodePatternTests(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         cls.nodes_attributes = {
+            'placeholder': {'kind': 'op', 'op': None},
             'in_node': {'kind': 'data', 'shape': [1, 13]},
-            'slice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 5)},
+            'splice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 6), 'const_dim': 0},
             'splice_data': {'kind': 'data', 'shape': [1, 143]},
+            'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
         }
-        cls.graph = build_graph(cls.nodes_attributes,
-                                [('in_node', 'slice'),
-                                 ('slice', 'splice_data')])
-
-        ReplaceSpliceNodePattern().find_and_replace_pattern(cls.graph)
-
-    def test_memory(self):
-        memory_nodes = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Memory']
-        self.assertEqual(len(memory_nodes), 2)
-        for memory_node in memory_nodes:
-            node = Node(self.graph, memory_node[0])
-            if len(node.in_nodes()):
-                self.assertEqual(node.index, 0)
-            elif len(node.out_nodes()):
-                self.assertEqual(node.index, 1)
-        self.assertEqual(memory_nodes[0][1]['id'], memory_nodes[1][1]['id'])
-
-    def test_crop(self):
-        crop_node = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Crop']
-        self.assertEqual(len(crop_node), 1)
-        crop_node = Node(self.graph, crop_node[0][0])
-        self.assertEqual(crop_node.offset, [13])
-        self.assertEqual(crop_node.dim, [13 * 9])
-
-    def test_concat(self):
-        concat_node = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Concat']
-        self.assertEqual(len(concat_node), 1)
-        crop_node = Node(self.graph, concat_node[0][0])
-        self.assertEqual(crop_node.axis, 1)
+
+    def test_splice(self):
+        graph = build_graph(self.nodes_attributes,
+                            [('placeholder', 'in_node'),
+                             ('in_node', 'splice'),
+                             ('splice', 'splice_data'),
+                             ('splice_data', 'out_placeholder')])
+        ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
+
+        ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
+                                 'in_node': {'kind': 'data', 'shape': [1, 13]},
+                                 'memory_in': {'kind': 'op', 'op': 'Memory'},
+                                 'memory_in_data': {'kind': 'data'},
+                                 'crop_mem':  {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
+                                 'crop_mem_data': {'kind': 'data'},
+                                 'concat': {'kind': 'op', 'op': 'Concat'},
+                                 'concat_data': {'kind': 'data', 'shape': [1, 143]},
+                                 'memory_out': {'kind': 'op', 'op': 'Memory'},
+                                 'memory_out_data': {'kind': 'data'},
+                                 'result': {'kind': 'op', 'op': 'Result'},
+                                 'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
+                                 },
+                                [
+                                    ('in_placeholder', 'in_node'),
+                                    ('memory_in', 'memory_in_data'),
+                                    ('memory_in_data', 'crop_mem'),
+                                    ('crop_mem', 'crop_mem_data'),
+                                    ('crop_mem_data', 'concat', {'in': 0}),
+                                    ('in_node', 'concat', {'in': 1}),
+                                    ('concat', 'concat_data'),
+                                    ('concat_data', 'memory_out'),
+                                    ('memory_out', 'memory_out_data'),
+                                    ('memory_out_data', 'result'),
+                                    ('concat_data', 'out_placeholder'),
+                                ]
+                                )
+
+        (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
+        self.assertTrue(flag, resp)
+
+    def test_splice_with_constdim(self):
+        graph = build_graph(self.nodes_attributes,
+                            [('placeholder', 'in_node'),
+                             ('in_node', 'splice'),
+                             ('splice', 'splice_data'),
+                             ('splice_data', 'out_placeholder')])
+        Node(graph, 'splice')['const_dim'] = 10
+        Node(graph, 'splice_data')['shape'] = [1, 43]
+        ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
+
+        ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
+                                 'in_node': {'kind': 'data', 'shape': [1, 13]},
+                                 'split': {'kind': 'op', 'op': 'Split'},
+                                 'split_data_0': {'kind': 'data'},
+                                 'split_data_1': {'kind': 'data'},
+                                 'memory_in': {'kind': 'op', 'op': 'Memory'},
+                                 'memory_in_data': {'kind': 'data'},
+                                 'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
+                                 'crop_mem_data': {'kind': 'data'},
+                                 'concat': {'kind': 'op', 'op': 'Concat'},
+                                 'concat_data': {'kind': 'data'},
+                                 'memory_out': {'kind': 'op', 'op': 'Memory'},
+                                 'memory_out_data': {'kind': 'data'},
+                                 'result': {'kind': 'op', 'op': 'Result'},
+                                 'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
+                                 'memory_in_constdims_data': {'kind': 'data'},
+                                 'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
+                                 'crop_mem_constdims_data': {'kind': 'data'},
+                                 'concat_constdims': {'kind': 'op', 'op': 'Concat'},
+                                 'concat_constdims_data': {'kind': 'data'},
+                                 'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
+                                 'memory_out_constdims_data': {'kind': 'data'},
+                                 'result_constdims': {'kind': 'op', 'op': 'Result'},
+                                 'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
+                                 'crop_first_constdims_data': {'kind': 'data'},
+                                 'concat_all': {'kind': 'op', 'op': 'Concat'},
+                                 'concat_all_data': {'kind': 'data', 'shape': [1, 43]},
+                                 'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
+                                 },
+                                [
+                                    ('in_placeholder', 'in_node'),
+                                    ('in_node', 'split'),
+                                    ('split', 'split_data_0', {'out': 0}),
+                                    ('split', 'split_data_1', {'out': 1}),
+                                    ('memory_in', 'memory_in_data'),
+                                    ('memory_in_data', 'crop_mem'),
+                                    ('crop_mem', 'crop_mem_data'),
+                                    ('crop_mem_data', 'concat', {'in': 0}),
+                                    ('split_data_0', 'concat', {'in': 1}),
+                                    ('concat', 'concat_data'),
+                                    ('concat_data', 'memory_out'),
+                                    ('memory_out', 'memory_out_data'),
+                                    ('memory_out_data', 'result'),
+                                    ('memory_in_constdims', 'memory_in_constdims_data'),
+                                    ('memory_in_constdims_data', 'crop_mem_constdims'),
+                                    ('crop_mem_constdims', 'crop_mem_constdims_data'),
+                                    ('crop_mem_constdims_data', 'concat_constdims', {'in': 0}),
+                                    ('split_data_1', 'concat_constdims', {'in': 1}),
+                                    ('concat_constdims', 'concat_constdims_data'),
+                                    ('concat_constdims_data', 'memory_out_constdims'),
+                                    ('memory_out_constdims', 'memory_out_constdims_data'),
+                                    ('memory_out_constdims_data', 'result_constdims'),
+                                    ('concat_constdims_data', 'crop_first_constdims'),
+                                    ('crop_first_constdims', 'crop_first_constdims_data'),
+                                    ('crop_first_constdims_data', 'concat_all', {'in': 1}),
+                                    ('concat_data', 'concat_all', {'in': 0}),
+                                    ('concat_all', 'concat_all_data'),
+                                    ('concat_all_data', 'out_placeholder'),
+                                ]
+                                )
+
+        (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
+        self.assertTrue(flag, resp)