2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
24 from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
25 from mo.front.extractor import extract_node_attrs
26 from mo.front.onnx.extractor import onnx_op_extractor, onnx_op_extractors
27 from mo.front.onnx.loader import load_onnx_model, protobuf2nx
28 from mo.pipeline.common import prepare_emit_ir
29 from mo.utils import class_registration
30 from mo.utils.cli_parser import get_meta_info
31 from mo.utils.error import Error
32 from mo.utils.utils import refer_to_faq_msg
35 def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str):
36 meta_info = get_meta_info(argv)
38 model_proto = load_onnx_model(model_file_name)
39 model_graph = model_proto.graph # pylint: disable=no-member
41 # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
42 log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
43 log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
44 log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
45 log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
46 update_extractors_with_extensions(onnx_op_extractors)
49 graph = protobuf2nx(model_proto)
50 log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
51 graph.__setattr__('name',
52 output_model_name if output_model_name else model_proto.graph.name) # pylint: disable=no-member
53 graph.graph['layout'] = 'NCHW'
54 graph.graph['cmd_params'] = argv
55 graph.graph['fw'] = 'onnx'
56 graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
58 if graph.graph['cmd_params'].generate_experimental_IR_V10:
62 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version
64 except Exception as e:
66 'Cannot pre-process ONNX graph after reading from model file "{}". ' \
67 'File is corrupt or has unsupported format. Details: {}. ' +
72 graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
73 extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
75 # --------------------------------- LOAD END ------------------------------------------------------
77 class_registration.apply_replacements(graph, [
78 class_registration.ClassType.FRONT_REPLACER,
79 class_registration.ClassType.MIDDLE_REPLACER,
80 class_registration.ClassType.BACK_REPLACER
83 prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,