2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node, Graph
22 from mo.middle.passes.eliminate import mark_output_reachable_nodes, graph_clean_up, mark_const_producer_nodes
23 from mo.utils.unittest.graph import build_graph
25 nodes_attributes = {'placeholder_1': {'type': 'Placeholder', 'kind': 'op'},
26 'placeholder_2': {'type': 'Placeholder', 'kind': 'op'},
27 'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
28 'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
29 'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
30 'node_4': {'type': 'Identity', 'value': None, 'kind': 'op'},
31 'node_5': {'type': 'Identity', 'value': None, 'kind': 'op'},
32 'node_6': {'type': 'Identity', 'value': None, 'kind': 'op'},
33 'placeholder_1_data_node': {'value': None, 'kind': 'data'},
34 'placeholder_2_data_node': {'value': None, 'kind': 'data'},
35 'data_node_1': {'value': None, 'kind': 'data'},
36 'data_node_2': {'value': None, 'kind': 'data'},
37 'data_node_3': {'value': None, 'kind': 'data'},
38 'data_node_3_2': {'value': None, 'kind': 'data'},
39 'data_node_4': {'value': None, 'kind': 'data'},
40 'data_node_5': {'value': None, 'shape': None, 'kind': 'data'},
41 'data_node_6': {'value': None, 'shape': None, 'kind': 'data'},
42 'tf_call_1': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
43 'tf_call_2': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
44 'tf_call_3': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
45 'op_output': {'kind': 'op', 'op': 'OpOutput'},
46 'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
47 'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
51 class TestEliminatePass(unittest.TestCase):
52 def test_mark_output_unreachable_nodes(self):
54 Checks that all nodes that are unreachable from output nodes are marked correspondingly.
55 The graph doesn't contain data nodes yet.
58 placeholder_1->node_1->node_2
64 graph = build_graph(nodes_attributes,
65 [('placeholder_1', 'node_1'),
67 ('placeholder_1', 'node_3'),
69 ('node_4', 'op_output')
72 nodes_with_edges_only=True)
73 mark_output_reachable_nodes(graph)
75 self.assertListEqual(sorted(['placeholder_1', 'node_3', 'op_output', 'node_4']),
76 sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
77 self.assertListEqual(sorted(['node_1', 'node_2']),
78 sorted(graph.get_nodes_with_attributes(is_output_reachable=False)))
80 def test_mark_output_unreachable_nodes_behind_output(self):
82 Checks case when unreachable node is 'behind' (i.e. is the child) of the output node.
83 The graph doesn't contain data nodes yet.
86 placeholder_1->node_1->node_2->node_3
90 graph = build_graph(nodes_attributes,
91 [('placeholder_1', 'node_1'),
94 ('node_2', 'op_output')
97 nodes_with_edges_only=True)
98 mark_output_reachable_nodes(graph)
100 self.assertListEqual(sorted(['node_1', 'node_2', 'op_output', 'placeholder_1']),
101 sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
102 self.assertFalse(graph.node['node_3']['is_output_reachable'])
104 def test_mark_ops_producing_constant_values(self):
106 Checks case when operation produces only constant tensors so it could be removed. If the node produces several
107 tensors and at least one of them is not constant then we should not mark this node.
108 The graph contains data nodes.
109 "data_node_2" and "data_node_5" are output.
110 "node_3" produces constant tensor "data_node_3" and non-constant tensor "data_node_3_2".
111 "node_6" produces constant tensor "data_node_6".
112 "node_4" could be eliminated since it gets constant input.
114 node_6->data_node_6->
116 placeholder_1->placeholder_1_data_node->node_1->data_node_1->node_2->data_node_2
118 node_3->data_node_3->node_4->data_node_4->
120 ->data_node_3_2->node_5->data_node_5
124 graph = build_graph(nodes_attributes,
125 [('placeholder_1', 'placeholder_1_data_node'),
126 ('placeholder_1_data_node', 'node_1'),
127 ('node_1', 'data_node_1'),
128 ('data_node_1', 'node_2'),
129 ('node_2', 'data_node_2'),
130 ('node_3', 'data_node_3'),
131 ('node_3', 'data_node_3_2'),
132 ('node_6', 'data_node_6'),
133 ('data_node_6', 'node_1'),
134 ('data_node_3_2', 'node_5'),
135 ('node_5', 'data_node_5'),
136 ('data_node_3', 'node_4'),
137 ('data_node_4', 'node_1'),
138 ('data_node_2', 'op_output'),
139 ('data_node_5', 'op_output_1')
143 'data_node_3': {'value': np.array(1)},
144 'data_node_6': {'value': np.array(1)}},
145 nodes_with_edges_only=True)
146 mark_const_producer_nodes(graph)
147 self.assertTrue((graph.node['node_6']['is_const_producer']))
148 self.assertListEqual(sorted(['node_1', 'node_2', 'node_3', 'node_5', 'placeholder_1']),
149 sorted(graph.get_nodes_with_attributes(is_const_producer=False, kind='op')))
151 graph_clean_up(graph)
152 self.assertTrue('node_3' in graph.nodes())
153 self.assertTrue('node_4' not in graph.nodes())
154 self.assertTrue('node_6' not in graph.nodes())
156 def test_undead_nodes_with_constant_inputs(self):
158 Checks that if node of 'undead' type has constant inputs it is not removed from the graph.
163 def test_remove_node_from_graph(self):
165 Checks case when remove node from graph.
166 The graph doesn't contain removed node yet.
167 "node_2" should be removed.
169 placeholder_1->node_1->node_2->node_3
173 graph = build_graph(nodes_attributes,
174 [('placeholder_1', 'node_1'),
175 ('node_1', 'node_2'),
176 ('node_2', 'node_3')],
177 nodes_with_edges_only=True)
178 graph.erase_node(Node(graph, 'node_2'))
180 self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_3']), sorted(graph.nodes()))