2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from collections import OrderedDict
22 from itertools import zip_longest
26 from mo.front.extractor import split_node_in_port
27 from mo.utils import import_extensions
28 from mo.utils.error import Error
29 from mo.utils.utils import refer_to_faq_msg
32 class CanonicalizePathAction(argparse.Action):
34 Expand user home directory paths and convert relative-paths to absolute.
37 def __call__(self, parser, namespace, values, option_string=None):
38 if values is not None:
39 list_of_values = list()
40 if isinstance(values, str):
42 list_of_values = values.split(',')
43 elif isinstance(values, list):
44 list_of_values = values
46 raise Error('Unsupported type of command line parameter "{}" value'.format(self.dest))
47 list_of_values = [get_absolute_path(path) for path in list_of_values]
48 setattr(namespace, self.dest, ','.join(list_of_values))
51 class CanonicalizePathCheckExistenceAction(CanonicalizePathAction):
53 Expand user home directory paths and convert relative-paths to absolute and check specified file or directory
57 def __call__(self, parser, namespace, values, option_string=None):
58 super().__call__(parser, namespace, values, option_string)
59 names = getattr(namespace, self.dest)
60 for name in names.split(','):
61 if name != "" and not os.path.exists(name):
62 raise Error('The value for command line parameter "{}" must be existing file/directory, '
63 ' but "{}" does not exist.'.format(self.dest, name))
66 def readable_file(path: str):
68 Check that specified path is a readable file.
69 :param path: path to check
70 :return: path if the file is readable
72 if not os.path.isfile(path):
73 raise Error('The "{}" is not existing file'.format(path))
74 elif not os.access(path, os.R_OK):
75 raise Error('The "{}" is not readable'.format(path))
80 def readable_dirs(paths: str):
82 Checks that comma separated list of paths are readable directories.
83 :param paths: comma separated list of paths.
84 :return: comma separated list of paths.
86 paths_list = [readable_dir(path) for path in paths.split(',')]
87 return ','.join(paths_list)
90 def readable_dirs_or_empty(paths: str):
92 Checks that comma separated list of paths are readable directories of if it is empty.
93 :param paths: comma separated list of paths.
94 :return: comma separated list of paths.
97 return readable_dirs(paths)
101 def readable_dir(path: str):
103 Check that specified path is a readable directory.
104 :param path: path to check
105 :return: path if the directory is readable
107 if not os.path.isdir(path):
108 raise Error('The "{}" is not existing directory'.format(path))
109 elif not os.access(path, os.R_OK):
110 raise Error('The "{}" is not readable'.format(path))
115 def writable_dir(path: str):
117 Checks that specified directory is writable. The directory may not exist but it's parent or grandparent must exist.
118 :param path: path to check that it is writable.
119 :return: path if it is writable
122 raise Error('The directory parameter is None')
123 if os.path.exists(path):
124 if os.path.isdir(path):
125 if os.access(path, os.W_OK):
128 raise Error('The directory "{}" is not writable'.format(path))
130 raise Error('The "{}" is not a directory'.format(path))
133 while os.path.dirname(cur_path) != cur_path:
134 if os.path.exists(cur_path):
136 cur_path = os.path.dirname(cur_path)
138 cur_path = os.path.curdir
139 if os.access(cur_path, os.W_OK):
142 raise Error('The directory "{}" is not writable'.format(cur_path))
145 def get_common_cli_parser(parser: argparse.ArgumentParser = None):
147 parser = argparse.ArgumentParser()
148 common_group = parser.add_argument_group('Framework-agnostic parameters')
150 common_group.add_argument('--input_model', '-w', '-m',
151 help='Tensorflow*: a file with a pre-trained model ' +
152 ' (binary or text .pb file after freezing).\n' +
153 ' Caffe*: a model proto file with model weights',
154 action=CanonicalizePathCheckExistenceAction,
156 common_group.add_argument('--model_name', '-n',
157 help='Model_name parameter passed to the final create_ir transform. ' +
158 'This parameter is used to name ' +
159 'a network in a generated IR and output .xml/.bin files.')
160 common_group.add_argument('--output_dir', '-o',
161 help='Directory that stores the generated IR. ' +
162 'By default, it is the directory from where the Model Optimizer is launched.',
163 default=get_absolute_path('.'),
164 action=CanonicalizePathAction,
166 common_group.add_argument('--input_shape',
167 help='Input shape(s) that should be fed to an input node(s) of the model. '
168 'Shape is defined as a comma-separated list of integer numbers enclosed in '
169 'parentheses or square brackets, for example [1,3,227,227] or (1,227,227,3), where '
170 'the order of dimensions depends on the framework input layout of the model. '
171 'For example, [N,C,H,W] is used for Caffe* models and [N,H,W,C] for TensorFlow* '
172 'models. Model Optimizer performs necessary transformations to convert the shape to '
173 'the layout required by Inference Engine (N,C,H,W). The shape should not contain '
174 'undefined dimensions (? or -1) and should fit the dimensions defined in the input '
175 'operation of the graph. If there are multiple inputs in the model, --input_shape '
176 'should contain definition of shape for each input separated by a comma, for '
177 'example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes.')
178 common_group.add_argument('--scale', '-s',
180 help='All input values coming from original network inputs will be ' +
182 'value. When a list of inputs is overridden by the --input ' +
183 'parameter, this scale ' +
184 'is not applied for any input that does not match with ' +
185 'the original input of the model.')
186 common_group.add_argument('--reverse_input_channels',
187 help='Switch the input channels order from RGB to BGR (or vice versa). Applied to '
188 'original inputs of the model if and only if a number of channels equals 3. Applied '
189 'after application of --mean_values and --scale_values options, so numbers in '
190 '--mean_values and --scale_values go in the order of channels used in the original '
193 common_group.add_argument('--log_level',
195 choices=['CRITICAL', 'ERROR', 'WARN', 'WARNING', 'INFO',
198 common_group.add_argument('--input',
199 help='The name of the input operation of the given model. ' +
200 'Usually this is a name of the ' +
201 'input placeholder of the model.')
202 common_group.add_argument('--output',
203 help='The name of the output operation of the model. ' +
204 'For TensorFlow*, do not add :0 to this name.')
205 common_group.add_argument('--mean_values', '-ms',
206 help='Mean values to be used for the input image per channel. ' +
207 'Values to be provided in the (R,G,B) or [R,G,B] format. ' +
208 'Can be defined for desired input of the model, for example: ' +
209 '"--mean_values data[255,255,255],info[255,255,255]". ' +
210 'The exact meaning and order ' +
211 'of channels depend on how the original model was trained.',
213 common_group.add_argument('--scale_values',
214 help='Scale values to be used for the input image per channel. ' +
215 'Values are provided in the (R,G,B) or [R,G,B] format. ' +
216 'Can be defined for desired input of the model, for example: ' +
217 '"--scale_values data[255,255,255],info[255,255,255]". ' +
218 'The exact meaning and order ' +
219 'of channels depend on how the original model was trained.',
221 # TODO: isn't it a weights precision type
222 common_group.add_argument('--data_type',
223 help='Data type for all intermediate tensors and weights. ' +
224 'If original model is in FP32 and --data_type=FP16 is specified, all model weights ' +
225 'and biases are quantized to FP16.',
226 choices=["FP16", "FP32", "half", "float"],
228 common_group.add_argument('--disable_fusing',
229 help='Turn off fusing of linear operations to Convolution',
231 common_group.add_argument('--disable_resnet_optimization',
232 help='Turn off resnet optimization',
234 common_group.add_argument('--finegrain_fusing',
235 help='Regex for layers/operations that won\'t be fused. ' +
236 'Example: --finegrain_fusing Convolution1,.*Scale.*')
237 common_group.add_argument('--disable_gfusing',
238 help='Turn off fusing of grouped convolutions',
240 common_group.add_argument('--enable_concat_optimization',
241 help='Turn on concat optimization',
243 common_group.add_argument('--move_to_preprocess',
244 help='Move mean values to IR preprocess section',
246 # we use CanonicalizeDirCheckExistenceAction instead of readable_dirs to handle empty strings
247 common_group.add_argument("--extensions",
248 help="Directory or a comma separated list of directories with extensions. To disable all "
249 "extensions including those that are placed at the default location, pass an empty "
251 default=import_extensions.default_path(),
252 action=CanonicalizePathCheckExistenceAction,
253 type=readable_dirs_or_empty)
254 common_group.add_argument("--batch", "-b",
257 help="Input batch size")
258 common_group.add_argument("--version",
260 help="Version of Model Optimizer")
262 common_group.add_argument('--silent',
263 help='Prevent any output messages except those that correspond to log level equals '
264 'ERROR, that can be set with the following option: --log_level. '
265 'By default, log level is already ERROR. ',
268 common_group.add_argument('--freeze_placeholder_with_value', help='Replaces input layer with constant node with '
269 'provided value, e.g.: "node_name->True"',
271 common_group.add_argument('--generate_deprecated_IR_V2',
272 help='Force to generate legacy/deprecated IR V2 to work with previous versions of the'
273 ' Inference Engine. The resulting IR may or may not be correctly loaded by'
274 ' Inference Engine API (including the most recent and old versions of Inference'
275 ' Engine) and provided as a partially-validated backup option for specific'
276 ' deployment scenarios. Use it at your own discretion. By default, without this'
277 ' option, the Model Optimizer generates IR V3.',
279 common_group.add_argument('--keep_shape_ops',
280 help='[ Experimental feature ] Enables `Shape` operation with all children keeping. '
281 'This feature makes model reshapable in Inference Engine',
282 action='store_true', default=False)
286 def get_common_cli_options(model_name):
288 d['input_model'] = '- Path to the Input Model'
289 d['output_dir'] = ['- Path for generated IR', lambda x: x if x != '.' else os.getcwd()]
290 d['model_name'] = ['- IR output name', lambda x: x if x else model_name]
291 d['log_level'] = '- Log level'
292 d['batch'] = ['- Batch', lambda x: x if x else 'Not specified, inherited from the model']
293 d['input'] = ['- Input layers', lambda x: x if x else 'Not specified, inherited from the model']
294 d['output'] = ['- Output layers', lambda x: x if x else 'Not specified, inherited from the model']
295 d['input_shape'] = ['- Input shapes', lambda x: x if x else 'Not specified, inherited from the model']
296 d['mean_values'] = ['- Mean values', lambda x: x if x else 'Not specified']
297 d['scale_values'] = ['- Scale values', lambda x: x if x else 'Not specified']
298 d['scale'] = ['- Scale factor', lambda x: x if x else 'Not specified']
299 d['data_type'] = ['- Precision of IR', lambda x: 'FP32' if x == 'float' else 'FP16' if x == 'half' else x]
300 d['disable_fusing'] = ['- Enable fusing', lambda x: not x]
301 d['disable_gfusing'] = ['- Enable grouped convolutions fusing', lambda x: not x]
302 d['move_to_preprocess'] = '- Move mean values to preprocess section'
303 d['reverse_input_channels'] = '- Reverse input channels'
307 def get_caffe_cli_options():
309 'input_proto': ['- Path to the Input prototxt', lambda x: x],
310 'mean_file': ['- Path to a mean file', lambda x: x if x else 'Not specified'],
311 'mean_file_offsets': ['- Offsets for a mean file', lambda x: x if x else 'Not specified'],
312 'k': '- Path to CustomLayersMapping.xml',
313 'disable_resnet_optimization': ['- Enable resnet optimization', lambda x: not x],
316 return OrderedDict(sorted(d.items(), key=lambda t: t[0]))
319 def get_tf_cli_options():
321 'input_model_is_text': '- Input model in text protobuf format',
322 'tensorflow_subgraph_patterns': '- Patterns to offload',
323 'tensorflow_operation_patterns': '- Operations to offload',
324 'tensorflow_custom_operations_config_update': '- Update the configuration file with input/output node names',
325 'tensorflow_use_custom_operations_config': '- Use the config file',
326 'tensorflow_object_detection_api_pipeline_config': '- Use configuration file used to generate the model with '
327 'Object Detection API',
328 'tensorflow_custom_layer_libraries': '- List of shared libraries with TensorFlow custom layers implementation',
329 'tensorboard_logdir': '- Path to model dump for TensorBoard'
332 return OrderedDict(sorted(d.items(), key=lambda t: t[0]))
335 def get_mxnet_cli_options():
337 'input_symbol': '- Deploy-ready symbol file',
338 'nd_prefix_name': '- Prefix name for args.nd and argx.nd files',
339 'pretrained_model_name': '- Pretrained model to be merged with the .nd files',
340 'save_params_from_nd': '- Enable saving built parameters file from .nd files',
341 'legacy_mxnet_model': '- Enable MXNet loader for models trained with MXNet version lower than 1.0.0'
344 return OrderedDict(sorted(d.items(), key=lambda t: t[0]))
347 def get_kaldi_cli_options():
349 'counts': '- A file name with full path to the counts file',
350 'remove_output_softmax': '- Removes the SoftMax layer that is the output layer'
353 return OrderedDict(sorted(d.items(), key=lambda t: t[0]))
356 def get_onnx_cli_options():
360 return OrderedDict(sorted(d.items(), key=lambda t: t[0]))
363 def get_caffe_cli_parser(parser: argparse.ArgumentParser = None):
365 Specifies cli arguments for Model Optimizer for Caffe*
369 ArgumentParser instance
372 parser = argparse.ArgumentParser()
373 get_common_cli_parser(parser=parser)
375 caffe_group = parser.add_argument_group('Caffe*-specific parameters')
377 caffe_group.add_argument('--input_proto', '-d',
378 help='Deploy-ready prototxt file that contains a topology structure ' +
379 'and layer attributes',
381 action=CanonicalizePathCheckExistenceAction)
382 caffe_group.add_argument('-k',
383 help='Path to CustomLayersMapping.xml to register custom layers',
385 default=os.path.join(os.path.dirname(sys.argv[0]), 'extensions', 'front', 'caffe',
386 'CustomLayersMapping.xml'),
387 action=CanonicalizePathCheckExistenceAction)
388 caffe_group.add_argument('--mean_file', '-mf',
389 help='Mean image to be used for the input. Should be a binaryproto file',
391 action=CanonicalizePathCheckExistenceAction)
392 caffe_group.add_argument('--mean_file_offsets', '-mo',
393 help='Mean image offsets to be used for the input binaryproto file. ' +
394 'When the mean image is bigger than the expected input, it is cropped. By default, centers ' +
395 'of the input image and the mean image are the same and the mean image is cropped by ' +
396 'dimensions of the input image. The format to pass this option is the following: "-mo (x,y)". In this ' +
397 'case, the mean file is cropped by dimensions of the input image with offset (x,y) ' +
398 'from the upper left corner of the mean image',
400 caffe_group.add_argument('--disable_omitting_optional',
401 help='Disable omitting optional attributes to be used for custom layers. ' +
402 'Use this option if you want to transfer all attributes of a custom layer to IR. ' +
403 'Default behavior is to transfer the attributes with default values and the attributes defined by the user to IR.',
406 caffe_group.add_argument('--enable_flattening_nested_params',
407 help='Enable flattening optional params to be used for custom layers. ' +
408 'Use this option if you want to transfer attributes of a custom layer to IR with flattened nested parameters. ' +
409 'Default behavior is to transfer the attributes without flattening nested parameters.',
415 def get_tf_cli_parser(parser: argparse.ArgumentParser = None):
417 Specifies cli arguments for Model Optimizer for TF
421 ArgumentParser instance
424 parser = argparse.ArgumentParser()
425 get_common_cli_parser(parser=parser)
427 tf_group = parser.add_argument_group('TensorFlow*-specific parameters')
428 tf_group.add_argument('--input_model_is_text',
429 help='TensorFlow*: treat the input model file as a text protobuf format. If not specified, ' +
430 'the Model Optimizer treats it as a binary file by default.',
432 tf_group.add_argument('--input_checkpoint', type=str, default=None, help="TensorFlow*: variables file to load.",
433 action=CanonicalizePathCheckExistenceAction)
434 tf_group.add_argument('--input_meta_graph',
435 help='Tensorflow*: a file with a meta-graph of the model before freezing',
436 action=CanonicalizePathCheckExistenceAction,
438 tf_group.add_argument('--saved_model_dir', default=None,
439 help="TensorFlow*: directory representing non frozen model",
440 action=CanonicalizePathCheckExistenceAction,
442 tf_group.add_argument('--saved_model_tags', type=str, default=None,
443 help="Group of tag(s) of the MetaGraphDef to load, in string format, separated by ','. "
444 "For tag-set contains multiple tags, all tags must be passed in.")
445 tf_group.add_argument('--tensorflow_subgraph_patterns',
446 help='TensorFlow*: a list of comma separated patterns that will be applied to ' +
447 'TensorFlow* node names to ' +
448 'infer a part of the graph using TensorFlow*.')
449 tf_group.add_argument('--tensorflow_operation_patterns',
450 help='TensorFlow*: a list of comma separated patterns that will be applied to ' +
451 'TensorFlow* node type (ops) ' +
452 'to infer these operations using TensorFlow*.')
453 tf_group.add_argument('--tensorflow_custom_operations_config_update',
454 help='TensorFlow*: update the configuration file with node name patterns with input/output '
455 'nodes information.',
456 action=CanonicalizePathCheckExistenceAction)
457 tf_group.add_argument('--tensorflow_use_custom_operations_config',
458 help='TensorFlow*: use the configuration file with custom operation description.',
459 action=CanonicalizePathCheckExistenceAction)
460 tf_group.add_argument('--tensorflow_object_detection_api_pipeline_config',
461 help='TensorFlow*: path to the pipeline configuration file used to generate model created '
462 'with help of Object Detection API.',
463 action=CanonicalizePathCheckExistenceAction)
464 tf_group.add_argument('--tensorboard_logdir',
465 help='TensorFlow*: dump the input graph to a given directory that should be used with TensorBoard.',
467 action=CanonicalizePathCheckExistenceAction)
468 tf_group.add_argument('--tensorflow_custom_layer_libraries',
469 help='TensorFlow*: comma separated list of shared libraries with TensorFlow* custom '
470 'operations implementation.',
472 action=CanonicalizePathCheckExistenceAction)
473 tf_group.add_argument('--disable_nhwc_to_nchw',
474 help='Disables default translation from NHWC to NCHW',
479 def get_mxnet_cli_parser(parser: argparse.ArgumentParser = None):
481 Specifies cli arguments for Model Optimizer for MXNet*
485 ArgumentParser instance
488 parser = argparse.ArgumentParser()
489 get_common_cli_parser(parser=parser)
491 mx_group = parser.add_argument_group('Mxnet-specific parameters')
493 mx_group.add_argument('--input_symbol',
494 help='Symbol file (for example, model-symbol.json) that contains a topology structure ' +
495 'and layer attributes',
497 action=CanonicalizePathCheckExistenceAction)
498 mx_group.add_argument("--nd_prefix_name",
499 help="Prefix name for args.nd and argx.nd files.",
501 mx_group.add_argument("--pretrained_model_name",
502 help="Name of a pretrained MXNet model without extension and epoch number. This model will be merged with args.nd and argx.nd files",
504 mx_group.add_argument("--save_params_from_nd",
506 help="Enable saving built parameters file from .nd files")
507 mx_group.add_argument("--legacy_mxnet_model",
509 help="Enable MXNet loader to make a model compatible with the latest MXNet version. Use only if your model was trained with MXNet version lower than 1.0.0")
513 def get_kaldi_cli_parser(parser: argparse.ArgumentParser = None):
515 Specifies cli arguments for Model Optimizer for MXNet*
519 ArgumentParser instance
522 parser = argparse.ArgumentParser()
523 get_common_cli_parser(parser=parser)
525 kaldi_group = parser.add_argument_group('Kaldi-specific parameters')
527 kaldi_group.add_argument("--counts",
528 help="Path to the counts file",
530 action=CanonicalizePathCheckExistenceAction)
532 kaldi_group.add_argument("--remove_output_softmax",
533 help="Removes the SoftMax layer that is the output layer",
538 def get_onnx_cli_parser(parser: argparse.ArgumentParser = None):
540 Specifies cli arguments for Model Optimizer for ONNX
544 ArgumentParser instance
547 parser = argparse.ArgumentParser()
548 get_common_cli_parser(parser=parser)
550 tf_group = parser.add_argument_group('ONNX*-specific parameters')
555 def get_all_cli_parser():
557 Specifies cli arguments for Model Optimizer
561 ArgumentParser instance
563 parser = argparse.ArgumentParser()
565 parser.add_argument('--framework',
566 help='Name of the framework used to train the input model.',
568 choices=['tf', 'caffe', 'mxnet', 'kaldi', 'onnx'])
570 get_common_cli_parser(parser=parser)
572 get_tf_cli_parser(parser=parser)
573 get_caffe_cli_parser(parser=parser)
574 get_mxnet_cli_parser(parser=parser)
575 get_kaldi_cli_parser(parser=parser)
576 get_onnx_cli_parser(parser=parser)
581 def get_placeholder_shapes(argv_input: str, argv_input_shape: str, argv_batch=None):
583 Parses input layers names and input shapes from the cli and returns the parsed object
588 string with a list of input layers: either an empty string, or strings separated with comma.
591 string with a list of input shapes: either an empty string, or tuples separated with comma.
593 Only positive integers are accepted except -1, which can be on any position in a shape.
595 integer that overrides batch size in input shape
599 parsed shapes in form of {'name of input':ndarray} if names of inputs are provided with shapes
600 parsed shapes in form of {'name of input':None} if names of inputs are provided without shapes
601 ndarray if only one shape is provided and no input name
602 None if neither shape nor input were provided
604 if argv_input_shape and argv_batch:
605 raise Error("Both --input_shape and --batch were provided. Please provide only one of them. " +
606 refer_to_faq_msg(56))
609 placeholder_shapes = None
611 first_digit_reg = r'([0-9 ]+|-1)'
612 next_digits_reg = r'(,{})*'.format(first_digit_reg)
613 tuple_reg = r'((\({}{}\))|(\[{}{}\]))'.format(first_digit_reg, next_digits_reg,
614 first_digit_reg, next_digits_reg)
616 full_reg = r'^{}(\s*,\s*{})*$|^$'.format(tuple_reg, tuple_reg)
617 if not re.match(full_reg, argv_input_shape):
618 raise Error('Input shape "{}" cannot be parsed. ' +
619 refer_to_faq_msg(57), argv_input_shape)
621 shapes = re.findall(r'[(\[]([0-9, -]+)[)\]]', argv_input_shape)
624 inputs = argv_input.split(',')
626 # check number of shapes with no input provided
627 if argv_input_shape and not argv_input:
629 raise Error('Please provide input layer names for input layer shapes. ' +
630 refer_to_faq_msg(58))
632 placeholder_shapes = np.fromstring(shapes[0], dtype=np.int64, sep=',')
634 # check if number of shapes does not match number of passed inputs
635 elif argv_input and (len(shapes) == len(inputs) or len(shapes) == 0):
636 placeholder_shapes = dict(zip_longest(inputs,
637 map(lambda x: np.fromstring(x, dtype=np.int64,
638 sep=',') if x else None, shapes)))
640 raise Error('Please provide each input layers with an input layer shape. ' +
641 refer_to_faq_msg(58))
643 return placeholder_shapes
646 def parse_tuple_pairs(argv_values: str):
648 Gets mean/scale values from the given string parameter
652 string with a specified input name and list of mean values: either an empty string, or a tuple
654 E.g. 'data(1,2,3)' means 1 for the RED channel, 2 for the GREEN channel, 3 for the BLUE channel for the data
655 input layer, or tuple of values in a form [] or () if input is specified separately, e.g. (1,2,3),[4,5,6].
659 dictionary with input name and tuple of values or list of values if mean/scale value is specified with input,
661 "data(10,20,30),info(11,22,33)" -> { 'data': [10,20,30], 'info': [11,22,33] }
662 "(10,20,30),(11,22,33)" -> [np.array(10,20,30), np.array(11,22,33)]
668 data_str = argv_values
670 tuples_matches = re.findall(r'[(\[]([0-9., -]+)[)\]]', data_str, re.IGNORECASE)
671 if not tuples_matches :
673 "Mean/scale values should be in format: data(1,2,3),info(2,3,4)" +
674 " or just plain set of them without naming any inputs: (1,2,3),(2,3,4). " +
675 refer_to_faq_msg(101), argv_values)
676 tuple_value = tuples_matches[0]
677 matches = data_str.split(tuple_value)
679 input_name = matches[0][:-1]
682 # check that other values are specified w/o names
683 words_reg = r'([a-zA-Z]+)'
684 for i in range(0, len(matches)):
685 if re.search(words_reg, matches[i]) is not None:
686 # error - tuple with name is also specified
688 "Mean/scale values should either contain names of input layers: data(1,2,3),info(2,3,4)" +
689 " or just plain set of them without naming any inputs: (1,2,3),(2,3,4)." +
690 refer_to_faq_msg(101), argv_values)
691 for match in tuples_matches:
692 res.append(np.fromstring(match, dtype=float, sep=','))
695 res[input_name] = np.fromstring(tuple_value, dtype=float, sep=',')
697 parenthesis = matches[0][-1]
698 sibling = ')' if parenthesis == '(' else ']'
699 pair = '{}{}{}{}'.format(input_name, parenthesis, tuple_value, sibling)
700 idx_substr = data_str.index(pair)
701 data_str = data_str[idx_substr + len(pair) + 1:]
709 def get_tuple_values(argv_values: str or tuple, num_exp_values: int = 3, t=float or int):
711 Gets mean values from the given string parameter
713 argv_values: string with list of mean values: either an empty string, or a tuple in a form [] or ().
714 E.g. '(1,2,3)' means 1 for the RED channel, 2 for the GREEN channel, 4 for the BLUE channel.
715 t: either float or int
716 num_exp_values: number of values in tuple
722 digit_reg = r'(-?[0-9. ]+)' if t == float else r'(-?[0-9 ]+)'
724 assert num_exp_values > 1, 'Can not parse tuple of size 1'
725 content = r'{0}\s*,{1}\s*{0}'.format(digit_reg, (digit_reg + ',') * (num_exp_values - 2))
726 tuple_reg = r'((\({0}\))|(\[{0}\]))'.format(content)
728 if isinstance(argv_values, tuple) and not len(argv_values):
731 if not len(argv_values) or not re.match(tuple_reg, argv_values):
732 raise Error('Values "{}" cannot be parsed. ' +
733 refer_to_faq_msg(59), argv_values)
735 mean_values_matches = re.findall(r'[(\[]([0-9., -]+)[)\]]', argv_values)
737 for mean in mean_values_matches:
738 if len(mean.split(',')) != num_exp_values:
739 raise Error('{} channels are expected for given values. ' +
740 refer_to_faq_msg(60), num_exp_values)
742 return mean_values_matches
745 def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
747 This function takes mean_values and scale_values, checks and processes them into convenient structure
751 mean_values dictionary, contains input name and mean values passed py user (e.g. {data: np.array[102.4, 122.1, 113.9]}),
752 or list containing values (e.g. np.array[102.4, 122.1, 113.9])
753 scale_values dictionary, contains input name and scale values passed py user (e.g. {data: np.array[102.4, 122.1, 113.9]})
754 or list containing values (e.g. np.array[102.4, 122.1, 113.9])
758 The function returns a dictionary e.g.
759 mean = { 'data: np.array, 'info': np.array }, scale = { 'data: np.array, 'info': np.array }, input = "data, info" ->
760 { 'data': { 'mean': np.array, 'scale': np.array }, 'info': { 'mean': np.array, 'scale': np.array } }
764 # collect input names
766 inputs = argv_input.split(',')
769 if type(mean_values) is dict:
770 inputs = list(mean_values.keys())
771 if type(scale_values) is dict:
772 for name in scale_values.keys():
773 if name not in inputs:
776 # create unified object containing both mean and scale for input
777 if type(mean_values) is dict and type(scale_values) is dict:
778 if not mean_values and not scale_values:
781 inp, port = split_node_in_port(inp)
782 if inp in mean_values or inp in scale_values:
787 mean_values[inp] if inp in mean_values else None,
789 scale_values[inp] if inp in scale_values else None
795 # user specified input and mean/scale separately - we should return dictionary
797 if mean_values and scale_values:
798 if len(inputs) != len(mean_values):
799 raise Error('Numbers of inputs and mean values do not match. ' +
800 refer_to_faq_msg(61))
801 if len(inputs) != len(scale_values):
802 raise Error('Numbers of inputs and scale values do not match. ' +
803 refer_to_faq_msg(62))
805 data = list(zip(mean_values, scale_values))
807 for i in range(len(data)):
820 # only mean value specified
822 data = list(mean_values)
823 for i in range(len(data)):
837 # only scale value specified
839 data = list(scale_values)
840 for i in range(len(data)):
853 # mean and scale are specified without inputs, return list, order is not guaranteed (?)
854 return list(zip_longest(mean_values, scale_values))
857 def get_model_name(path_input_model: str) -> str:
859 Deduces model name by a given path to the input model
861 path_input_model: path to the input model
864 name of the output IR
866 parsed_name, extension = os.path.splitext(os.path.basename(path_input_model))
867 return 'model' if parsed_name.startswith('.') or len(parsed_name) == 0 else parsed_name
870 def get_absolute_path(path_to_file: str) -> str:
872 Deduces absolute path of the file by a given path to the file
874 path_to_file: path to the file
877 absolute path of the file
879 file_path = os.path.expanduser(path_to_file)
880 if not os.path.isabs(file_path):
881 file_path = os.path.join(os.getcwd(), file_path)
885 def check_positive(value):
887 int_value = int(value)
891 raise argparse.ArgumentTypeError("expected a positive integer value")
896 def depersonalize(value: str):
897 if not isinstance(value, str):
900 for path in value.split(','):
901 if os.path.isdir(path):
903 elif os.path.isfile(path):
904 res.append(os.path.join('DIR', os.path.split(path)[1]))
910 def get_meta_info(argv: argparse.Namespace):
911 meta_data = {'unset': []}
912 for key, value in argv.__dict__.items():
913 if value is not None:
914 value = depersonalize(value)
915 meta_data[key] = value
917 meta_data['unset'].append(key)
918 # The attribute 'k' is treated separately because it points to not existing file by default
921 meta_data[key] = ','.join([os.path.join('DIR', os.path.split(i)[1]) for i in meta_data[key].split(',')])