Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / onnx.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20
21 import argparse
22 import logging as log
23
24 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
25 from extensions.middle.AddQuantizeFuse import AddQuantizeFuse
26 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
27 from extensions.middle.MulQuantizeFuse import MulQuantizeFuse
28 from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
29 from mo.front.extractor import extract_node_attrs, remove_output_ops
30 from mo.front.onnx.extractor import onnx_op_extractor, onnx_op_extractors
31 from mo.front.onnx.loader import load_onnx_model, protobuf2nx
32 from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift, convert_muladd_to_scaleshift_or_power, fuse_pad
33 from mo.middle.passes.eliminate import graph_clean_up_onnx, remove_const_ops
34 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
35 from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
36 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
37 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
38 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
39 from mo.middle.passes.infer import convert_mul_add_to_power
40 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
41 from mo.middle.passes.shape import convert_reshape, reverse_input_channels, \
42     fuse_sequence_of_reshapes, merge_nodes_permutations, permute_data_nodes_attrs, permute_op_nodes_attrs
43 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
44 from mo.pipeline.common import prepare_emit_ir
45 from mo.utils import class_registration
46 from mo.utils.cli_parser import get_meta_info
47 from mo.utils.error import Error
48 from mo.utils.utils import refer_to_faq_msg
49
50
51 def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str):
52     meta_info = get_meta_info(argv)
53
54     model_proto = load_onnx_model(model_file_name)
55     model_graph = model_proto.graph  # pylint: disable=no-member
56     # print(model_graph)
57     # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
58     log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
59     log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
60     log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
61     log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
62     update_extractors_with_extensions(onnx_op_extractors)
63
64     try:
65         graph = protobuf2nx(model_proto)
66         log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
67         graph.__setattr__('name',
68                           output_model_name if output_model_name else model_proto.graph.name)  # pylint: disable=no-member
69         graph.graph['layout'] = 'NCHW'
70         graph.graph['cmd_params'] = argv
71         graph.graph['fw'] = 'onnx'
72         graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
73         graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
74     except Exception as e:
75         raise Error(
76             'Cannot pre-process ONNX graph after reading from model file "{}". ' \
77             'File is corrupt or has unsupported format. Details: {}. ' +
78             refer_to_faq_msg(44),
79             model_file_name,
80             str(e)
81         ) from e
82     graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
83     extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
84
85     # --------------------------------- LOAD END ------------------------------------------------------
86     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
87     class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
88
89     fuse_pad(graph)
90     graph_clean_up_onnx(graph)
91
92     # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
93     mark_unfused_nodes(graph, argv.finegrain_fusing)
94
95     # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
96     # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
97     convert_batch_norm(graph)
98     graph_clean_up_onnx(graph)
99
100     if not argv.disable_fusing:
101         # Converting ScaleShift layer to Mul->Add
102         convert_scale_shift_to_mul_add(graph)
103         graph_clean_up_onnx(graph)
104
105         # Fusing the sequences of Mul/Add operations
106         fuse_mul_add_sequence(graph)
107         graph_clean_up_onnx(graph)
108
109         # Fusing linear operation to Convolution
110         fuse_linear_ops(graph)
111         graph_clean_up_onnx(graph)
112
113     if not argv.disable_gfusing:
114         grouped_convolutions_fusing(graph)
115         graph_clean_up_onnx(graph)
116         if not argv.disable_fusing:
117             fuse_linear_ops(graph)
118             graph_clean_up_onnx(graph)
119
120     AddQuantizeFuse().find_and_replace_pattern(graph)
121     MulQuantizeFuse().find_and_replace_pattern(graph)
122
123     convert_muladd_to_scaleshift_or_power(graph)
124     graph_clean_up_onnx(graph)
125
126     convert_mul_add_to_power(graph)
127     graph_clean_up_onnx(graph)
128
129     convert_reshape(graph)
130     graph_clean_up_onnx(graph)
131     convert_add_or_mul_to_scaleshift(graph)  # scale = 1
132     graph_clean_up_onnx(graph)
133
134     fuse_pad(graph)
135     graph_clean_up_onnx(graph)
136
137     if argv.reverse_input_channels:
138         reverse_input_channels(graph)
139
140     if argv.move_to_preprocess:
141         move_scaleshift_to_preprocess(graph)
142         graph_clean_up_onnx(graph)
143
144     fuse_sequence_of_reshapes(graph)
145     graph_clean_up_onnx(graph)
146
147     pattern = EltwiseInputNormalize()
148     pattern.find_and_replace_pattern(graph)
149
150     merge_nodes_permutations(graph)
151     permute_data_nodes_attrs(graph)
152     permute_op_nodes_attrs(graph)
153
154     class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
155
156     for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
157
158     CreateConstNodesReplacement().find_and_replace_pattern(graph)
159
160     for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
161
162     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
163                     meta_info=meta_info)
164
165     return 0