2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from argparse import Namespace
22 from extensions.middle.AddMeanScaleValues import AddMeanScaleValues
23 from mo.graph.graph import Node
24 from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
25 from mo.utils.unittest.graph import build_graph
27 nodes_attributes = {'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
28 'node_1_data': {'value': None, 'kind': 'data', 'data_type': None},
29 'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
30 'concat': {'type': 'Concat', 'value': None, 'kind': 'op'},
31 'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
32 'node_3_data': {'value': None, 'kind': 'data', 'data_type': None},
34 'placeholder_1': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
35 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36 'placeholder_2': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
37 'pl_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
38 'pl_1_data': {'value': None, 'kind': 'data', 'data_type': None},
39 'pl_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
40 'pl_2_data': {'value': None, 'kind': 'data', 'data_type': None},
41 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
43 'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
44 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
45 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
46 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
48 'mul_1': {'type': None, 'kind': 'op', 'op': 'Mul'},
49 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
50 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
51 'op_output': {'kind': 'op', 'op': 'OpOutput', 'infer': lambda x: None}
55 class AddMeanScaleValuesTest(unittest.TestCase):
56 def test_add_mean_scale_values_with_data_name(self):
57 graph = build_graph(nodes_attributes,
58 [('node_1', 'node_2'),
59 ('node_2', 'op_output')
61 {'node_2': {'shape': None, 'data_type': None},
62 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder', 'name': 'data',
65 nodes_with_edges_only=True)
66 graph.graph['layout'] = 'NCHW'
67 mean_values = parse_tuple_pairs('(124,117,104)')
68 scale_values = parse_tuple_pairs('')
71 mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
72 argv = Namespace(mean_scale_values=mean_scale)
73 graph.graph['cmd_params'] = argv
74 self.assertEqual(len(graph), 3)
75 AddMeanScaleValues().find_and_replace_pattern(graph)
76 self.assertEqual(len(graph), 6)
78 def test_add_mean_scale_values_without_data_name(self):
79 graph = build_graph(nodes_attributes,
80 [('node_1', 'node_2'),
81 ('node_2', 'op_output')
83 {'node_2': {'shape': None, 'data_type': None},
84 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder', 'name': 'data',
87 nodes_with_edges_only=True)
88 graph.graph['layout'] = 'NCHW'
89 mean_values = parse_tuple_pairs('(124,117,104)')
90 scale_values = parse_tuple_pairs('')
92 mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
93 argv = Namespace(mean_scale_values=mean_scale)
94 graph.graph['cmd_params'] = argv
95 self.assertEqual(len(graph), 3)
96 AddMeanScaleValues().find_and_replace_pattern(graph)
97 self.assertEqual(len(graph), 6)
99 def test_add_mean_scale_values1(self):
100 graph = build_graph(nodes_attributes,
101 [('pl_1', 'pl_1_data'), ('pl_2', 'pl_2_data')],
102 {'pl_1_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
103 'pl_2_data': {'shape': np.array([1, 6]), 'infer': None},
104 'pl_1': {'shape': np.array([1, 3, 38, 38])},
105 'pl_2': {'shape': np.array([1, 6])},
107 nodes_with_edges_only=True)
108 graph.graph['layout'] = 'NCHW'
110 mean_scale_values={'pl_1': {'mean': np.array([1., 2., 3.])}, 'pl_2': {'mean': np.array([0., 0., 0.])}})
111 graph.graph['cmd_params'] = argv
112 graph.graph['cmd_params'] = argv
113 AddMeanScaleValues().find_and_replace_pattern(graph)
116 for node in graph.nodes():
117 node = Node(graph, node)
118 if node.has_valid('op') and node.op == 'Mul':
120 if node.has_valid('op') and node.op == 'Add':
123 self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
124 self.assertEqual(mul_op_cnt, 0, "Found Mul op in graph")
126 def test_optimize_scale_and_add_mean_values(self):
130 ('pl_1', 'pl_1_data')
134 'shape': np.array([1, 3, 38, 38]),
138 'shape': np.array([1, 3, 38, 38])
141 nodes_with_edges_only=True
143 graph.graph['layout'] = 'NCHW'
144 argv = Namespace(mean_scale_values={'pl_1': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
145 graph.graph['cmd_params'] = argv
146 AddMeanScaleValues().find_and_replace_pattern(graph)
149 for node in graph.nodes():
150 node = Node(graph, node)
151 if node.has_valid('op') and node.op == 'Mul':
153 if node.has_valid('op') and node.op == 'Add':
156 self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
157 self.assertEqual(mul_op_cnt, 0, "Found Mul op in graph")
159 def test_optimize_mean_and_add_scale_values(self):
163 ('pl_1', 'pl_1_data')
167 'shape': np.array([1, 3, 38, 38]),
171 'shape': np.array([1, 3, 38, 38])
174 nodes_with_edges_only=True
176 graph.graph['layout'] = 'NCHW'
177 argv = Namespace(mean_scale_values={'pl_1': {'scale': np.array([1.43]), 'mean': np.array([0., 0., 0.])}})
178 graph.graph['cmd_params'] = argv
179 AddMeanScaleValues().find_and_replace_pattern(graph)
182 for node in graph.nodes():
183 node = Node(graph, node)
184 if node.has_valid('op') and node.op == 'Mul':
186 if node.has_valid('op') and node.op == 'Add':
189 self.assertEqual(add_op_cnt, 0, "Found more than one Add op in graph")
190 self.assertEqual(mul_op_cnt, 1, "Found Mul op in graph")
192 def test_add_mean_scale_values3(self):
193 graph = build_graph(nodes_attributes,
194 [('pl_1', 'pl_1_data')],
195 {'pl_1_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
196 'pl_1': {'shape': np.array([1, 3, 38, 38])},
198 nodes_with_edges_only=True)
199 graph.graph['layout'] = 'NCHW'
200 argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
201 graph.graph['cmd_params'] = argv
202 AddMeanScaleValues().find_and_replace_pattern(graph)
206 for node in graph.nodes():
207 node = Node(graph, node)
208 if node.has_valid('op') and node.op == 'Mul':
210 if node.has_valid('op') and node.op == 'Add':
213 self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
214 self.assertEqual(mul_op_cnt, 1, "Found more than one Nul op in graph")
216 def test_add_mean_scale_values_cut_graph(self):
218 Test case when user cutted start of the network and specified mean/scale value to the new input node 'node_3'.
220 graph = build_graph(nodes_attributes,
221 [('pl_1', 'pl_1_data'),
222 ('pl_2', 'pl_2_data'),
223 ('pl_2_data', 'node_3'),
224 ('node_3', 'node_3_data'),
225 ('pl_1_data', 'node_1'),
226 ('node_3_data', 'node_1'),
228 {'pl_1_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
229 'pl_2_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
230 'pl_2': {'initial_node_name': 'node_3', 'shape': np.array([1, 3, 38, 38])},
231 'pl_1': {'shape': np.array([1, 3, 38, 38])},
233 nodes_with_edges_only=True)
234 graph.graph['layout'] = 'NCHW'
236 mean_scale_values={'pl_1': {'mean': np.array([1, 2, 3])}, 'node_3': {'scale': np.array([1, 2, 3])}})
237 graph.graph['cmd_params'] = argv
238 AddMeanScaleValues().find_and_replace_pattern(graph)
242 for node in graph.nodes():
243 node = Node(graph, node)
244 if node.has_valid('op') and node.op == 'Mul':
246 if node.has_valid('op') and node.op == 'Add':
249 self.assertEqual(add_op_cnt, 1, "There should be exactly one Add op")
250 self.assertEqual(mul_op_cnt, 1, "There should be exactly one Mul op")
251 self.assertEqual(Node(graph, 'pl_2').out_node().out_node().op, 'Mul', "The Mul op should be added after pl_2")
252 self.assertEqual(Node(graph, 'pl_1').out_node().out_node().op, 'Add', "The Add op should be added after pl_1")