2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.utils.error import Error
20 from mo.utils.simple_proto_parser import SimpleProtoParser
23 # The list of rules how to map the value from the pipeline.config file to the dictionary with attributes.
24 # The rule is either a string or a tuple with two elements. In the first case the rule string is used as a key to
25 # search in the parsed pipeline.config file attributes dictionary and a key to save found value. In the second case the
26 # first element of the tuple is the key to save found value; the second element of the tuple is a string defining the
27 # path to the value of the attribute in the pipeline.config file. The path consists of the regular expression strings
28 # defining the dictionary key to look for separated with a '/' character.
31 # preprocessing block attributes
32 ('resizer_image_height', 'image_resizer/fixed_shape_resizer/height'),
33 ('resizer_image_width', 'image_resizer/fixed_shape_resizer/width'),
34 ('resizer_min_dimension', 'image_resizer/keep_aspect_ratio_resizer/min_dimension'),
35 ('resizer_max_dimension', 'image_resizer/keep_aspect_ratio_resizer/max_dimension'),
36 # anchor generator attributes
37 ('anchor_generator_height', 'first_stage_anchor_generator/grid_anchor_generator/height$', 256),
38 ('anchor_generator_width', 'first_stage_anchor_generator/grid_anchor_generator/width$', 256),
39 ('anchor_generator_height_stride', 'first_stage_anchor_generator/grid_anchor_generator/height_stride', 16),
40 ('anchor_generator_width_stride', 'first_stage_anchor_generator/grid_anchor_generator/width_stride', 16),
41 ('anchor_generator_scales', 'first_stage_anchor_generator/grid_anchor_generator/scales'),
42 ('anchor_generator_aspect_ratios', 'first_stage_anchor_generator/grid_anchor_generator/aspect_ratios'),
43 ('multiscale_anchor_generator_min_level', 'anchor_generator/multiscale_anchor_generator/min_level'),
44 ('multiscale_anchor_generator_max_level', 'anchor_generator/multiscale_anchor_generator/max_level'),
45 ('multiscale_anchor_generator_anchor_scale', 'anchor_generator/multiscale_anchor_generator/anchor_scale'),
46 ('multiscale_anchor_generator_aspect_ratios', 'anchor_generator/multiscale_anchor_generator/aspect_ratios'),
47 ('multiscale_anchor_generator_scales_per_octave', 'anchor_generator/multiscale_anchor_generator/scales_per_octave'),
48 # SSD anchor generator attributes
49 ('ssd_anchor_generator_min_scale', 'anchor_generator/ssd_anchor_generator/min_scale'),
50 ('ssd_anchor_generator_max_scale', 'anchor_generator/ssd_anchor_generator/max_scale'),
51 ('ssd_anchor_generator_num_layers', 'anchor_generator/ssd_anchor_generator/num_layers'),
52 ('ssd_anchor_generator_aspect_ratios', 'anchor_generator/ssd_anchor_generator/aspect_ratios'),
53 ('ssd_anchor_generator_reduce_lowest', 'anchor_generator/ssd_anchor_generator/reduce_boxes_in_lowest_layer'),
54 ('ssd_anchor_generator_base_anchor_height', 'anchor_generator/ssd_anchor_generator/base_anchor_height', 1.0),
55 ('ssd_anchor_generator_base_anchor_width', 'anchor_generator/ssd_anchor_generator/base_anchor_width', 1.0),
56 # Proposal and ROI Pooling layers attributes
57 ('first_stage_nms_score_threshold', '.*_nms_score_threshold'),
58 ('first_stage_nms_iou_threshold', '.*_nms_iou_threshold'),
59 ('first_stage_max_proposals', '.*_max_proposals'),
60 ('num_spatial_bins_height', '.*/rfcn_box_predictor/num_spatial_bins_height'),
61 ('num_spatial_bins_width', '.*/rfcn_box_predictor/num_spatial_bins_width'),
62 ('crop_height', '.*/rfcn_box_predictor/crop_height'),
63 ('crop_width', '.*/rfcn_box_predictor/crop_width'),
65 # Detection Output layer attributes
66 ('postprocessing_score_converter', '.*/score_converter'),
67 ('postprocessing_score_threshold', '.*/batch_non_max_suppression/score_threshold'),
68 ('postprocessing_iou_threshold', '.*/batch_non_max_suppression/iou_threshold'),
69 ('postprocessing_max_detections_per_class', '.*/batch_non_max_suppression/max_detections_per_class'),
70 ('postprocessing_max_total_detections', '.*/batch_non_max_suppression/max_total_detections'),
71 # Variances for predicted bounding box deltas (tx, ty, tw, th)
72 ('frcnn_variance_x', 'box_coder/faster_rcnn_box_coder/x_scale', 10.0),
73 ('frcnn_variance_y', 'box_coder/faster_rcnn_box_coder/y_scale', 10.0),
74 ('frcnn_variance_width', 'box_coder/faster_rcnn_box_coder/width_scale', 5.0),
75 ('frcnn_variance_height', 'box_coder/faster_rcnn_box_coder/height_scale', 5.0)
81 The class that parses pipeline.config files used to generate TF models generated using Object Detection API.
82 The class stores data read from the file in a plain dictionary for easier access using the get_param function.
84 _raw_data_dict = dict()
85 _model_params = dict()
87 def __init__(self, file_name: str):
88 self._raw_data_dict = SimpleProtoParser().parse_file(file_name)
89 if not self._raw_data_dict:
90 raise Error('Failed to parse pipeline.config file {}'.format(file_name))
92 self._initialize_model_params()
95 def _get_value_by_path(params: dict, path: list):
96 if not path or len(path) == 0:
98 if not isinstance(params, dict):
100 compiled_regexp = re.compile(path[0])
101 for key in params.keys():
102 if re.match(compiled_regexp, key):
106 value = __class__._get_value_by_path(params[key], path[1:])
107 if value is not None:
111 def _update_param_using_rule(self, params: dict, rule: [str, tuple]):
112 if isinstance(rule, str):
114 self._model_params[rule] = params[rule]
115 log.debug('Found value "{}" for path "{}"'.format(params[rule], rule))
116 elif isinstance(rule, tuple):
117 if len(rule) != 2 and len(rule) != 3:
118 raise Error('Invalid rule length. Rule must be a tuple with two elements: key and path, or three '
119 'elements: key, path, default_value.')
120 value = __class__._get_value_by_path(params, rule[1].split('/'))
121 if value is not None:
122 log.debug('Found value "{}" for path "{}"'.format(value, rule[1]))
123 self._model_params[rule[0]] = value
125 self._model_params[rule[0]] = rule[2]
126 log.debug('There is no value path "{}". Set default value "{}"'.format(value, rule[2]))
129 raise Error('Invalid rule type. Rule can be either string or tuple')
131 def _initialize_model_params(self):
133 Store global params in the dedicated dictionary self._model_params for easier use.
137 if 'model' not in self._raw_data_dict:
138 raise Error('The "model" key is not found in the configuration file. Looks like the parsed file is not '
139 'Object Detection API model configuration file.')
140 params = list(self._raw_data_dict['model'].values())[0]
141 for rule in mapping_rules:
142 self._update_param_using_rule(params, rule)
144 def get_param(self, param: str):
145 if param not in self._model_params:
147 return self._model_params[param]