2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.graph.graph import Node
20 from mo.middle.passes.fusing.helpers import forward_bfs, backward_bfs, get_next_operation
21 from mo.utils.unittest.graph import build_graph
24 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
25 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
27 'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
28 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
29 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
30 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
31 # Mul and Add operations
32 'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul'},
33 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
34 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
35 'add_1': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
36 'add_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
37 'add_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
38 # Mul2 and Add2 operations
39 'mul_2': {'type': 'Mul', 'kind': 'op', 'op': 'Mul'},
40 'mul_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
41 'mul_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
42 'add_2': {'type': 'Add', 'kind': 'op', 'op': 'Add'},
43 'add_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
44 'add_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
46 'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
47 'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
49 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
50 'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'},
51 'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'},
52 'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
53 'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
54 'conv_2_w': {'value': None, 'shape': None, 'kind': 'data'},
55 'conv_2_b': {'value': None, 'shape': None, 'kind': 'data'},
56 'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
58 'fc_1': {'type': 'FullyConnected', 'kind': 'op', 'op': 'InnerProduct', 'layout': 'NHWC'},
59 'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
60 'fc_1_b': {'value': None, 'shape': None, 'kind': 'data'},
61 'fc_1_data': {'value': None, 'shape': None, 'kind': 'data'},
63 'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
64 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
65 'placeholder_3': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
66 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
67 'op_output': { 'kind': 'op', 'op': 'OpOutput'}
71 # Unit tests for forward and backward bfs (forward_bfs, backward_bfs)
72 class BFSTests(unittest.TestCase):
73 def test_forward_bfs_simple(self):
74 # Placeholder->ScaleShift->Mul->Add
75 graph = build_graph(nodes_attributes,
76 [('placeholder_1', 'placeholder_1_data'),
77 ('placeholder_1_data', 'scaleshift_1'),
78 ('scaleshift_1_w', 'scaleshift_1'),
79 ('scaleshift_1', 'scaleshift_1_data'),
80 ('scaleshift_1_data', 'mul_1'),
81 ('mul_1', 'mul_1_data'),
82 ('mul_1_data', 'add_1'),
83 ('add_1', 'add_1_data'),
84 ('add_1_data', 'op_output')
87 res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul'], ['Add'])
88 self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'Add operation was not found by bfs')
90 res = forward_bfs(Node(graph, 'placeholder_1'), [], ['Add'], allowed_all=True)
91 self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'Add operation was not found by bfs')
93 res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Add'])
94 self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
96 res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Mul', 'Add'])
97 self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'BFS should find only one Mul operation')
99 def test_backward_bfs_simple(self):
100 # Placeholder->ScaleShift->Mul->Add
101 graph = build_graph(nodes_attributes,
102 [('placeholder_1', 'placeholder_1_data'),
103 ('placeholder_1_data', 'scaleshift_1'),
104 ('scaleshift_1_w', 'scaleshift_1'),
105 ('scaleshift_1', 'scaleshift_1_data'),
106 ('scaleshift_1_data', 'mul_1'),
107 ('mul_1', 'mul_1_data'),
108 ('mul_1_data', 'add_1'),
109 ('add_1', 'add_1_data'),
110 ('add_1_data', 'op_output')
113 res = backward_bfs(Node(graph, 'add_1_data'), ['Add', 'ScaleShift', 'Mul'], ['Placeholder'])
114 self.assertTrue(len(res) == 1 and res[0].id == 'placeholder_1', 'Placeholder operation was not found by bfs')
116 res = backward_bfs(Node(graph, 'add_1'), [], ['Placeholder'], allowed_all=True)
117 self.assertTrue(len(res) == 1 and res[0].id == 'placeholder_1', 'Placeholder operation was not found by bfs')
119 res = backward_bfs(Node(graph, 'add_1_data'), ['Add'], ['ScaleShift'])
120 self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
122 res = backward_bfs(Node(graph, 'add_1_data'), ['Add', 'Mul'], ['Placeholder', 'ScaleShift'])
123 self.assertTrue(len(res) == 1 and res[0].id == 'scaleshift_1', 'BFS should find only one ScaleShift operation')
125 def test_forward_bfs_hard(self):
126 # Placeholder->ScaleShift->Mul1->Add1---->Concat
127 # `----------->Add2->Mul2--'
128 graph = build_graph(nodes_attributes,
129 [('placeholder_1', 'placeholder_1_data'),
130 ('placeholder_1_data', 'scaleshift_1'),
131 ('placeholder_1_data', 'add_2'),
132 ('scaleshift_1_w', 'scaleshift_1'),
133 ('scaleshift_1', 'scaleshift_1_data'),
134 ('scaleshift_1_data', 'mul_1'),
135 ('mul_1', 'mul_1_data'),
136 ('mul_1_data', 'add_1'),
137 ('add_1', 'add_1_data'),
138 ('add_2', 'add_2_data'),
139 ('add_2_data', 'mul_2'),
140 ('mul_2', 'mul_2_data'),
141 ('add_1_data', 'concat_1'),
142 ('mul_2_data', 'concat_1'),
143 ('concat_1', 'concat_1_data'),
144 ('concat_1_data', 'op_output')
147 res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul', 'Add'], ['Concat'])
148 self.assertTrue(len(res) == 1 and res[0].id == 'concat_1', 'Probably Concat operation was not found by bfs')
150 res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift', 'Mul'], ['Add'])
151 self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
152 'Add operations was not found by bfs')
154 res = forward_bfs(Node(graph, 'placeholder_1'), ['ScaleShift'], ['Add'])
155 self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')
157 res = forward_bfs(Node(graph, 'placeholder_1'), [], ['Add'], allowed_all=True)
158 self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
159 'Add operations was not found by bfs')
161 res = forward_bfs(Node(graph, 'placeholder_1_data'), ['ScaleShift'], ['Concat'])
162 self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
164 def test_backward_bfs_hard(self):
165 # Placeholder->ScaleShift->Mul1->Add1---->Concat
166 # `----------->Add2->Mul2--'
167 graph = build_graph(nodes_attributes,
168 [('placeholder_1', 'placeholder_1_data'),
169 ('placeholder_1_data', 'scaleshift_1'),
170 ('placeholder_1_data', 'add_2'),
171 ('scaleshift_1_w', 'scaleshift_1'),
172 ('scaleshift_1', 'scaleshift_1_data'),
173 ('scaleshift_1_data', 'mul_1'),
174 ('mul_1', 'mul_1_data'),
175 ('mul_1_data', 'add_1'),
176 ('add_1', 'add_1_data'),
177 ('add_2', 'add_2_data'),
178 ('add_2_data', 'mul_2'),
179 ('mul_2', 'mul_2_data'),
180 ('add_1_data', 'concat_1'),
181 ('mul_2_data', 'concat_1'),
182 ('concat_1', 'concat_1_data'),
183 ('concat_1_data', 'op_output')
186 res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift', 'Mul', 'Add'], ['Placeholder'])
187 self.assertTrue(len(res) == 0, 'Smth went wrong with bfs')
189 res = backward_bfs(Node(graph, 'concat_1'), ['Mul'], ['Add'])
190 self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
191 'Add operations was not found by bfs')
193 res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['Add'])
194 self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')
196 res = backward_bfs(Node(graph, 'concat_1'), [], ['Add'], allowed_all=True)
197 self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
198 'Add operations was not found by bfs')
200 res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['ScaleShift'])
201 self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
203 def test_backward_bfs_hard2(self):
204 # Placeholder->ScaleShift->Mul1->Add1---->Concat
205 # `----------->Add2->Mul2--'
206 graph = build_graph(nodes_attributes,
207 [('placeholder_1', 'placeholder_1_data'),
208 ('placeholder_1_data', 'add_2'),
209 ('scaleshift_1_w', 'scaleshift_1'),
210 ('scaleshift_1', 'scaleshift_1_data'),
211 ('scaleshift_1_data', 'mul_1'),
212 ('mul_1', 'mul_1_data'),
213 ('mul_1_data', 'add_1'),
214 ('add_1', 'add_1_data'),
215 ('add_2', 'add_2_data'),
216 ('add_2_data', 'mul_2'),
217 ('mul_2', 'mul_2_data'),
218 ('add_1_data', 'concat_1'),
219 ('mul_2_data', 'concat_1'),
220 ('concat_1', 'concat_1_data'),
221 ('concat_1_data', 'op_output')
224 res = backward_bfs(Node(graph, 'concat_1'), ['Mul', 'Add'], ['Placeholder'])
225 self.assertTrue(len(res) == 0, 'Smth went wrong with bfs')
227 res = backward_bfs(Node(graph, 'concat_1'), ['Mul'], ['Add'])
228 self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
229 'Add operations was not found by bfs')
231 res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['Add'])
232 self.assertTrue(len(res) == 0, 'BFS shouldn\'t find any operations')
234 res = backward_bfs(Node(graph, 'concat_1'), [], ['Add'], allowed_all=True)
235 self.assertTrue(len(res) == 2 and all([res[x].id in ['add_1', 'add_2'] for x in range(len(res))]),
236 'Add operations was not found by bfs')
238 res = backward_bfs(Node(graph, 'concat_1'), ['ScaleShift'], ['ScaleShift'])
239 self.assertTrue(len(res) == 0, 'No one node should be found! But bfs found {} nodes'.format(len(res)))
241 def test_backward_bfs_cycle(self):
242 # Placeholder->ScaleShift->Mul->Add
243 graph = build_graph(nodes_attributes,
244 [('placeholder_1', 'placeholder_1_data'),
245 ('placeholder_1_data', 'scaleshift_1'),
246 ('scaleshift_1_w', 'scaleshift_1'),
247 ('scaleshift_1', 'scaleshift_1_data'),
248 ('scaleshift_1_data', 'mul_1'),
249 ('mul_1', 'mul_1_data'),
250 ('mul_1_data', 'add_1'),
251 ('add_1', 'add_1_data'),
252 ('add_1_data', 'placeholder_1'),
253 ('add_1_data', 'op_output')
256 res = backward_bfs(Node(graph, 'add_1_data'), ['Add', 'ScaleShift', 'Mul', 'Placeholder'], ['Conv2D'])
257 self.assertTrue(len(res) == 0, 'Sholdn\'t find any nodes due to cycle in graph')
260 # Unit tests for get_next_operation
261 class GetNextOperationTests(unittest.TestCase):
262 def test_get_next_operation_1(self):
263 # Placeholder->ScaleShift->Mul->Add
264 graph = build_graph(nodes_attributes,
265 [('placeholder_1', 'placeholder_1_data'),
266 ('placeholder_1_data', 'scaleshift_1'),
267 ('scaleshift_1_w', 'scaleshift_1'),
268 ('scaleshift_1', 'scaleshift_1_data'),
269 ('scaleshift_1_data', 'mul_1'),
270 ('mul_1', 'mul_1_data'),
271 ('mul_1_data', 'add_1'),
272 ('add_1', 'add_1_data'),
273 ('add_1_data', 'op_output')
276 res = get_next_operation(Node(graph, 'mul_1'))
277 self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'get_nex_operation returned wrong op')
279 def test_get_next_operation_2(self):
280 # Placeholder->ScaleShift->Mul->Add
281 graph = build_graph(nodes_attributes,
282 [('placeholder_1', 'placeholder_1_data'),
283 ('placeholder_1_data', 'mul_1'),
284 ('placeholder_1_data', 'add_1'),
285 ('mul_1', 'mul_1_data'),
286 ('mul_1_data', 'add_1'),
287 ('add_1', 'add_1_data'),
288 ('add_1_data', 'op_output')
291 res = get_next_operation(Node(graph, 'placeholder_1'))
292 self.assertTrue(len(res) == 2 and all([x.id in ['add_1', 'mul_1'] for x in res]),
293 'get_nex_operation returned wrong op')
295 def test_get_next_operation_3(self):
296 # Placeholder-+--->ScaleShift
298 graph = build_graph(nodes_attributes,
299 [('placeholder_1', 'placeholder_1_data'),
300 ('placeholder_1', 'placeholder_2_data'),
301 ('placeholder_1_data', 'mul_1'),
302 ('placeholder_2_data', 'mul_1'),
303 ('mul_1', 'mul_1_data'),
304 ('mul_1_data', 'op_output')
307 res = get_next_operation(Node(graph, 'placeholder_1'))
308 self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'get_nex_operation returned wrong op')