Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / extractors / utils_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18 from unittest.mock import patch
19
20 import mxnet as mx
21
22 from mo.front.mxnet.extractors.utils import AttrDictionary
23 from mo.front.mxnet.extractors.utils import load_params
24
25
26 class TestAttrDictionary(unittest.TestCase):
27     def testBool(self):
28         attrs = {
29             "global_pool": "True"
30         }
31
32         attr_dict = AttrDictionary(attrs)
33         global_pool = attr_dict.bool("global_pool", False)
34         self.assertEqual(True, global_pool)
35
36     def testBoolAsDigits(self):
37         attrs = {
38             "global_pool": "1"
39         }
40
41         attr_dict = AttrDictionary(attrs)
42         global_pool = attr_dict.bool("global_pool", False)
43         self.assertEqual(True, global_pool)
44
45     def testBoolWithoutAttr(self):
46         attrs = {
47             "something": "1"
48         }
49
50         attr_dict = AttrDictionary(attrs)
51         global_pool = attr_dict.bool("global_pool", False)
52         self.assertEqual(False, global_pool)
53
54     def testStrAttr(self):
55         attrs = {
56             "something": "Val"
57         }
58
59         attr_dict = AttrDictionary(attrs)
60         attr = attr_dict.str("something", "Text")
61         self.assertEqual("Val", attr)
62
63     def testStrAttrWithoutAttr(self):
64         attrs = {
65             "something2": "Val"
66         }
67
68         attr_dict = AttrDictionary(attrs)
69         attr = attr_dict.str("something", "Text")
70         self.assertEqual("Text", attr)
71
72     def testFloatAttr(self):
73         attrs = {
74             "something": "0.5"
75         }
76
77         attr_dict = AttrDictionary(attrs)
78         attr = attr_dict.float("something", 0.1)
79         self.assertEqual(0.5, attr)
80
81     def testFloatWithoutAttr(self):
82         attrs = {
83             "something2": "0.5"
84         }
85
86         attr_dict = AttrDictionary(attrs)
87         attr = attr_dict.float("something", 0.1)
88         self.assertEqual(0.1, attr)
89
90     def testIntAttr(self):
91         attrs = {
92             "something": "5"
93         }
94
95         attr_dict = AttrDictionary(attrs)
96         attr = attr_dict.float("something", 1)
97         self.assertEqual(5, attr)
98
99     def testIntWithoutAttr(self):
100         attrs = {
101             "something2": "5"
102         }
103
104         attr_dict = AttrDictionary(attrs)
105         attr = attr_dict.float("something", 1)
106         self.assertEqual(1, attr)
107
108     def testTupleAttr(self):
109         attrs = {
110             "something": "(5,6,7)"
111         }
112
113         attr_dict = AttrDictionary(attrs)
114         a, b, c = attr_dict.tuple("something", int, (1, 2, 3))
115         self.assertEqual(5, a)
116         self.assertEqual(6, b)
117         self.assertEqual(7, c)
118
119     def testTupleWithoutAttr(self):
120         attrs = {
121             "something2": "(5,6,7)"
122         }
123
124         attr_dict = AttrDictionary(attrs)
125         a, b, c = attr_dict.tuple("something", int, (1, 2, 3))
126         self.assertEqual(1, a)
127         self.assertEqual(2, b)
128         self.assertEqual(3, c)
129
130     def testTupleWithEmptyTupleAttr(self):
131         attrs = {
132             "something2": "()"
133         }
134
135         attr_dict = AttrDictionary(attrs)
136         a, b = attr_dict.tuple("something", int, (2, 3))
137         self.assertEqual(2, a)
138         self.assertEqual(3, b)
139
140     def testTupleWithEmptyListAttr(self):
141         attrs = {
142             "something2": "[]"
143         }
144
145         attr_dict = AttrDictionary(attrs)
146         a, b = attr_dict.tuple("something", int, (2, 3))
147         self.assertEqual(2, a)
148         self.assertEqual(3, b)
149
150     def testListAttr(self):
151         attrs = {
152             "something": "5,6,7"
153         }
154
155         attr_dict = AttrDictionary(attrs)
156         l = attr_dict.list("something", int, [1, 2, 3])
157         self.assertEqual(5, l[0])
158         self.assertEqual(6, l[1])
159         self.assertEqual(7, l[2])
160
161     def testListWithoutAttr(self):
162         attrs = {
163             "something2": "5,6,7"
164         }
165
166         attr_dict = AttrDictionary(attrs)
167         l = attr_dict.list("something", int, [1, 2, 3])
168         self.assertEqual(1, l[0])
169         self.assertEqual(2, l[1])
170         self.assertEqual(3, l[2])
171
172     def testIntWithAttrNone(self):
173         attrs = {
174             "something": "None"
175         }
176
177         attr_dict = AttrDictionary(attrs)
178         attr = attr_dict.int("something", None)
179         self.assertEqual(None, attr)
180
181
182 class TestUtils(unittest.TestCase):
183     @patch('mxnet.nd.load')
184     def test_load_symbol_nodes_from_params(self, mock_nd_load):
185         mock_nd_load.return_value = {'arg:conv0_weight': mx.nd.array([1, 2], dtype='float32'),
186                                      'arg:conv1_weight': mx.nd.array([2, 3], dtype='float32'),
187                                      'aux:bn_data_mean': mx.nd.array([5, 6], dtype='float32')}
188         model_params = load_params("model.params")
189         self.assertTrue('conv0_weight' in model_params._param_names)
190         self.assertTrue('conv1_weight' in model_params._param_names)
191         self.assertTrue('bn_data_mean' in model_params._aux_names)
192         self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist())
193         self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
194         self.assertEqual([5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())
195
196     @patch('mxnet.nd.load')
197     def test_load_symbol_nodes_from_args_nd(self, mock_nd_load):
198         mock_nd_load.return_value = {'conv0_weight': mx.nd.array([1, 2], dtype='float32'),
199                                      'conv1_weight': mx.nd.array([2, 3], dtype='float32')}
200         model_params = load_params("args_model.nd", data_names=('data1', 'data2'))
201         self.assertTrue('conv0_weight' in model_params._param_names)
202         self.assertTrue('conv1_weight' in model_params._param_names)
203         self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist())
204         self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
205
206     @patch('mxnet.nd.load')
207     def test_load_symbol_nodes_from_auxs_nd(self, mock_nd_load):
208         mock_nd_load.return_value = {'bn_data_mean': mx.nd.array([5, 6], dtype='float32')}
209         model_params = load_params("auxs_model.nd")
210         self.assertTrue('bn_data_mean' in model_params._aux_names)
211         self.assertEqual([5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())