Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / eliminate_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.graph.graph import Node, Graph
22 from mo.middle.passes.eliminate import mark_output_reachable_nodes, graph_clean_up, mark_const_producer_nodes
23 from mo.utils.unittest.graph import build_graph
24
25 nodes_attributes = {'placeholder_1': {'type': 'Placeholder', 'kind': 'op'},
26                     'placeholder_2': {'type': 'Placeholder', 'kind': 'op'},
27                     'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
28                     'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
29                     'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
30                     'node_4': {'type': 'Identity', 'value': None, 'kind': 'op'},
31                     'node_5': {'type': 'Identity', 'value': None, 'kind': 'op'},
32                     'node_6': {'type': 'Identity', 'value': None, 'kind': 'op'},
33                     'placeholder_1_data_node': {'value': None, 'kind': 'data'},
34                     'placeholder_2_data_node': {'value': None, 'kind': 'data'},
35                     'data_node_1': {'value': None, 'kind': 'data'},
36                     'data_node_2': {'value': None, 'kind': 'data'},
37                     'data_node_3': {'value': None, 'kind': 'data'},
38                     'data_node_3_2': {'value': None, 'kind': 'data'},
39                     'data_node_4': {'value': None, 'kind': 'data'},
40                     'data_node_5': {'value': None, 'shape': None, 'kind': 'data'},
41                     'data_node_6': {'value': None, 'shape': None, 'kind': 'data'},
42                     'tf_call_1': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
43                     'tf_call_2': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
44                     'tf_call_3': {'type': 'TFCustomSubgraphCall', 'kind': 'op'},
45                     'op_output': {'kind': 'op', 'op': 'OpOutput'},
46                     'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
47                     'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
48                     }
49
50
51 class TestEliminatePass(unittest.TestCase):
52     def test_mark_output_unreachable_nodes(self):
53         """
54         Checks that all nodes that are unreachable from output nodes are marked correspondingly.
55         The graph doesn't contain data nodes yet.
56         "node_4" is output.
57
58         placeholder_1->node_1->node_2
59               \
60                -> node_3->node_4
61
62         :return: None
63         """
64         graph = build_graph(nodes_attributes,
65                             [('placeholder_1', 'node_1'),
66                              ('node_1', 'node_2'),
67                              ('placeholder_1', 'node_3'),
68                              ('node_3', 'node_4'),
69                              ('node_4', 'op_output')
70                              ],
71                             {'node_4': {}},
72                             nodes_with_edges_only=True)
73         mark_output_reachable_nodes(graph)
74
75         self.assertListEqual(sorted(['placeholder_1', 'node_3', 'op_output', 'node_4']),
76                              sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
77         self.assertListEqual(sorted(['node_1', 'node_2']),
78                              sorted(graph.get_nodes_with_attributes(is_output_reachable=False)))
79
80     def test_mark_output_unreachable_nodes_behind_output(self):
81         """
82         Checks case when unreachable node is 'behind' (i.e. is the child) of the output node.
83         The graph doesn't contain data nodes yet.
84         "node_2" is output.
85
86         placeholder_1->node_1->node_2->node_3
87
88         :return: None
89         """
90         graph = build_graph(nodes_attributes,
91                             [('placeholder_1', 'node_1'),
92                              ('node_1', 'node_2'),
93                              ('node_2', 'node_3'),
94                              ('node_2', 'op_output')
95                              ],
96                             {'node_2': {}},
97                             nodes_with_edges_only=True)
98         mark_output_reachable_nodes(graph)
99
100         self.assertListEqual(sorted(['node_1', 'node_2', 'op_output', 'placeholder_1']),
101                              sorted(graph.get_nodes_with_attributes(is_output_reachable=True)))
102         self.assertFalse(graph.node['node_3']['is_output_reachable'])
103
104     def test_mark_ops_producing_constant_values(self):
105         """
106         Checks case when operation produces only constant tensors so it could be removed. If the node produces several
107         tensors and at least one of them is not constant then we should not mark this node.
108         The graph contains data nodes.
109         "data_node_2" and "data_node_5" are output.
110         "node_3" produces constant tensor "data_node_3" and non-constant tensor "data_node_3_2".
111         "node_6" produces constant tensor "data_node_6".
112         "node_4" could be eliminated since it gets constant input.
113
114                              node_6->data_node_6->
115                                                   \
116         placeholder_1->placeholder_1_data_node->node_1->data_node_1->node_2->data_node_2
117                                                   /
118         node_3->data_node_3->node_4->data_node_4->
119            \
120             ->data_node_3_2->node_5->data_node_5
121
122         :return: None
123         """
124         graph = build_graph(nodes_attributes,
125                             [('placeholder_1', 'placeholder_1_data_node'),
126                              ('placeholder_1_data_node', 'node_1'),
127                              ('node_1', 'data_node_1'),
128                              ('data_node_1', 'node_2'),
129                              ('node_2', 'data_node_2'),
130                              ('node_3', 'data_node_3'),
131                              ('node_3', 'data_node_3_2'),
132                              ('node_6', 'data_node_6'),
133                              ('data_node_6', 'node_1'),
134                              ('data_node_3_2', 'node_5'),
135                              ('node_5', 'data_node_5'),
136                              ('data_node_3', 'node_4'),
137                              ('data_node_4', 'node_1'),
138                              ('data_node_2', 'op_output'),
139                              ('data_node_5', 'op_output_1')
140                              ],
141                             {'data_node_2': {},
142                              'data_node_5': {},
143                              'data_node_3': {'value': np.array(1)},
144                              'data_node_6': {'value': np.array(1)}},
145                             nodes_with_edges_only=True)
146         mark_const_producer_nodes(graph)
147         self.assertTrue((graph.node['node_6']['is_const_producer']))
148         self.assertListEqual(sorted(['node_1', 'node_2', 'node_3', 'node_5', 'placeholder_1']),
149                              sorted(graph.get_nodes_with_attributes(is_const_producer=False, kind='op')))
150
151         graph_clean_up(graph)
152         self.assertTrue('node_3' in graph.nodes())
153         self.assertTrue('node_4' not in graph.nodes())
154         self.assertTrue('node_6' not in graph.nodes())
155
156     def test_undead_nodes_with_constant_inputs(self):
157         """
158         Checks that if node of 'undead' type has constant inputs it is not removed from the graph.
159         :return: None
160         """
161         pass
162
163     def test_remove_node_from_graph(self):
164         """
165         Checks case when remove node from graph.
166         The graph doesn't contain removed node yet.
167         "node_2" should be removed.
168
169         placeholder_1->node_1->node_2->node_3
170
171         :return: None
172         """
173         graph = build_graph(nodes_attributes,
174                             [('placeholder_1', 'node_1'),
175                              ('node_1', 'node_2'),
176                              ('node_2', 'node_3')],
177                             nodes_with_edges_only=True)
178         graph.erase_node(Node(graph, 'node_2'))
179
180         self.assertListEqual(sorted(['placeholder_1', 'node_1', 'node_3']), sorted(graph.nodes()))