updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / onnx.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20
21 import argparse
22 import logging as log
23
24 from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
25 from mo.front.extractor import extract_node_attrs
26 from mo.front.onnx.extractor import onnx_op_extractor, onnx_op_extractors
27 from mo.front.onnx.loader import load_onnx_model, protobuf2nx
28 from mo.pipeline.common import prepare_emit_ir
29 from mo.utils import class_registration
30 from mo.utils.cli_parser import get_meta_info
31 from mo.utils.error import Error
32 from mo.utils.utils import refer_to_faq_msg
33
34
35 def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str):
36     meta_info = get_meta_info(argv)
37
38     model_proto = load_onnx_model(model_file_name)
39     model_graph = model_proto.graph  # pylint: disable=no-member
40     # print(model_graph)
41     # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
42     log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
43     log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
44     log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
45     log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
46     update_extractors_with_extensions(onnx_op_extractors)
47
48     try:
49         graph = protobuf2nx(model_proto)
50         log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
51         graph.__setattr__('name',
52                           output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
53         graph.graph['layout'] = 'NCHW'
54         graph.graph['cmd_params'] = argv
55         graph.graph['fw'] = 'onnx'
56         graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
57
58         if graph.graph['cmd_params'].generate_experimental_IR_V10:
59             version = 10
60         else:
61             version = 6
62         graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version
63
64     except Exception as e:
65         raise Error(
66             'Cannot pre-process ONNX graph after reading from model file "{}". ' \
67             'File is corrupt or has unsupported format. Details: {}. ' +
68             refer_to_faq_msg(44),
69             model_file_name,
70             str(e)
71         ) from e
72     graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
73     extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
74
75     # --------------------------------- LOAD END ------------------------------------------------------
76
77     class_registration.apply_replacements(graph, [
78         class_registration.ClassType.FRONT_REPLACER,
79         class_registration.ClassType.MIDDLE_REPLACER,
80         class_registration.ClassType.BACK_REPLACER
81     ])
82
83     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
84                     meta_info=meta_info)
85     return 0