Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / helpers_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 from mo.graph.graph import Node
20 from mo.middle.passes.fusing.helpers import forward_bfs, backward_bfs, get_next_operation
21 from mo.utils.unittest.graph import build_graph
22
23 nodes_attributes = {
24     'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
25     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
26     # ScaleShift layer
27     'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
28     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
29     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
30     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
31     # Mul and Add operations
32     'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul'},
33     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
34     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
35     'add_1': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
36     'add_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
37     'add_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
38     # Mul2 and Add2 operations
39     'mul_2': {'type': 'Mul', 'kind': 'op', 'op': 'Mul'},
40     'mul_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
41     'mul_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
42     'add_2': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
43     'add_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
44     'add_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
45     # Concat1 operation
46     'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
47     'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
48     # Convolutions
49     'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
50     'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'},
51     'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'},
52     'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
53     'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
54     'conv_2_w': {'value': None, 'shape': None, 'kind': 'data'},
55     'conv_2_b': {'value': None, 'shape': None, 'kind': 'data'},
56     'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
57     # FullyConnected
58     'fc_1': {'type': 'FullyConnected', 'kind': 'op', 'op': 'InnerProduct', 'layout': 'NHWC'},
59     'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
60     'fc_1_b': {'value': None, 'shape': None, 'kind': 'data'},
61     'fc_1_data': {'value': None, 'shape': None, 'kind': 'data'},
62     # Placeholders
63     'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
64     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
65     'placeholder_3': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
66     'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
67     'op_output': { 'kind': 'op', 'op': 'OpOutput'}
68 }
69
70
71 # Unit tests for forward and backward bfs (forward_bfs, backward_bfs)
72 class BFSTests(unittest.TestCase):
73     def test_forward_bfs_simple(self):
74         # Placeholder->ScaleShift->Mul->Add
75         graph = build_graph(nodes_attributes,
76                             [('placeholder_1', 'placeholder_1_data'),
77                              ('placeholder_1_data', 'scaleshift_1'),
78                              ('scaleshift_1_w', 'scaleshift_1'),
79                              ('scaleshift_1', 'scaleshift_1_data'),
80                              ('scaleshift_1_data', 'mul_1'),
81                              ('mul_1', 'mul_1_data'),
82                              ('mul_1_data', 'add_1'),
83                              ('add_1', 'add_1_data'),
84                              ('add_1_data', 'op_output')
85                              ])
86
87         res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul'], ['Add'])
88         self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'Add operation was not found by bfs')
89
90         res = forward_bfs(Node(graph, 'placeholder_1'), [], ['Add'], allowed_all=True)
91         self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'Add operation was not found by bfs')
92
93         res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Add'])
94         self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
95
96         res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Mul', 'Add'])
97         self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'BFS should find only one Mul operation')
98
99     def test_backward_bfs_simple(self):
100         # Placeholder->ScaleShift->Mul->Add
101         graph = build_graph(nodes_attributes,
102                             [('placeholder_1', 'placeholder_1_data'),
103                              ('placeholder_1_data', 'scaleshift_1'),
104                              ('scaleshift_1_w', 'scaleshift_1'),
105                              ('scaleshift_1', 'scaleshift_1_data'),
106                              ('scaleshift_1_data', 'mul_1'),
107                              ('mul_1', 'mul_1_data'),
108                              ('mul_1_data', 'add_1'),
109                              ('add_1', 'add_1_data'),
110                              ('add_1_data', 'op_output')
111                              ])
112
113         res = backward_bfs(Node(graph, 'add_1_data'), ['Add', 'ScaleShift', 'Mul'], ['Placeholder'])
114         self.assertTrue(len(res) == 1 and res[0].id == 'placeholder_1', 'Placeholder operation was not found by bfs')
115
116         res = backward_bfs(Node(graph, 'add_1'), [], ['Placeholder'], allowed_all=True)
117         self.assertTrue(len(res) == 1 and res[0].id == 'placeholder_1', 'Placeholder operation was not found by bfs')
118
119         res = backward_bfs(Node(graph, 'add_1_data'), ['Add'], ['ScaleShift'])
120         self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
121
122         res = backward_bfs(Node(graph, 'add_1_data'), ['Add', 'Mul'], ['Placeholder', 'ScaleShift'])
123         self.assertTrue(len(res) == 1 and res[0].id == 'scaleshift_1', 'BFS should find only one ScaleShift operation')
124
125     def test_forward_bfs_hard(self):
126         # Placeholder->ScaleShift->Mul1->Add1---->Concat
127         #             `----------->Add2->Mul2--'
128         graph = build_graph(nodes_attributes,
129                             [('placeholder_1', 'placeholder_1_data'),
130                              ('placeholder_1_data', 'scaleshift_1'),
131                              ('placeholder_1_data', 'add_2'),
132                              ('scaleshift_1_w', 'scaleshift_1'),
133                              ('scaleshift_1', 'scaleshift_1_data'),
134                              ('scaleshift_1_data', 'mul_1'),
135                              ('mul_1', 'mul_1_data'),
136                              ('mul_1_data', 'add_1'),
137                              ('add_1', 'add_1_data'),
138                              ('add_2', 'add_2_data'),
139                              ('add_2_data', 'mul_2'),
140                              ('mul_2', 'mul_2_data'),
141                              ('add_1_data', 'concat_1'),
142                              ('mul_2_data', 'concat_1'),
143                              ('concat_1', 'concat_1_data'),
144                              ('concat_1_data', 'op_output')
145                              ])
146
147         res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul', 'Add'], ['Concat'])
148         self.assertTrue(len(res) == 1 and res[0].id == 'concat_1', 'Probably Concat operation was not found by bfs')
149
150         res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul'], ['Add'])
151         self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
152                         'Add operations was not found by bfs')
153
154         res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift'], ['Add'])
155         self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')
156
157         res = forward_bfs(Node(graph, 'placeholder_1'), [], ['Add'], allowed_all=True)
158         self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
159                         'Add operations was not found by bfs')
160
161         res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Concat'])
162         self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
163
164     def test_backward_bfs_hard(self):
165         # Placeholder->ScaleShift->Mul1->Add1---->Concat
166         #             `----------->Add2->Mul2--'
167         graph = build_graph(nodes_attributes,
168                             [('placeholder_1', 'placeholder_1_data'),
169                              ('placeholder_1_data', 'scaleshift_1'),
170                              ('placeholder_1_data', 'add_2'),
171                              ('scaleshift_1_w', 'scaleshift_1'),
172                              ('scaleshift_1', 'scaleshift_1_data'),
173                              ('scaleshift_1_data', 'mul_1'),
174                              ('mul_1', 'mul_1_data'),
175                              ('mul_1_data', 'add_1'),
176                              ('add_1', 'add_1_data'),
177                              ('add_2', 'add_2_data'),
178                              ('add_2_data', 'mul_2'),
179                              ('mul_2', 'mul_2_data'),
180                              ('add_1_data', 'concat_1'),
181                              ('mul_2_data', 'concat_1'),
182                              ('concat_1', 'concat_1_data'),
183                              ('concat_1_data', 'op_output')
184                              ])
185
186         res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift', 'Mul', 'Add'], ['Placeholder'])
187         self.assertTrue(len(res) == 0, 'Smth went wrong with bfs')
188
189         res = backward_bfs(Node(graph, 'concat_1'), ['Mul'], ['Add'])
190         self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
191                         'Add operations was not found by bfs')
192
193         res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['Add'])
194         self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')
195
196         res = backward_bfs(Node(graph, 'concat_1'), [], ['Add'], allowed_all=True)
197         self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
198                         'Add operations was not found by bfs')
199
200         res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['ScaleShift'])
201         self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
202
203     def test_backward_bfs_hard2(self):
204         # Placeholder->ScaleShift->Mul1->Add1---->Concat
205         #             `----------->Add2->Mul2--'
206         graph = build_graph(nodes_attributes,
207                             [('placeholder_1', 'placeholder_1_data'),
208                              ('placeholder_1_data', 'add_2'),
209                              ('scaleshift_1_w', 'scaleshift_1'),
210                              ('scaleshift_1', 'scaleshift_1_data'),
211                              ('scaleshift_1_data', 'mul_1'),
212                              ('mul_1', 'mul_1_data'),
213                              ('mul_1_data', 'add_1'),
214                              ('add_1', 'add_1_data'),
215                              ('add_2', 'add_2_data'),
216                              ('add_2_data', 'mul_2'),
217                              ('mul_2', 'mul_2_data'),
218                              ('add_1_data', 'concat_1'),
219                              ('mul_2_data', 'concat_1'),
220                              ('concat_1', 'concat_1_data'),
221                              ('concat_1_data', 'op_output')
222                              ])
223
224         res = backward_bfs(Node(graph, 'concat_1'), ['Mul', 'Add'], ['Placeholder'])
225         self.assertTrue(len(res) == 0, 'Smth went wrong with bfs')
226
227         res = backward_bfs(Node(graph, 'concat_1'), ['Mul'], ['Add'])
228         self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
229                         'Add operations was not found by bfs')
230
231         res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['Add'])
232         self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')
233
234         res = backward_bfs(Node(graph, 'concat_1'), [], ['Add'], allowed_all=True)
235         self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
236                         'Add operations was not found by bfs')
237
238         res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['ScaleShift'])
239         self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
240
241     def test_backward_bfs_cycle(self):
242         # Placeholder->ScaleShift->Mul->Add
243         graph = build_graph(nodes_attributes,
244                             [('placeholder_1', 'placeholder_1_data'),
245                              ('placeholder_1_data', 'scaleshift_1'),
246                              ('scaleshift_1_w', 'scaleshift_1'),
247                              ('scaleshift_1', 'scaleshift_1_data'),
248                              ('scaleshift_1_data', 'mul_1'),
249                              ('mul_1', 'mul_1_data'),
250                              ('mul_1_data', 'add_1'),
251                              ('add_1', 'add_1_data'),
252                              ('add_1_data', 'placeholder_1'),
253                              ('add_1_data', 'op_output')
254                              ])
255
256         res = backward_bfs(Node(graph, 'add_1_data'), ['Add', 'ScaleShift', 'Mul', 'Placeholder'], ['Conv2D'])
257         self.assertTrue(len(res) == 0, 'Sholdn\'t find any nodes due to cycle in graph')
258
259
260 # Unit tests for get_next_operation
261 class GetNextOperationTests(unittest.TestCase):
262     def test_get_next_operation_1(self):
263         # Placeholder->ScaleShift->Mul->Add
264         graph = build_graph(nodes_attributes,
265                             [('placeholder_1', 'placeholder_1_data'),
266                              ('placeholder_1_data', 'scaleshift_1'),
267                              ('scaleshift_1_w', 'scaleshift_1'),
268                              ('scaleshift_1', 'scaleshift_1_data'),
269                              ('scaleshift_1_data', 'mul_1'),
270                              ('mul_1', 'mul_1_data'),
271                              ('mul_1_data', 'add_1'),
272                              ('add_1', 'add_1_data'),
273                              ('add_1_data', 'op_output')
274                              ])
275
276         res = get_next_operation(Node(graph, 'mul_1'))
277         self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'get_nex_operation returned wrong op')
278
279     def test_get_next_operation_2(self):
280         # Placeholder->ScaleShift->Mul->Add
281         graph = build_graph(nodes_attributes,
282                             [('placeholder_1', 'placeholder_1_data'),
283                              ('placeholder_1_data', 'mul_1'),
284                              ('placeholder_1_data', 'add_1'),
285                              ('mul_1', 'mul_1_data'),
286                              ('mul_1_data', 'add_1'),
287                              ('add_1', 'add_1_data'),
288                              ('add_1_data', 'op_output')
289                              ])
290
291         res = get_next_operation(Node(graph, 'placeholder_1'))
292         self.assertTrue(len(res) == 2 and all([x.id in ['add_1', 'mul_1'] for x in res]),
293                         'get_nex_operation returned wrong op')
294
295     def test_get_next_operation_3(self):
296         # Placeholder-+--->ScaleShift
297         #             +-----^
298         graph = build_graph(nodes_attributes,
299                             [('placeholder_1', 'placeholder_1_data'),
300                              ('placeholder_1', 'placeholder_2_data'),
301                              ('placeholder_1_data', 'mul_1'),
302                              ('placeholder_2_data', 'mul_1'),
303                              ('mul_1', 'mul_1_data'),
304                              ('mul_1_data', 'op_output')
305                              ])
306
307         res = get_next_operation(Node(graph, 'placeholder_1'))
308         self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'get_nex_operation returned wrong op')