2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node
22 from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, convert_add_or_mul_to_scaleshift
23 from mo.middle.passes.eliminate import graph_clean_up
24 from mo.utils.unittest.graph import build_graph, compare_graphs
27 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
28 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
30 'scaleshift_1': {'type': 'ScaleShift', 'value': None, 'kind': 'op', 'op': 'ScaleShift'},
31 'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op'},
32 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
33 'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op'},
34 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
35 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
36 # Mul and Add operations
37 'mul_1': {'value': None, 'kind': 'op', 'op': 'Mul'},
38 'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op'},
39 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
40 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
41 'add_1': {'value': None, 'kind': 'op', 'op': 'Add'},
42 'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op'},
43 'add_1_w': {'value': None, 'shape': None, 'kind': 'data'},
44 'add_1_data': {'value': None, 'shape': None, 'kind': 'data'},
46 'power_1': {'type': 'Power', 'kind': 'op', 'op': 'Power', 'scale': None, 'shift': None, 'power': None},
47 'power_1_data': {'value': None, 'shape': None, 'kind': 'data'},
48 'op_output': {'kind': 'op', 'op': 'OpOutput'},
52 class MulAddToScaleShiftOrPower(unittest.TestCase):
53 def _create_graph_with_mul_add(self, mul_w, add_w):
54 graph = build_graph(nodes_attributes,
55 [('placeholder_1', 'placeholder_1_data'),
56 ('placeholder_1_data', 'mul_1'),
57 ('const_mul_1_w', 'mul_1_w'),
59 ('mul_1', 'mul_1_data'),
60 ('mul_1_data', 'add_1'),
61 ('const_add_1_w', 'add_1_w'),
63 ('add_1', 'add_1_data'),
64 ('add_1_data', 'op_output')
66 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
67 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
68 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
69 'const_mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
70 'value': np.array(mul_w) if mul_w is not None else None},
71 'mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
72 'value': np.array(mul_w) if mul_w is not None else None},
73 'const_add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
74 'value': np.array(add_w) if add_w is not None else None},
75 'add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
76 'value': np.array(add_w) if add_w is not None else None},
78 del graph['mul_1']['mul_1_data'][0]['in']
79 del graph['add_1']['add_1_data'][0]['in']
82 def test_mul_add_to_scaleshift_1(self):
83 graph = self._create_graph_with_mul_add(np.array([1, 2, 3]), np.array([1, 2, 3]))
85 graph_ref = build_graph(nodes_attributes,
86 [('placeholder_1', 'placeholder_1_data'),
87 ('placeholder_1_data', 'scaleshift_1'),
88 ('const_scaleshift_1_w', 'scaleshift_1_w'),
89 ('scaleshift_1_w', 'scaleshift_1'),
90 ('const_scaleshift_1_b', 'scaleshift_1_b'),
91 ('scaleshift_1_b', 'scaleshift_1'),
92 ('scaleshift_1', 'scaleshift_1_data'),
93 ('scaleshift_1_data', 'op_output'),
95 {'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
96 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
97 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
98 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
99 'scaleshift_1_data': {}
102 convert_muladd_to_scaleshift_or_power(graph)
103 graph_clean_up(graph)
104 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'scaleshift_1_data')
105 self.assertTrue(flag, resp)
107 def test_mul_add_to_power_1(self):
108 graph = self._create_graph_with_mul_add(np.array([3]), np.array([2]))
110 graph_ref = build_graph(nodes_attributes,
111 [('placeholder_1', 'placeholder_1_data'),
112 ('placeholder_1_data', 'power_1'),
113 ('power_1', 'power_1_data'),
114 ('power_1_data', 'op_output'),
116 {'power_1': {'scale': 3, 'shift': 2, 'power': 1},
120 convert_muladd_to_scaleshift_or_power(graph)
121 graph_clean_up(graph)
122 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'power_1_data', check_op_attrs=True)
123 self.assertTrue(flag, resp)
125 def test_mul_add_neg_1(self):
126 graph = self._create_graph_with_mul_add(None, np.array([2]))
127 graph_ref = self._create_graph_with_mul_add(None, np.array([2]))
129 convert_muladd_to_scaleshift_or_power(graph)
130 graph_clean_up(graph)
131 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'add_1_data', check_op_attrs=True)
132 self.assertTrue(flag, resp)
134 def test_mul_add_neg_2(self):
135 graph = self._create_graph_with_mul_add(np.array([2]), None)
136 graph_ref = self._create_graph_with_mul_add(np.array([2]), None)
138 convert_muladd_to_scaleshift_or_power(graph)
139 graph_clean_up(graph)
140 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'add_1_data', check_op_attrs=True)
141 self.assertTrue(flag, resp)
143 def test_mul_add_neg_3(self):
144 graph = self._create_graph_with_mul_add(None, None)
145 graph_ref = self._create_graph_with_mul_add(None, None)
147 convert_muladd_to_scaleshift_or_power(graph)
148 graph_clean_up(graph)
149 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'add_1_data', check_op_attrs=True)
150 self.assertTrue(flag, resp)
152 def test_mul_add_neg_4(self):
153 graph = self._create_graph_with_mul_add(np.array([1, 2, 3]), np.array([3]))
154 graph_ref = self._create_graph_with_mul_add(np.array([1, 2, 3]), np.array(3))
156 convert_muladd_to_scaleshift_or_power(graph)
157 graph_clean_up(graph)
158 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'add_1_data', check_op_attrs=True)
159 self.assertTrue(flag, resp)
161 def test_mul_add_neg_5(self):
162 graph = self._create_graph_with_mul_add(np.array([3]), np.array([3, 2, 1]))
163 graph_ref = build_graph(nodes_attributes,
164 [('placeholder_1', 'placeholder_1_data'),
165 ('placeholder_1_data', 'scaleshift_1'),
166 ('const_scaleshift_1_w', 'scaleshift_1_w'),
167 ('scaleshift_1_w', 'scaleshift_1'),
168 ('const_scaleshift_1_b', 'scaleshift_1_b'),
169 ('scaleshift_1_b', 'scaleshift_1'),
170 ('scaleshift_1', 'add_1_data'),
171 ('add_1_data', 'op_output'),
173 {'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
174 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([3, 3, 3])},
175 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
176 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
179 convert_muladd_to_scaleshift_or_power(graph)
180 graph_clean_up(graph)
181 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'add_1_data', check_op_attrs=True)
182 self.assertTrue(flag, resp)
185 class AddToScaleShift(unittest.TestCase):
187 def _create_graph_with_add(add_w: np.ndarray):
188 graph = build_graph(nodes_attributes,
189 [('placeholder_1', 'placeholder_1_data'),
190 ('placeholder_1_data', 'add_1'),
191 ('const_add_1_w', 'add_1_w'),
192 ('add_1_w', 'add_1'),
193 ('add_1', 'add_1_data'),
194 ('add_1_data', 'op_output')
196 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
197 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
198 'const_add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
199 'value': np.array(add_w) if add_w is not None else None},
200 'add_1_w': {'shape': np.array(add_w.shape) if add_w is not None else None,
201 'value': np.array(add_w) if add_w is not None else None},
202 }, nodes_with_edges_only=True)
203 del graph['add_1']['add_1_data'][0]['in']
207 def _create_graph_with_mul(mul_w: np.ndarray):
208 graph = build_graph(nodes_attributes,
209 [('placeholder_1', 'placeholder_1_data'),
210 ('placeholder_1_data', 'mul_1'),
211 ('const_mul_1_w', 'mul_1_w'),
212 ('mul_1_w', 'mul_1'),
213 ('mul_1', 'mul_1_data'),
214 ('mul_1_data', 'op_output')
216 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
217 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
218 'const_mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
219 'value': np.array(mul_w) if mul_w is not None else None},
220 'mul_1_w': {'shape': np.array(mul_w.shape) if mul_w is not None else None,
221 'value': np.array(mul_w) if mul_w is not None else None},
222 }, nodes_with_edges_only=True)
223 del graph['mul_1']['mul_1_data'][0]['in']
226 def test_add_to_scaleshift_1(self):
227 graph = AddToScaleShift._create_graph_with_add(np.array([1, 2, 3], dtype=np.float32))
228 graph.stage = 'middle'
230 graph_ref = build_graph(nodes_attributes,
231 [('placeholder_1', 'placeholder_1_data'),
232 ('placeholder_1_data', 'scaleshift_1'),
233 ('const_scaleshift_1_w', 'scaleshift_1_w'),
234 ('const_scaleshift_1_b', 'scaleshift_1_b'),
235 ('scaleshift_1_w', 'scaleshift_1'),
236 ('scaleshift_1_b', 'scaleshift_1'),
237 ('scaleshift_1', 'scaleshift_1_data'),
238 ('scaleshift_1_data', 'op_output')
240 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
241 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])},
243 'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
244 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
246 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
247 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
248 }, nodes_with_edges_only=True)
250 convert_add_or_mul_to_scaleshift(graph)
251 graph_clean_up(graph)
253 (flag, resp) = compare_graphs(graph, graph_ref, 'op_output')
254 self.assertTrue(flag, resp)
256 scsh_node = Node(graph, 'op_output').in_port(0).get_source().node
258 self.assertTrue(graph.get_edge_data(scsh_node.in_node(1).id, scsh_node.id)[0]['bin'] == 'weights')
259 self.assertTrue(graph.get_edge_data(scsh_node.in_node(2).id, scsh_node.id)[0]['bin'] == 'biases')
261 def test_mul_to_scaleshift_1(self):
262 graph = AddToScaleShift._create_graph_with_mul(np.array([1, 2, 3], dtype=np.float32))
263 graph.stage = 'middle'
265 graph_ref = build_graph(nodes_attributes,
266 [('placeholder_1', 'placeholder_1_data'),
267 ('placeholder_1_data', 'scaleshift_1'),
268 ('const_scaleshift_1_w', 'scaleshift_1_w'),
269 ('const_scaleshift_1_b', 'scaleshift_1_b'),
270 ('scaleshift_1_w', 'scaleshift_1'),
271 ('scaleshift_1_b', 'scaleshift_1'),
272 ('scaleshift_1', 'scaleshift_1_data'),
273 ('scaleshift_1_data', 'op_output')
275 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
276 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])},
278 'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
279 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
281 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
282 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
283 }, nodes_with_edges_only=True)
285 convert_add_or_mul_to_scaleshift(graph)
286 graph_clean_up(graph)
288 (flag, resp) = compare_graphs(graph, graph_ref, 'op_output')
289 self.assertTrue(flag, resp)
291 scsh_node = Node(graph, 'op_output').in_port(0).get_source().node
293 self.assertTrue(graph.get_edge_data(scsh_node.in_node(1).id, scsh_node.id)[0]['bin'] == 'weights')
294 self.assertTrue(graph.get_edge_data(scsh_node.in_node(2).id, scsh_node.id)[0]['bin'] == 'biases')
298 if __name__ == '__main__':