Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / AddMeanScaleValues_test.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18 from argparse import Namespace
19
20 import numpy as np
21
22 from extensions.middle.AddMeanScaleValues import AddMeanScaleValues
23 from mo.graph.graph import Node
24 from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
25 from mo.utils.unittest.graph import build_graph
26
27 nodes_attributes = {'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
28                     'node_1_data': {'value': None, 'kind': 'data', 'data_type': None},
29                     'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
30                     'concat': {'type': 'Concat', 'value': None, 'kind': 'op'},
31                     'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
32                     'node_3_data': {'value': None, 'kind': 'data', 'data_type': None},
33                     # Placeholders
34                     'placeholder_1': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
35                     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36                     'placeholder_2': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
37                     'pl_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
38                     'pl_1_data': {'value': None, 'kind': 'data', 'data_type': None},
39                     'pl_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
40                     'pl_2_data': {'value': None, 'kind': 'data', 'data_type': None},
41                     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
42                     # ScaleShift layer
43                     'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
44                     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
45                     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
46                     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
47                     # Mul op
48                     'mul_1': {'type': None, 'kind': 'op', 'op': 'Mul'},
49                     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
50                     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
51                     'op_output': {'kind': 'op', 'op': 'OpOutput', 'infer': lambda x: None}
52                     }
53
54
55 class AddMeanScaleValuesTest(unittest.TestCase):
56     def test_add_mean_scale_values_with_data_name(self):
57         graph = build_graph(nodes_attributes,
58                             [('node_1', 'node_2'),
59                              ('node_2', 'op_output')
60                              ],
61                             {'node_2': {'shape': None, 'data_type': None},
62                              'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder', 'name': 'data',
63                                         'data_type': None}
64                              },
65                             nodes_with_edges_only=True)
66         graph.graph['layout'] = 'NCHW'
67         mean_values = parse_tuple_pairs('(124,117,104)')
68         scale_values = parse_tuple_pairs('')
69
70         # input = 'data'
71         mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
72         argv = Namespace(mean_scale_values=mean_scale)
73         graph.graph['cmd_params'] = argv
74         self.assertEqual(len(graph), 3)
75         AddMeanScaleValues().find_and_replace_pattern(graph)
76         self.assertEqual(len(graph), 6)
77
78     def test_add_mean_scale_values_without_data_name(self):
79         graph = build_graph(nodes_attributes,
80                             [('node_1', 'node_2'),
81                              ('node_2', 'op_output')
82                              ],
83                             {'node_2': {'shape': None, 'data_type': None},
84                              'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder', 'name': 'data',
85                                         'data_type': None}
86                              },
87                             nodes_with_edges_only=True)
88         graph.graph['layout'] = 'NCHW'
89         mean_values = parse_tuple_pairs('(124,117,104)')
90         scale_values = parse_tuple_pairs('')
91         # input = None
92         mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None)
93         argv = Namespace(mean_scale_values=mean_scale)
94         graph.graph['cmd_params'] = argv
95         self.assertEqual(len(graph), 3)
96         AddMeanScaleValues().find_and_replace_pattern(graph)
97         self.assertEqual(len(graph), 6)
98
99     def test_add_mean_scale_values1(self):
100         graph = build_graph(nodes_attributes,
101                             [('pl_1', 'pl_1_data'), ('pl_2', 'pl_2_data')],
102                             {'pl_1_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
103                              'pl_2_data': {'shape': np.array([1, 6]), 'infer': None},
104                              'pl_1': {'shape': np.array([1, 3, 38, 38])},
105                              'pl_2': {'shape': np.array([1, 6])},
106                              },
107                             nodes_with_edges_only=True)
108         graph.graph['layout'] = 'NCHW'
109         argv = Namespace(
110             mean_scale_values={'pl_1': {'mean': np.array([1., 2., 3.])}, 'pl_2': {'mean': np.array([0., 0., 0.])}})
111         graph.graph['cmd_params'] = argv
112         graph.graph['cmd_params'] = argv
113         AddMeanScaleValues().find_and_replace_pattern(graph)
114         mul_op_cnt = 0
115         add_op_cnt = 0
116         for node in graph.nodes():
117             node = Node(graph, node)
118             if node.has_valid('op') and node.op == 'Mul':
119                 mul_op_cnt += 1
120             if node.has_valid('op') and node.op == 'Add':
121                 add_op_cnt += 1
122
123         self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
124         self.assertEqual(mul_op_cnt, 0, "Found Mul op in graph")
125
126     def test_optimize_scale_and_add_mean_values(self):
127         graph = build_graph(
128             nodes_attributes,
129             [
130                 ('pl_1', 'pl_1_data')
131             ],
132             {
133                 'pl_1_data': {
134                     'shape': np.array([1, 3, 38, 38]),
135                     'infer': None
136                 },
137                 'pl_1': {
138                     'shape': np.array([1, 3, 38, 38])
139                 }
140             },
141             nodes_with_edges_only=True
142         )
143         graph.graph['layout'] = 'NCHW'
144         argv = Namespace(mean_scale_values={'pl_1': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
145         graph.graph['cmd_params'] = argv
146         AddMeanScaleValues().find_and_replace_pattern(graph)
147         mul_op_cnt = 0
148         add_op_cnt = 0
149         for node in graph.nodes():
150             node = Node(graph, node)
151             if node.has_valid('op') and node.op == 'Mul':
152                 mul_op_cnt += 1
153             if node.has_valid('op') and node.op == 'Add':
154                 add_op_cnt += 1
155
156         self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
157         self.assertEqual(mul_op_cnt, 0, "Found Mul op in graph")
158
159     def test_optimize_mean_and_add_scale_values(self):
160         graph = build_graph(
161             nodes_attributes,
162             [
163                 ('pl_1', 'pl_1_data')
164             ],
165             {
166                 'pl_1_data': {
167                     'shape': np.array([1, 3, 38, 38]),
168                     'infer': None
169                 },
170                 'pl_1': {
171                     'shape': np.array([1, 3, 38, 38])
172                 }
173             },
174             nodes_with_edges_only=True
175         )
176         graph.graph['layout'] = 'NCHW'
177         argv = Namespace(mean_scale_values={'pl_1': {'scale': np.array([1.43]), 'mean': np.array([0., 0., 0.])}})
178         graph.graph['cmd_params'] = argv
179         AddMeanScaleValues().find_and_replace_pattern(graph)
180         mul_op_cnt = 0
181         add_op_cnt = 0
182         for node in graph.nodes():
183             node = Node(graph, node)
184             if node.has_valid('op') and node.op == 'Mul':
185                 mul_op_cnt += 1
186             if node.has_valid('op') and node.op == 'Add':
187                 add_op_cnt += 1
188
189         self.assertEqual(add_op_cnt, 0, "Found more than one Add op in graph")
190         self.assertEqual(mul_op_cnt, 1, "Found Mul op in graph")
191
192     def test_add_mean_scale_values3(self):
193         graph = build_graph(nodes_attributes,
194                             [('pl_1', 'pl_1_data')],
195                             {'pl_1_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
196                              'pl_1': {'shape': np.array([1, 3, 38, 38])},
197                              },
198                             nodes_with_edges_only=True)
199         graph.graph['layout'] = 'NCHW'
200         argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
201         graph.graph['cmd_params'] = argv
202         AddMeanScaleValues().find_and_replace_pattern(graph)
203
204         mul_op_cnt = 0
205         add_op_cnt = 0
206         for node in graph.nodes():
207             node = Node(graph, node)
208             if node.has_valid('op') and node.op == 'Mul':
209                 mul_op_cnt += 1
210             if node.has_valid('op') and node.op == 'Add':
211                 add_op_cnt += 1
212
213         self.assertEqual(add_op_cnt, 1, "Found more than one Add op in graph")
214         self.assertEqual(mul_op_cnt, 1, "Found more than one Nul op in graph")
215
216     def test_add_mean_scale_values_cut_graph(self):
217         """
218         Test case when user cutted start of the network and specified mean/scale value to the new input node 'node_3'.
219         """
220         graph = build_graph(nodes_attributes,
221                             [('pl_1', 'pl_1_data'),
222                              ('pl_2', 'pl_2_data'),
223                              ('pl_2_data', 'node_3'),
224                              ('node_3', 'node_3_data'),
225                              ('pl_1_data', 'node_1'),
226                              ('node_3_data', 'node_1'),
227                              ],
228                             {'pl_1_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
229                              'pl_2_data': {'shape': np.array([1, 3, 38, 38]), 'infer': None},
230                              'pl_2': {'initial_node_name': 'node_3', 'shape': np.array([1, 3, 38, 38])},
231                              'pl_1': {'shape': np.array([1, 3, 38, 38])},
232                              },
233                             nodes_with_edges_only=True)
234         graph.graph['layout'] = 'NCHW'
235         argv = Namespace(
236             mean_scale_values={'pl_1': {'mean': np.array([1, 2, 3])}, 'node_3': {'scale': np.array([1, 2, 3])}})
237         graph.graph['cmd_params'] = argv
238         AddMeanScaleValues().find_and_replace_pattern(graph)
239
240         mul_op_cnt = 0
241         add_op_cnt = 0
242         for node in graph.nodes():
243             node = Node(graph, node)
244             if node.has_valid('op') and node.op == 'Mul':
245                 mul_op_cnt += 1
246             if node.has_valid('op') and node.op == 'Add':
247                 add_op_cnt += 1
248
249         self.assertEqual(add_op_cnt, 1, "There should be exactly one Add op")
250         self.assertEqual(mul_op_cnt, 1, "There should be exactly one Mul op")
251         self.assertEqual(Node(graph, 'pl_2').out_node().out_node().op, 'Mul', "The Mul op should be added after pl_2")
252         self.assertEqual(Node(graph, 'pl_1').out_node().out_node().op, 'Add', "The Add op should be added after pl_1")