Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / pipeline_config.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17 import re
18
19 from mo.utils.error import Error
20 from mo.utils.simple_proto_parser import SimpleProtoParser
21
22
23 # The list of rules how to map the value from the pipeline.config file to the dictionary with attributes.
24 # The rule is either a string or a tuple with two elements. In the first case the rule string is used as a key to
25 # search in the parsed pipeline.config file attributes dictionary and a key to save found value. In the second case the
26 # first element of the tuple is the key to save found value; the second element of the tuple is a string defining the
27 # path to the value of the attribute in the pipeline.config file. The path consists of the regular expression strings
28 # defining the dictionary key to look for separated with a '/' character.
29 mapping_rules = [
30     'num_classes',
31     # preprocessing block attributes
32     ('resizer_image_height', 'image_resizer/fixed_shape_resizer/height'),
33     ('resizer_image_width', 'image_resizer/fixed_shape_resizer/width'),
34     ('resizer_min_dimension', 'image_resizer/keep_aspect_ratio_resizer/min_dimension'),
35     ('resizer_max_dimension', 'image_resizer/keep_aspect_ratio_resizer/max_dimension'),
36     # anchor generator attributes
37     ('anchor_generator_height', 'first_stage_anchor_generator/grid_anchor_generator/height$', 256),
38     ('anchor_generator_width', 'first_stage_anchor_generator/grid_anchor_generator/width$', 256),
39     ('anchor_generator_height_stride', 'first_stage_anchor_generator/grid_anchor_generator/height_stride', 16),
40     ('anchor_generator_width_stride', 'first_stage_anchor_generator/grid_anchor_generator/width_stride', 16),
41     ('anchor_generator_scales', 'first_stage_anchor_generator/grid_anchor_generator/scales'),
42     ('anchor_generator_aspect_ratios', 'first_stage_anchor_generator/grid_anchor_generator/aspect_ratios'),
43     ('multiscale_anchor_generator_min_level', 'anchor_generator/multiscale_anchor_generator/min_level'),
44     ('multiscale_anchor_generator_max_level', 'anchor_generator/multiscale_anchor_generator/max_level'),
45     ('multiscale_anchor_generator_anchor_scale', 'anchor_generator/multiscale_anchor_generator/anchor_scale'),
46     ('multiscale_anchor_generator_aspect_ratios', 'anchor_generator/multiscale_anchor_generator/aspect_ratios'),
47     ('multiscale_anchor_generator_scales_per_octave', 'anchor_generator/multiscale_anchor_generator/scales_per_octave'),
48     # SSD anchor generator attributes
49     ('ssd_anchor_generator_min_scale', 'anchor_generator/ssd_anchor_generator/min_scale'),
50     ('ssd_anchor_generator_max_scale', 'anchor_generator/ssd_anchor_generator/max_scale'),
51     ('ssd_anchor_generator_num_layers', 'anchor_generator/ssd_anchor_generator/num_layers'),
52     ('ssd_anchor_generator_aspect_ratios', 'anchor_generator/ssd_anchor_generator/aspect_ratios'),
53     ('ssd_anchor_generator_reduce_lowest', 'anchor_generator/ssd_anchor_generator/reduce_boxes_in_lowest_layer'),
54     ('ssd_anchor_generator_base_anchor_height', 'anchor_generator/ssd_anchor_generator/base_anchor_height', 1.0),
55     ('ssd_anchor_generator_base_anchor_width', 'anchor_generator/ssd_anchor_generator/base_anchor_width', 1.0),
56     # Proposal and ROI Pooling layers attributes
57     ('first_stage_nms_score_threshold', '.*_nms_score_threshold'),
58     ('first_stage_nms_iou_threshold', '.*_nms_iou_threshold'),
59     ('first_stage_max_proposals', '.*_max_proposals'),
60     ('num_spatial_bins_height', '.*/rfcn_box_predictor/num_spatial_bins_height'),
61     ('num_spatial_bins_width', '.*/rfcn_box_predictor/num_spatial_bins_width'),
62     ('crop_height', '.*/rfcn_box_predictor/crop_height'),
63     ('crop_width', '.*/rfcn_box_predictor/crop_width'),
64     'initial_crop_size',
65     # Detection Output layer attributes
66     ('postprocessing_score_converter', '.*/score_converter'),
67     ('postprocessing_score_threshold', '.*/batch_non_max_suppression/score_threshold'),
68     ('postprocessing_iou_threshold', '.*/batch_non_max_suppression/iou_threshold'),
69     ('postprocessing_max_detections_per_class', '.*/batch_non_max_suppression/max_detections_per_class'),
70     ('postprocessing_max_total_detections', '.*/batch_non_max_suppression/max_total_detections'),
71     # Variances for predicted bounding box deltas (tx, ty, tw, th)
72     ('frcnn_variance_x', 'box_coder/faster_rcnn_box_coder/x_scale', 10.0),
73     ('frcnn_variance_y', 'box_coder/faster_rcnn_box_coder/y_scale', 10.0),
74     ('frcnn_variance_width', 'box_coder/faster_rcnn_box_coder/width_scale', 5.0),
75     ('frcnn_variance_height', 'box_coder/faster_rcnn_box_coder/height_scale', 5.0)
76 ]
77
78
79 class PipelineConfig:
80     """
81     The class that parses pipeline.config files used to generate TF models generated using Object Detection API.
82     The class stores data read from the file in a plain dictionary for easier access using the get_param function.
83     """
84     _raw_data_dict = dict()
85     _model_params = dict()
86
87     def __init__(self, file_name: str):
88         self._raw_data_dict = SimpleProtoParser().parse_file(file_name)
89         if not self._raw_data_dict:
90             raise Error('Failed to parse pipeline.config file {}'.format(file_name))
91
92         self._initialize_model_params()
93
94     @staticmethod
95     def _get_value_by_path(params: dict, path: list):
96         if not path or len(path) == 0:
97             return None
98         if not isinstance(params, dict):
99             return None
100         compiled_regexp = re.compile(path[0])
101         for key in params.keys():
102             if re.match(compiled_regexp, key):
103                 if len(path) == 1:
104                     return params[key]
105                 else:
106                     value = __class__._get_value_by_path(params[key], path[1:])
107                     if value is not None:
108                         return value
109         return None
110
111     def _update_param_using_rule(self, params: dict, rule: [str, tuple]):
112         if isinstance(rule, str):
113             if rule in params:
114                 self._model_params[rule] = params[rule]
115                 log.debug('Found value "{}" for path "{}"'.format(params[rule], rule))
116         elif isinstance(rule, tuple):
117             if len(rule) != 2 and len(rule) != 3:
118                 raise Error('Invalid rule length. Rule must be a tuple with two elements: key and path, or three '
119                             'elements: key, path, default_value.')
120             value = __class__._get_value_by_path(params, rule[1].split('/'))
121             if value is not None:
122                 log.debug('Found value "{}" for path "{}"'.format(value, rule[1]))
123                 self._model_params[rule[0]] = value
124             elif len(rule) == 3:
125                 self._model_params[rule[0]] = rule[2]
126                 log.debug('There is no value path "{}". Set default value "{}"'.format(value, rule[2]))
127
128         else:
129             raise Error('Invalid rule type. Rule can be either string or tuple')
130
131     def _initialize_model_params(self):
132         """
133         Store global params in the dedicated dictionary self._model_params for easier use.
134         :return: None
135         """
136
137         if 'model' not in self._raw_data_dict:
138             raise Error('The "model" key is not found in the configuration file. Looks like the parsed file is not '
139                         'Object Detection API model configuration file.')
140         params = list(self._raw_data_dict['model'].values())[0]
141         for rule in mapping_rules:
142             self._update_param_using_rule(params, rule)
143
144     def get_param(self, param: str):
145         if param not in self._model_params:
146             return None
147         return self._model_params[param]