2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from unittest.mock import patch
22 from mo.front.mxnet.extractors.utils import AttrDictionary
23 from mo.front.mxnet.extractors.utils import load_params
26 class TestAttrDictionary(unittest.TestCase):
32 attr_dict = AttrDictionary(attrs)
33 global_pool = attr_dict.bool("global_pool", False)
34 self.assertEqual(True, global_pool)
36 def testBoolAsDigits(self):
41 attr_dict = AttrDictionary(attrs)
42 global_pool = attr_dict.bool("global_pool", False)
43 self.assertEqual(True, global_pool)
45 def testBoolWithoutAttr(self):
50 attr_dict = AttrDictionary(attrs)
51 global_pool = attr_dict.bool("global_pool", False)
52 self.assertEqual(False, global_pool)
54 def testStrAttr(self):
59 attr_dict = AttrDictionary(attrs)
60 attr = attr_dict.str("something", "Text")
61 self.assertEqual("Val", attr)
63 def testStrAttrWithoutAttr(self):
68 attr_dict = AttrDictionary(attrs)
69 attr = attr_dict.str("something", "Text")
70 self.assertEqual("Text", attr)
72 def testFloatAttr(self):
77 attr_dict = AttrDictionary(attrs)
78 attr = attr_dict.float("something", 0.1)
79 self.assertEqual(0.5, attr)
81 def testFloatWithoutAttr(self):
86 attr_dict = AttrDictionary(attrs)
87 attr = attr_dict.float("something", 0.1)
88 self.assertEqual(0.1, attr)
90 def testIntAttr(self):
95 attr_dict = AttrDictionary(attrs)
96 attr = attr_dict.float("something", 1)
97 self.assertEqual(5, attr)
99 def testIntWithoutAttr(self):
104 attr_dict = AttrDictionary(attrs)
105 attr = attr_dict.float("something", 1)
106 self.assertEqual(1, attr)
108 def testTupleAttr(self):
110 "something": "(5,6,7)"
113 attr_dict = AttrDictionary(attrs)
114 a, b, c = attr_dict.tuple("something", int, (1, 2, 3))
115 self.assertEqual(5, a)
116 self.assertEqual(6, b)
117 self.assertEqual(7, c)
119 def testTupleWithoutAttr(self):
121 "something2": "(5,6,7)"
124 attr_dict = AttrDictionary(attrs)
125 a, b, c = attr_dict.tuple("something", int, (1, 2, 3))
126 self.assertEqual(1, a)
127 self.assertEqual(2, b)
128 self.assertEqual(3, c)
130 def testTupleWithEmptyTupleAttr(self):
135 attr_dict = AttrDictionary(attrs)
136 a, b = attr_dict.tuple("something", int, (2, 3))
137 self.assertEqual(2, a)
138 self.assertEqual(3, b)
140 def testTupleWithEmptyListAttr(self):
145 attr_dict = AttrDictionary(attrs)
146 a, b = attr_dict.tuple("something", int, (2, 3))
147 self.assertEqual(2, a)
148 self.assertEqual(3, b)
150 def testListAttr(self):
155 attr_dict = AttrDictionary(attrs)
156 l = attr_dict.list("something", int, [1, 2, 3])
157 self.assertEqual(5, l[0])
158 self.assertEqual(6, l[1])
159 self.assertEqual(7, l[2])
161 def testListWithoutAttr(self):
163 "something2": "5,6,7"
166 attr_dict = AttrDictionary(attrs)
167 l = attr_dict.list("something", int, [1, 2, 3])
168 self.assertEqual(1, l[0])
169 self.assertEqual(2, l[1])
170 self.assertEqual(3, l[2])
172 def testIntWithAttrNone(self):
177 attr_dict = AttrDictionary(attrs)
178 attr = attr_dict.int("something", None)
179 self.assertEqual(None, attr)
182 class TestUtils(unittest.TestCase):
183 @patch('mxnet.nd.load')
184 def test_load_symbol_nodes_from_params(self, mock_nd_load):
185 mock_nd_load.return_value = {'arg:conv0_weight': mx.nd.array([1, 2], dtype='float32'),
186 'arg:conv1_weight': mx.nd.array([2, 3], dtype='float32'),
187 'aux:bn_data_mean': mx.nd.array([5, 6], dtype='float32')}
188 model_params = load_params("model.params")
189 self.assertTrue('conv0_weight' in model_params._param_names)
190 self.assertTrue('conv1_weight' in model_params._param_names)
191 self.assertTrue('bn_data_mean' in model_params._aux_names)
192 self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist())
193 self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
194 self.assertEqual([5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())
196 @patch('mxnet.nd.load')
197 def test_load_symbol_nodes_from_args_nd(self, mock_nd_load):
198 mock_nd_load.return_value = {'conv0_weight': mx.nd.array([1, 2], dtype='float32'),
199 'conv1_weight': mx.nd.array([2, 3], dtype='float32')}
200 model_params = load_params("args_model.nd", data_names=('data1', 'data2'))
201 self.assertTrue('conv0_weight' in model_params._param_names)
202 self.assertTrue('conv1_weight' in model_params._param_names)
203 self.assertEqual([1., 2.], model_params._arg_params['conv0_weight'].asnumpy().tolist())
204 self.assertEqual([2., 3.], model_params._arg_params['conv1_weight'].asnumpy().tolist())
206 @patch('mxnet.nd.load')
207 def test_load_symbol_nodes_from_auxs_nd(self, mock_nd_load):
208 mock_nd_load.return_value = {'bn_data_mean': mx.nd.array([5, 6], dtype='float32')}
209 model_params = load_params("auxs_model.nd")
210 self.assertTrue('bn_data_mean' in model_params._aux_names)
211 self.assertEqual([5., 6.], model_params._aux_params['bn_data_mean'].asnumpy().tolist())