Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / simple_proto_parser_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import os
18 import tempfile
19 import unittest
20
21 from mo.utils.simple_proto_parser import SimpleProtoParser
22
23 correct_proto_message_1 = 'model { faster_rcnn { num_classes: 90 image_resizer { keep_aspect_ratio_resizer {' \
24                           ' min_dimension: 600  max_dimension: 1024 }}}}'
25
26 correct_proto_message_2 = '    first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
27                           ' 16 scales: 0.25 scales: 0.5 scales: 1.0 scales: 2.0  aspect_ratios: 0.5 aspect_ratios:' \
28                           ' 1.0 aspect_ratios: 2.0}}'
29
30 correct_proto_message_3 = '  initializer \n{variance_scaling_initializer \n{\nfactor: 1.0 uniform: true bla: false ' \
31                           'mode: FAN_AVG}}'
32
33 correct_proto_message_4 = 'train_input_reader {label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"' \
34                           ' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/  mscoco_train.record" }}'
35
36 correct_proto_message_5 = '  initializer \n  # abc \n{variance_scaling_initializer \n{\nfactor: 1.0 \n  # sd ' \
37                           '\nuniform: true bla: false mode: FAN_AVG}}'
38
39 correct_proto_message_6 = '    first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
40                           ' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: 0.5 aspect_ratios:' \
41                           ' 1.0 aspect_ratios: 2.0}}'
42
43 correct_proto_message_7 = '    first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
44                           ' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: [] }}'
45
46 correct_proto_message_8 = 'model {good_list: [3.0, 5.0, ]}'
47
48 correct_proto_message_9 = '    first_stage_anchor_generator {grid_anchor_generator {height_stride: 16, width_stride:' \
49                           ' 16 scales: [ 0.25, 0.5, 1.0, 2.0], aspect_ratios: [] }}'
50
51 correct_proto_message_10 = 'train_input_reader {label_map_path: "C:\mscoco_label_map.pbtxt"' \
52                            ' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/  mscoco_train.record" }}'
53
54 correct_proto_message_11 = 'model {path: "C:\[{],}" other_value: [1, 2, 3, 4]}'
55
56 incorrect_proto_message_1 = 'model { bad_no_value }'
57
58 incorrect_proto_message_2 = 'model { abc: 3 { }'
59
60 incorrect_proto_message_3 = 'model { too_many_values: 3 4 }'
61
62 incorrect_proto_message_4 = 'model { missing_values: '
63
64 incorrect_proto_message_5 = 'model { missing_values: aa bb : }'
65
66 incorrect_proto_message_6 = 'model : '
67
68 incorrect_proto_message_7 = 'model : {bad_list: [3.0, 4, , 4.0]}'
69
70
71 class TestingSimpleProtoParser(unittest.TestCase):
72     def test_correct_proto_reader_from_string_1(self):
73         result = SimpleProtoParser().parse_from_string(correct_proto_message_1)
74         expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
75             'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
76         self.assertDictEqual(result, expected_result)
77
78     def test_correct_proto_reader_from_string_2(self):
79         result = SimpleProtoParser().parse_from_string(correct_proto_message_2)
80         expected_result = {'first_stage_anchor_generator': {
81             'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
82                                       'aspect_ratios': [0.5, 1.0, 2.0]}}}
83         self.assertDictEqual(result, expected_result)
84
85     def test_correct_proto_reader_from_string_3(self):
86         result = SimpleProtoParser().parse_from_string(correct_proto_message_3)
87         expected_result = {
88             'initializer': {
89                 'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
90         self.assertDictEqual(result, expected_result)
91
92     def test_correct_proto_reader_from_string_4(self):
93         result = SimpleProtoParser().parse_from_string(correct_proto_message_4)
94         expected_result = {
95             'train_input_reader': {'label_map_path': "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt",
96                                    'tf_record_input_reader': {
97                                        'input_path': "PATH_TO_BE_CONFIGURED/  mscoco_train.record"}}}
98         self.assertDictEqual(result, expected_result)
99
100     def test_correct_proto_reader_from_string_with_comments(self):
101         result = SimpleProtoParser().parse_from_string(correct_proto_message_5)
102         expected_result = {
103             'initializer': {
104                 'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
105         self.assertDictEqual(result, expected_result)
106
107     def test_correct_proto_reader_from_string_with_lists(self):
108         result = SimpleProtoParser().parse_from_string(correct_proto_message_6)
109         expected_result = {'first_stage_anchor_generator': {
110             'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
111                                       'aspect_ratios': [0.5, 1.0, 2.0]}}}
112         self.assertDictEqual(result, expected_result)
113
114     def test_correct_proto_reader_from_string_with_empty_list(self):
115         result = SimpleProtoParser().parse_from_string(correct_proto_message_7)
116         expected_result = {'first_stage_anchor_generator': {
117             'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
118                                       'aspect_ratios': []}}}
119         self.assertDictEqual(result, expected_result)
120
121     def test_correct_proto_reader_from_string_with_comma_trailing_list(self):
122         result = SimpleProtoParser().parse_from_string(correct_proto_message_8)
123         expected_result = {'model': {'good_list': [3.0, 5.0]}}
124         self.assertDictEqual(result, expected_result)
125
126     def test_correct_proto_reader_from_string_with_redundant_commas(self):
127         result = SimpleProtoParser().parse_from_string(correct_proto_message_9)
128         expected_result = {'first_stage_anchor_generator': {
129             'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
130                                       'aspect_ratios': []}}}
131         self.assertDictEqual(result, expected_result)
132
133     def test_correct_proto_reader_from_string_with_windows_path(self):
134         result = SimpleProtoParser().parse_from_string(correct_proto_message_10)
135         expected_result = {
136             'train_input_reader': {'label_map_path': "C:\mscoco_label_map.pbtxt",
137                                    'tf_record_input_reader': {
138                                        'input_path': "PATH_TO_BE_CONFIGURED/  mscoco_train.record"}}}
139         self.assertDictEqual(result, expected_result)
140
141     def test_correct_proto_reader_from_string_with_special_characters_in_string(self):
142         result = SimpleProtoParser().parse_from_string(correct_proto_message_11)
143         expected_result = {'model': {'path': "C:\[{],}",
144                                      'other_value': [1, 2, 3, 4]}}
145         self.assertDictEqual(result, expected_result)
146
147     def test_incorrect_proto_reader_from_string_1(self):
148         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_1)
149         self.assertIsNone(result)
150
151     def test_incorrect_proto_reader_from_string_2(self):
152         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_2)
153         self.assertIsNone(result)
154
155     def test_incorrect_proto_reader_from_string_3(self):
156         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_3)
157         self.assertIsNone(result)
158
159     def test_incorrect_proto_reader_from_string_4(self):
160         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_4)
161         self.assertIsNone(result)
162
163     def test_incorrect_proto_reader_from_string_5(self):
164         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_5)
165         self.assertIsNone(result)
166
167     def test_incorrect_proto_reader_from_string_6(self):
168         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_6)
169         self.assertIsNone(result)
170
171     def test_incorrect_proto_reader_from_string_7(self):
172         result = SimpleProtoParser().parse_from_string(incorrect_proto_message_7)
173         self.assertIsNone(result)
174
175     def test_correct_proto_reader_from_file(self):
176         file = tempfile.NamedTemporaryFile('wt', delete=False)
177         file.write(correct_proto_message_1)
178         file_name = file.name
179         file.close()
180
181         result = SimpleProtoParser().parse_file(file_name)
182         expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
183             'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
184         self.assertDictEqual(result, expected_result)
185         os.unlink(file_name)
186
187     def test_proto_reader_from_non_readable_file(self):
188         file = tempfile.NamedTemporaryFile('wt', delete=False)
189         file.write(correct_proto_message_1)
190         file_name = file.name
191         file.close()
192         os.chmod(file_name, 0000)
193
194         result = SimpleProtoParser().parse_file(file_name)
195         self.assertIsNone(result)
196         os.unlink(file_name)
197
198     def test_proto_reader_from_non_existing_file(self):
199         result = SimpleProtoParser().parse_file('/non/existing/file')
200         self.assertIsNone(result)