"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import numpy as np
-from mo.graph.graph import Node, erase_node
-from mo.middle.passes.eliminate import mark_output_reachable_nodes, graph_clean_up, \
- get_nodes_with_attributes, mark_const_producer_nodes
+from mo.graph.graph import Node, Graph
+from mo.middle.passes.eliminate import mark_output_reachable_nodes, graph_clean_up, mark_const_producer_nodes
from mo.utils.unittest.graph import build_graph
nodes_attributes = {'placeholder_1': {'type': 'Placeholder', 'kind': 'op'},
'data_node_3': {'value': None, 'kind': 'data'},
'data_node_3_2': {'value': None, 'kind': 'data'},
'data_node_4': {'value': None, 'kind': 'data'},
- 'data_node_5': {'value': None, 'kind': 'data'},
- 'data_node_6': {'value': None, 'kind': 'data'},
+ 'data_node_5': {'value': None, 'shape': None, 'kind': 'data'},
+ 'data_node_6': {'value': None, 'shape': None, 'kind': 'data'},
'tf_call_1': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
'tf_call_2': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
'tf_call_3': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
+ 'op_output': {'kind': 'op', 'op': 'OpOutput'},
+ 'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
+ 'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
}
[('placeholder_1', 'node_1'),
('node_1', 'node_2'),
('placeholder_1', 'node_3'),
- ('node_3', 'node_4')],
- {'node_4': {'is_output': True}},
+ ('node_3', 'node_4'),
+ ('node_4', 'op_output')
+ ],
+ {'node_4': {}},
nodes_with_edges_only=True)
mark_output_reachable_nodes(graph)
- self.assertListEqual(sorted(['placeholder_1', 'node_3', 'node_4']),
- sorted(get_nodes_with_attributes(graph, is_output_reachable=True)))
+ self.assertListEqual(sorted(['placeholder_1', 'node_3', 'op_output', 'node_4']),
+ sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
self.assertListEqual(sorted(['node_1', 'node_2']),
- sorted(get_nodes_with_attributes(graph, is_output_reachable=False)))
+ sorted(graph.get_nodes_with_attributes(is_output_reachable=False)))
def test_mark_output_unreachable_nodes_behind_output(self):
"""
graph = build_graph(nodes_attributes,
[('placeholder_1', 'node_1'),
('node_1', 'node_2'),
- ('node_2', 'node_3')],
- {'node_2': {'is_output': True}},
+ ('node_2', 'node_3'),
+ ('node_2', 'op_output')
+ ],
+ {'node_2': {}},
nodes_with_edges_only=True)
mark_output_reachable_nodes(graph)
- self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_2']),
- sorted(get_nodes_with_attributes(graph, is_output_reachable=True)))
+ self.assertListEqual(sorted(['node_1', 'node_2', 'op_output', 'placeholder_1']),
+ sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
self.assertFalse(graph.node['node_3']['is_output_reachable'])
def test_mark_ops_producing_constant_values(self):
('data_node_3_2', 'node_5'),
('node_5', 'data_node_5'),
('data_node_3', 'node_4'),
- ('data_node_4', 'node_1')],
- {'data_node_2': {'is_output': True},
- 'data_node_5': {'is_output': True},
+ ('data_node_4', 'node_1'),
+ ('data_node_2', 'op_output'),
+ ('data_node_5', 'op_output_1')
+ ],
+ {'data_node_2': {},
+ 'data_node_5': {},
'data_node_3': {'value': np.array(1)},
'data_node_6': {'value': np.array(1)}},
nodes_with_edges_only=True)
mark_const_producer_nodes(graph)
self.assertTrue((graph.node['node_6']['is_const_producer']))
self.assertListEqual(sorted(['node_1', 'node_2', 'node_3', 'node_5', 'placeholder_1']),
- sorted(get_nodes_with_attributes(graph, is_const_producer=False, kind='op')))
+ sorted(graph.get_nodes_with_attributes(is_const_producer=False, kind='op')))
graph_clean_up(graph)
self.assertTrue('node_3' in graph.nodes())
('node_1', 'node_2'),
('node_2', 'node_3')],
nodes_with_edges_only=True)
- erase_node(Node(graph, 'node_2'))
+ graph.erase_node(Node(graph, 'node_2'))
self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_3']), sorted(graph.nodes()))