from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
from mo.graph.graph import Node
-from mo.utils.unittest.graph import build_graph
+from mo.utils.unittest.graph import build_graph, compare_graphs
class ReplaceSpliceNodePatternTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.nodes_attributes = {
+ 'placeholder': {'kind': 'op', 'op': None},
'in_node': {'kind': 'data', 'shape': [1, 13]},
- 'slice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 5)},
+ 'splice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 6), 'const_dim': 0},
'splice_data': {'kind': 'data', 'shape': [1, 143]},
+ 'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
}
- cls.graph = build_graph(cls.nodes_attributes,
- [('in_node', 'slice'),
- ('slice', 'splice_data')])
-
- ReplaceSpliceNodePattern().find_and_replace_pattern(cls.graph)
-
- def test_memory(self):
- memory_nodes = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Memory']
- self.assertEqual(len(memory_nodes), 2)
- for memory_node in memory_nodes:
- node = Node(self.graph, memory_node[0])
- if len(node.in_nodes()):
- self.assertEqual(node.index, 0)
- elif len(node.out_nodes()):
- self.assertEqual(node.index, 1)
- self.assertEqual(memory_nodes[0][1]['id'], memory_nodes[1][1]['id'])
-
- def test_crop(self):
- crop_node = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Crop']
- self.assertEqual(len(crop_node), 1)
- crop_node = Node(self.graph, crop_node[0][0])
- self.assertEqual(crop_node.offset, [13])
- self.assertEqual(crop_node.dim, [13 * 9])
-
- def test_concat(self):
- concat_node = [node for node in self.graph.nodes(data=True) if node[1]['kind'] == 'op' and node[1]['op'] == 'Concat']
- self.assertEqual(len(concat_node), 1)
- crop_node = Node(self.graph, concat_node[0][0])
- self.assertEqual(crop_node.axis, 1)
+
+ def test_splice(self):
+ graph = build_graph(self.nodes_attributes,
+ [('placeholder', 'in_node'),
+ ('in_node', 'splice'),
+ ('splice', 'splice_data'),
+ ('splice_data', 'out_placeholder')])
+ ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
+
+ ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
+ 'in_node': {'kind': 'data', 'shape': [1, 13]},
+ 'memory_in': {'kind': 'op', 'op': 'Memory'},
+ 'memory_in_data': {'kind': 'data'},
+ 'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
+ 'crop_mem_data': {'kind': 'data'},
+ 'concat': {'kind': 'op', 'op': 'Concat'},
+ 'concat_data': {'kind': 'data', 'shape': [1, 143]},
+ 'memory_out': {'kind': 'op', 'op': 'Memory'},
+ 'memory_out_data': {'kind': 'data'},
+ 'result': {'kind': 'op', 'op': 'Result'},
+ 'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
+ },
+ [
+ ('in_placeholder', 'in_node'),
+ ('memory_in', 'memory_in_data'),
+ ('memory_in_data', 'crop_mem'),
+ ('crop_mem', 'crop_mem_data'),
+ ('crop_mem_data', 'concat', {'in': 0}),
+ ('in_node', 'concat', {'in': 1}),
+ ('concat', 'concat_data'),
+ ('concat_data', 'memory_out'),
+ ('memory_out', 'memory_out_data'),
+ ('memory_out_data', 'result'),
+ ('concat_data', 'out_placeholder'),
+ ]
+ )
+
+ (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
+ self.assertTrue(flag, resp)
+
+ def test_splice_with_constdim(self):
+ graph = build_graph(self.nodes_attributes,
+ [('placeholder', 'in_node'),
+ ('in_node', 'splice'),
+ ('splice', 'splice_data'),
+ ('splice_data', 'out_placeholder')])
+ Node(graph, 'splice')['const_dim'] = 10
+ Node(graph, 'splice_data')['shape'] = [1, 43]
+ ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
+
+ ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
+ 'in_node': {'kind': 'data', 'shape': [1, 13]},
+ 'split': {'kind': 'op', 'op': 'Split'},
+ 'split_data_0': {'kind': 'data'},
+ 'split_data_1': {'kind': 'data'},
+ 'memory_in': {'kind': 'op', 'op': 'Memory'},
+ 'memory_in_data': {'kind': 'data'},
+ 'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
+ 'crop_mem_data': {'kind': 'data'},
+ 'concat': {'kind': 'op', 'op': 'Concat'},
+ 'concat_data': {'kind': 'data'},
+ 'memory_out': {'kind': 'op', 'op': 'Memory'},
+ 'memory_out_data': {'kind': 'data'},
+ 'result': {'kind': 'op', 'op': 'Result'},
+ 'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
+ 'memory_in_constdims_data': {'kind': 'data'},
+ 'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
+ 'crop_mem_constdims_data': {'kind': 'data'},
+ 'concat_constdims': {'kind': 'op', 'op': 'Concat'},
+ 'concat_constdims_data': {'kind': 'data'},
+ 'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
+ 'memory_out_constdims_data': {'kind': 'data'},
+ 'result_constdims': {'kind': 'op', 'op': 'Result'},
+ 'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
+ 'crop_first_constdims_data': {'kind': 'data'},
+ 'concat_all': {'kind': 'op', 'op': 'Concat'},
+ 'concat_all_data': {'kind': 'data', 'shape': [1, 43]},
+ 'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
+ },
+ [
+ ('in_placeholder', 'in_node'),
+ ('in_node', 'split'),
+ ('split', 'split_data_0', {'out': 0}),
+ ('split', 'split_data_1', {'out': 1}),
+ ('memory_in', 'memory_in_data'),
+ ('memory_in_data', 'crop_mem'),
+ ('crop_mem', 'crop_mem_data'),
+ ('crop_mem_data', 'concat', {'in': 0}),
+ ('split_data_0', 'concat', {'in': 1}),
+ ('concat', 'concat_data'),
+ ('concat_data', 'memory_out'),
+ ('memory_out', 'memory_out_data'),
+ ('memory_out_data', 'result'),
+ ('memory_in_constdims', 'memory_in_constdims_data'),
+ ('memory_in_constdims_data', 'crop_mem_constdims'),
+ ('crop_mem_constdims', 'crop_mem_constdims_data'),
+ ('crop_mem_constdims_data', 'concat_constdims', {'in': 0}),
+ ('split_data_1', 'concat_constdims', {'in': 1}),
+ ('concat_constdims', 'concat_constdims_data'),
+ ('concat_constdims_data', 'memory_out_constdims'),
+ ('memory_out_constdims', 'memory_out_constdims_data'),
+ ('memory_out_constdims_data', 'result_constdims'),
+ ('concat_constdims_data', 'crop_first_constdims'),
+ ('crop_first_constdims', 'crop_first_constdims_data'),
+ ('crop_first_constdims_data', 'concat_all', {'in': 1}),
+ ('concat_data', 'concat_all', {'in': 0}),
+ ('concat_all', 'concat_all_data'),
+ ('concat_all_data', 'out_placeholder'),
+ ]
+ )
+
+ (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
+ self.assertTrue(flag, resp)