Publishing R3
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / onnx.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20 from __future__ import unicode_literals
21
22 import argparse
23 import copy
24 import logging as log
25
26 import onnx
27 import os
28
29 import numpy as np
30
31 from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
32 from mo.front.common.find_unsupported_ops import find_unsupported_ops
33 from mo.front.common.register_custom_ops import check_for_duplicates
34 from mo.front.common.register_custom_ops import update_extractors_with_extensions
35 from mo.front.extractor import restore_edges, add_output_ops, add_input_ops, \
36     extract_node_attrs, create_tensor_nodes, remove_output_ops, user_data_repack
37 from mo.front.onnx.extractor import common_onnx_fields, onnx_op_extractor, onnx_op_extractors
38 from mo.front.onnx.loader import load_onnx_model, protobuf2nx
39 from mo.middle.passes.conv import convert_add_to_scaleshift, \
40     convert_weights_yxio_to_oiyx, convert_weights_yxio_to_goiyx, convert_gemm_to_fully_connected, \
41     convert_muladd_to_scaleshift_or_power, fuse_pad, transpose_fully_connected_weights, \
42     convert_dilated_convolution, convert_mul_to_scaleshift, convert_nasnet
43 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
44 from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes, remove_useless_split
45 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
46 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
47 from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
48 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
49 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
50 from mo.middle.passes.infer import scale_input, override_placeholder_shapes, partial_infer, convert_mul_add_to_power, \
51     update_fully_connected_shapes, add_mean_scale_values, override_batch
52 from mo.middle.passes.l2normalization import l2_norm_to_norm
53 from mo.middle.passes.pool import mean_to_avgpool
54 from mo.middle.passes.shape import convert_squeeze, convert_reshape, convert_nhwc_to_nchw, reverse_input_channels, \
55     conv_flatten_concat, fuse_sequence_of_reshapes
56 from mo.utils import class_registration
57 from mo.pipeline.common import prepare_emit_ir
58 from mo.utils.custom_replacement_config import update_custom_replacement_config_file
59 from mo.utils.error import Error
60 from mo.utils.utils import refer_to_faq_msg
61
62
63 def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, outputs: list, output_dir: str,
64           scale: float,
65           user_shapes: [None, list, np.array] = None,
66           mean_scale_values: [dict, list] = ()):
67
68     model_proto = load_onnx_model(model_file_name)
69     model_graph = model_proto.graph
70     #print(model_graph)
71     #assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
72     log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
73     log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
74     log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
75     log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
76     update_extractors_with_extensions(onnx_op_extractors)
77
78     try:
79         graph = protobuf2nx(model_proto)
80         log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
81         graph.__setattr__('name', output_model_name if output_model_name else model_proto.graph.name)
82         graph.graph['layout'] = 'NCHW'
83         graph.graph['cmd_params'] = argv
84         graph.graph['fw'] = 'onnx'
85         graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
86         # extract basic attributes earlier to enable some passes that relies on them before full attribute
87         # extractor is called
88         extract_node_attrs(graph, lambda node: (True, common_onnx_fields(node)))
89     except Exception as e:
90         raise Error(
91             'Cannot pre-process ONNX graph after reading from model file "{}". ' \
92             'File is corrupt or has unsupported format. Details: {}. ' +
93             refer_to_faq_msg(44),
94             model_file_name,
95             str(e)
96         ) from e
97
98     user_shapes, outputs, _ = user_data_repack(graph, user_shapes, outputs, None)
99
100     graph, output_op_nodes = add_output_ops(graph, outputs)
101     graph, input_op_nodes = add_input_ops(graph, user_shapes, True)
102
103     # this call of 'graph_clean_up' removes child nodes of outputs which is useful when custom output is specified
104     graph_clean_up(graph)
105
106     extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
107
108     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
109
110     create_tensor_nodes(graph)
111     graph_clean_up(graph)
112
113     override_placeholder_shapes(graph, user_shapes)
114     override_batch(graph, argv.batch)
115
116     graph_clean_up(graph)
117     remove_op_nodes(graph, {'op': 'Identity'})
118
119     graph_clean_up(graph)
120
121     remove_output_ops(graph)
122
123     partial_infer(graph)
124     graph_clean_up(graph)
125
126
127     graph, input_op_nodes = add_input_ops(graph, user_shapes, False)
128     graph_clean_up(graph)
129
130     #change_placeholders_types_to_FP32(graph)
131
132     scale_input(graph, scale)
133     add_mean_scale_values(graph, mean_scale_values)
134
135     convert_dilated_convolution(graph)
136     graph_clean_up(graph)
137
138     graph_clean_up(graph)
139
140     remove_op_nodes(graph, {'op': 'Identity'})
141     remove_useless_split(graph)
142
143     class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
144
145     convert_gemm_to_fully_connected(graph)
146
147     fuse_pad(graph)
148     graph_clean_up(graph)
149
150     # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
151     mark_unfused_nodes(graph, argv.finegrain_fusing)
152
153     # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
154     # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
155     convert_batch_norm(graph)
156     graph_clean_up(graph)
157
158     if not argv.disable_fusing:
159         # Converting ScaleShift layer to Mul->Add
160         convert_scale_shift_to_mul_add(graph)
161         graph_clean_up(graph)
162
163         # Fusing the sequences of Mul/Add operations
164         fuse_mul_add_sequence(graph)
165         graph_clean_up(graph)
166
167         # Fusing linear operation to Convolution
168         fuse_linear_ops(graph)
169         graph_clean_up(graph)
170
171     if not argv.disable_gfusing:
172         grouped_convolutions_fusing(graph)
173         graph_clean_up(graph)
174         if not argv.disable_fusing:
175             fuse_linear_ops(graph)
176             graph_clean_up(graph)
177
178     convert_muladd_to_scaleshift_or_power(graph)
179     graph_clean_up(graph)
180
181     convert_mul_add_to_power(graph)
182
183     # Need to eliminate dead nodes before doing update_fully_connected_shapes
184     # because update_fully_connected_shapes does partial inference and dead
185     # nodes will lead to sporadic failures.
186     graph_clean_up(graph)
187     update_fully_connected_shapes(graph)
188
189     convert_reshape(graph)
190     convert_add_to_scaleshift(graph)  # scale = 1
191     convert_mul_to_scaleshift(graph)  # biases = 0
192
193     fuse_pad(graph)
194     graph_clean_up(graph)
195
196     if argv.reverse_input_channels:
197         reverse_input_channels(graph)
198
199     if argv.move_to_preprocess:
200         move_scaleshift_to_preprocess(graph)
201         graph_clean_up(graph)
202
203     fuse_sequence_of_reshapes(graph)
204     graph_clean_up(graph)
205
206     class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
207
208     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name)
209
210     return 0