X-Git-Url: http://review.tizen.org/git/?a=blobdiff_plain;f=model-optimizer%2Fextensions%2Fmiddle%2FReplaceSpliceNodePattern_test.py;h=4689784d945484fb950aa5866450dc385773498b;hb=0923303e0201c5b59386ab146d0e30b2ef79272d;hp=ca403364997e95b5097163095eeb50f1461d93a0;hpb=ba6e22b1b5ee4cbefcc30e8d9493cddb0bb3dfdf;p=platform%2Fupstream%2Fdldt.git diff --git a/model-optimizer/extensions/middle/ReplaceSpliceNodePattern_test.py b/model-optimizer/extensions/middle/ReplaceSpliceNodePattern_test.py index ca40336..4689784 100644 --- a/model-optimizer/extensions/middle/ReplaceSpliceNodePattern_test.py +++ b/model-optimizer/extensions/middle/ReplaceSpliceNodePattern_test.py @@ -17,43 +17,129 @@ import unittest from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern from mo.graph.graph import Node -from mo.utils.unittest.graph import build_graph +from mo.utils.unittest.graph import build_graph, compare_graphs class ReplaceSpliceNodePatternTests(unittest.TestCase): @classmethod def setUpClass(cls): cls.nodes_attributes = { + 'placeholder': {'kind': 'op', 'op': None}, 'in_node': {'kind': 'data', 'shape': [1, 13]}, - 'slice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 5)}, + 'splice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 6), 'const_dim': 0}, 'splice_data': {'kind': 'data', 'shape': [1, 143]}, + 'out_placeholder': {'kind': 'op', 'op': 'placeholder'}, } - cls.graph = build_graph(cls.nodes_attributes, - [('in_node', 'slice'), - ('slice', 'splice_data')]) - - ReplaceSpliceNodePattern().find_and_replace_pattern(cls.graph) - - def test_memory(self): - memory_nodes = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Memory'] - self.assertEqual(len(memory_nodes), 2) - for memory_node in memory_nodes: - node = Node(self.graph, memory_node[0]) - if len(node.in_nodes()): - self.assertEqual(node.index, 0) - elif len(node.out_nodes()): - self.assertEqual(node.index, 1) - self.assertEqual(memory_nodes[0][1]['id'], memory_nodes[1][1]['id']) - - def test_crop(self): - crop_node = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Crop'] - self.assertEqual(len(crop_node), 1) - crop_node = Node(self.graph, crop_node[0][0]) - self.assertEqual(crop_node.offset, [13]) - self.assertEqual(crop_node.dim, [13 * 9]) - - def test_concat(self): - concat_node = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Concat'] - self.assertEqual(len(concat_node), 1) - crop_node = Node(self.graph, concat_node[0][0]) - self.assertEqual(crop_node.axis, 1) + + def test_splice(self): + graph = build_graph(self.nodes_attributes, + [('placeholder', 'in_node'), + ('in_node', 'splice'), + ('splice', 'splice_data'), + ('splice_data', 'out_placeholder')]) + ReplaceSpliceNodePattern().find_and_replace_pattern(graph) + + ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None}, + 'in_node': {'kind': 'data', 'shape': [1, 13]}, + 'memory_in': {'kind': 'op', 'op': 'Memory'}, + 'memory_in_data': {'kind': 'data'}, + 'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130}, + 'crop_mem_data': {'kind': 'data'}, + 'concat': {'kind': 'op', 'op': 'Concat'}, + 'concat_data': {'kind': 'data', 'shape': [1, 143]}, + 'memory_out': {'kind': 'op', 'op': 'Memory'}, + 'memory_out_data': {'kind': 'data'}, + 'result': {'kind': 'op', 'op': 'Result'}, + 'out_placeholder': {'kind': 'op', 'op': 'placeholder'}, + }, + [ + ('in_placeholder', 'in_node'), + ('memory_in', 'memory_in_data'), + ('memory_in_data', 'crop_mem'), + ('crop_mem', 'crop_mem_data'), + ('crop_mem_data', 'concat', {'in': 0}), + ('in_node', 'concat', {'in': 1}), + ('concat', 'concat_data'), + ('concat_data', 'memory_out'), + ('memory_out', 'memory_out_data'), + ('memory_out_data', 'result'), + ('concat_data', 'out_placeholder'), + ] + ) + + (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder') + self.assertTrue(flag, resp) + + def test_splice_with_constdim(self): + graph = build_graph(self.nodes_attributes, + [('placeholder', 'in_node'), + ('in_node', 'splice'), + ('splice', 'splice_data'), + ('splice_data', 'out_placeholder')]) + Node(graph, 'splice')['const_dim'] = 10 + Node(graph, 'splice_data')['shape'] = [1, 43] + ReplaceSpliceNodePattern().find_and_replace_pattern(graph) + + ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None}, + 'in_node': {'kind': 'data', 'shape': [1, 13]}, + 'split': {'kind': 'op', 'op': 'Split'}, + 'split_data_0': {'kind': 'data'}, + 'split_data_1': {'kind': 'data'}, + 'memory_in': {'kind': 'op', 'op': 'Memory'}, + 'memory_in_data': {'kind': 'data'}, + 'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30}, + 'crop_mem_data': {'kind': 'data'}, + 'concat': {'kind': 'op', 'op': 'Concat'}, + 'concat_data': {'kind': 'data'}, + 'memory_out': {'kind': 'op', 'op': 'Memory'}, + 'memory_out_data': {'kind': 'data'}, + 'result': {'kind': 'op', 'op': 'Result'}, + 'memory_in_constdims': {'kind': 'op', 'op': 'Memory'}, + 'memory_in_constdims_data': {'kind': 'data'}, + 'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100}, + 'crop_mem_constdims_data': {'kind': 'data'}, + 'concat_constdims': {'kind': 'op', 'op': 'Concat'}, + 'concat_constdims_data': {'kind': 'data'}, + 'memory_out_constdims': {'kind': 'op', 'op': 'Memory'}, + 'memory_out_constdims_data': {'kind': 'data'}, + 'result_constdims': {'kind': 'op', 'op': 'Result'}, + 'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10}, + 'crop_first_constdims_data': {'kind': 'data'}, + 'concat_all': {'kind': 'op', 'op': 'Concat'}, + 'concat_all_data': {'kind': 'data', 'shape': [1, 43]}, + 'out_placeholder': {'kind': 'op', 'op': 'placeholder'}, + }, + [ + ('in_placeholder', 'in_node'), + ('in_node', 'split'), + ('split', 'split_data_0', {'out': 0}), + ('split', 'split_data_1', {'out': 1}), + ('memory_in', 'memory_in_data'), + ('memory_in_data', 'crop_mem'), + ('crop_mem', 'crop_mem_data'), + ('crop_mem_data', 'concat', {'in': 0}), + ('split_data_0', 'concat', {'in': 1}), + ('concat', 'concat_data'), + ('concat_data', 'memory_out'), + ('memory_out', 'memory_out_data'), + ('memory_out_data', 'result'), + ('memory_in_constdims', 'memory_in_constdims_data'), + ('memory_in_constdims_data', 'crop_mem_constdims'), + ('crop_mem_constdims', 'crop_mem_constdims_data'), + ('crop_mem_constdims_data', 'concat_constdims', {'in': 0}), + ('split_data_1', 'concat_constdims', {'in': 1}), + ('concat_constdims', 'concat_constdims_data'), + ('concat_constdims_data', 'memory_out_constdims'), + ('memory_out_constdims', 'memory_out_constdims_data'), + ('memory_out_constdims_data', 'result_constdims'), + ('concat_constdims_data', 'crop_first_constdims'), + ('crop_first_constdims', 'crop_first_constdims_data'), + ('crop_first_constdims_data', 'concat_all', {'in': 1}), + ('concat_data', 'concat_all', {'in': 0}), + ('concat_all', 'concat_all_data'), + ('concat_all_data', 'out_placeholder'), + ] + ) + + (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder') + self.assertTrue(flag, resp)