2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.utils.simple_proto_parser import SimpleProtoParser
23 correct_proto_message_1 = 'model { faster_rcnn { num_classes: 90 image_resizer { keep_aspect_ratio_resizer {' \
24 ' min_dimension: 600 max_dimension: 1024 }}}}'
26 correct_proto_message_2 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
27 ' 16 scales: 0.25 scales: 0.5 scales: 1.0 scales: 2.0 aspect_ratios: 0.5 aspect_ratios:' \
28 ' 1.0 aspect_ratios: 2.0}}'
30 correct_proto_message_3 = ' initializer \n{variance_scaling_initializer \n{\nfactor: 1.0 uniform: true bla: false ' \
33 correct_proto_message_4 = 'train_input_reader {label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"' \
34 ' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/ mscoco_train.record" }}'
36 correct_proto_message_5 = ' initializer \n # abc \n{variance_scaling_initializer \n{\nfactor: 1.0 \n # sd ' \
37 '\nuniform: true bla: false mode: FAN_AVG}}'
39 correct_proto_message_6 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
40 ' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: 0.5 aspect_ratios:' \
41 ' 1.0 aspect_ratios: 2.0}}'
43 correct_proto_message_7 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
44 ' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: [] }}'
46 correct_proto_message_8 = 'model {good_list: [3.0, 5.0, ]}'
48 correct_proto_message_9 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16, width_stride:' \
49 ' 16 scales: [ 0.25, 0.5, 1.0, 2.0], aspect_ratios: [] }}'
51 correct_proto_message_10 = 'train_input_reader {label_map_path: "C:\mscoco_label_map.pbtxt"' \
52 ' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/ mscoco_train.record" }}'
54 correct_proto_message_11 = 'model {path: "C:\[{],}" other_value: [1, 2, 3, 4]}'
56 incorrect_proto_message_1 = 'model { bad_no_value }'
58 incorrect_proto_message_2 = 'model { abc: 3 { }'
60 incorrect_proto_message_3 = 'model { too_many_values: 3 4 }'
62 incorrect_proto_message_4 = 'model { missing_values: '
64 incorrect_proto_message_5 = 'model { missing_values: aa bb : }'
66 incorrect_proto_message_6 = 'model : '
68 incorrect_proto_message_7 = 'model : {bad_list: [3.0, 4, , 4.0]}'
71 class TestingSimpleProtoParser(unittest.TestCase):
72 def test_correct_proto_reader_from_string_1(self):
73 result = SimpleProtoParser().parse_from_string(correct_proto_message_1)
74 expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
75 'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
76 self.assertDictEqual(result, expected_result)
78 def test_correct_proto_reader_from_string_2(self):
79 result = SimpleProtoParser().parse_from_string(correct_proto_message_2)
80 expected_result = {'first_stage_anchor_generator': {
81 'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
82 'aspect_ratios': [0.5, 1.0, 2.0]}}}
83 self.assertDictEqual(result, expected_result)
85 def test_correct_proto_reader_from_string_3(self):
86 result = SimpleProtoParser().parse_from_string(correct_proto_message_3)
89 'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
90 self.assertDictEqual(result, expected_result)
92 def test_correct_proto_reader_from_string_4(self):
93 result = SimpleProtoParser().parse_from_string(correct_proto_message_4)
95 'train_input_reader': {'label_map_path': "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt",
96 'tf_record_input_reader': {
97 'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record"}}}
98 self.assertDictEqual(result, expected_result)
100 def test_correct_proto_reader_from_string_with_comments(self):
101 result = SimpleProtoParser().parse_from_string(correct_proto_message_5)
104 'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
105 self.assertDictEqual(result, expected_result)
107 def test_correct_proto_reader_from_string_with_lists(self):
108 result = SimpleProtoParser().parse_from_string(correct_proto_message_6)
109 expected_result = {'first_stage_anchor_generator': {
110 'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
111 'aspect_ratios': [0.5, 1.0, 2.0]}}}
112 self.assertDictEqual(result, expected_result)
114 def test_correct_proto_reader_from_string_with_empty_list(self):
115 result = SimpleProtoParser().parse_from_string(correct_proto_message_7)
116 expected_result = {'first_stage_anchor_generator': {
117 'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
118 'aspect_ratios': []}}}
119 self.assertDictEqual(result, expected_result)
121 def test_correct_proto_reader_from_string_with_comma_trailing_list(self):
122 result = SimpleProtoParser().parse_from_string(correct_proto_message_8)
123 expected_result = {'model': {'good_list': [3.0, 5.0]}}
124 self.assertDictEqual(result, expected_result)
126 def test_correct_proto_reader_from_string_with_redundant_commas(self):
127 result = SimpleProtoParser().parse_from_string(correct_proto_message_9)
128 expected_result = {'first_stage_anchor_generator': {
129 'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
130 'aspect_ratios': []}}}
131 self.assertDictEqual(result, expected_result)
133 def test_correct_proto_reader_from_string_with_windows_path(self):
134 result = SimpleProtoParser().parse_from_string(correct_proto_message_10)
136 'train_input_reader': {'label_map_path': "C:\mscoco_label_map.pbtxt",
137 'tf_record_input_reader': {
138 'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record"}}}
139 self.assertDictEqual(result, expected_result)
141 def test_correct_proto_reader_from_string_with_special_characters_in_string(self):
142 result = SimpleProtoParser().parse_from_string(correct_proto_message_11)
143 expected_result = {'model': {'path': "C:\[{],}",
144 'other_value': [1, 2, 3, 4]}}
145 self.assertDictEqual(result, expected_result)
147 def test_incorrect_proto_reader_from_string_1(self):
148 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_1)
149 self.assertIsNone(result)
151 def test_incorrect_proto_reader_from_string_2(self):
152 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_2)
153 self.assertIsNone(result)
155 def test_incorrect_proto_reader_from_string_3(self):
156 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_3)
157 self.assertIsNone(result)
159 def test_incorrect_proto_reader_from_string_4(self):
160 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_4)
161 self.assertIsNone(result)
163 def test_incorrect_proto_reader_from_string_5(self):
164 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_5)
165 self.assertIsNone(result)
167 def test_incorrect_proto_reader_from_string_6(self):
168 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_6)
169 self.assertIsNone(result)
171 def test_incorrect_proto_reader_from_string_7(self):
172 result = SimpleProtoParser().parse_from_string(incorrect_proto_message_7)
173 self.assertIsNone(result)
175 def test_correct_proto_reader_from_file(self):
176 file = tempfile.NamedTemporaryFile('wt', delete=False)
177 file.write(correct_proto_message_1)
178 file_name = file.name
181 result = SimpleProtoParser().parse_file(file_name)
182 expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
183 'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
184 self.assertDictEqual(result, expected_result)
187 def test_proto_reader_from_non_readable_file(self):
188 file = tempfile.NamedTemporaryFile('wt', delete=False)
189 file.write(correct_proto_message_1)
190 file_name = file.name
192 os.chmod(file_name, 0000)
194 result = SimpleProtoParser().parse_file(file_name)
195 self.assertIsNone(result)
198 def test_proto_reader_from_non_existing_file(self):
199 result = SimpleProtoParser().parse_file('/non/existing/file')
200 self.assertIsNone(result)