2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20 from __future__ import unicode_literals
31 from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
32 from mo.front.common.find_unsupported_ops import find_unsupported_ops
33 from mo.front.common.register_custom_ops import check_for_duplicates
34 from mo.front.common.register_custom_ops import update_extractors_with_extensions
35 from mo.front.extractor import restore_edges, add_output_ops, add_input_ops, \
36 extract_node_attrs, create_tensor_nodes, remove_output_ops, user_data_repack
37 from mo.front.onnx.extractor import common_onnx_fields, onnx_op_extractor, onnx_op_extractors
38 from mo.front.onnx.loader import load_onnx_model, protobuf2nx
39 from mo.middle.passes.conv import convert_add_to_scaleshift, \
40 convert_weights_yxio_to_oiyx, convert_weights_yxio_to_goiyx, convert_gemm_to_fully_connected, \
41 convert_muladd_to_scaleshift_or_power, fuse_pad, transpose_fully_connected_weights, \
42 convert_dilated_convolution, convert_mul_to_scaleshift, convert_nasnet
43 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
44 from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes, remove_useless_split
45 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
46 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
47 from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
48 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
49 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
50 from mo.middle.passes.infer import scale_input, override_placeholder_shapes, partial_infer, convert_mul_add_to_power, \
51 update_fully_connected_shapes, add_mean_scale_values, override_batch
52 from mo.middle.passes.l2normalization import l2_norm_to_norm
53 from mo.middle.passes.pool import mean_to_avgpool
54 from mo.middle.passes.shape import convert_squeeze, convert_reshape, convert_nhwc_to_nchw, reverse_input_channels, \
55 conv_flatten_concat, fuse_sequence_of_reshapes
56 from mo.utils import class_registration
57 from mo.pipeline.common import prepare_emit_ir
58 from mo.utils.custom_replacement_config import update_custom_replacement_config_file
59 from mo.utils.error import Error
60 from mo.utils.utils import refer_to_faq_msg
63 def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, outputs: list, output_dir: str,
65 user_shapes: [None, list, np.array] = None,
66 mean_scale_values: [dict, list] = ()):
68 model_proto = load_onnx_model(model_file_name)
69 model_graph = model_proto.graph
71 #assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
72 log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
73 log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
74 log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
75 log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
76 update_extractors_with_extensions(onnx_op_extractors)
79 graph = protobuf2nx(model_proto)
80 log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
81 graph.__setattr__('name', output_model_name if output_model_name else model_proto.graph.name)
82 graph.graph['layout'] = 'NCHW'
83 graph.graph['cmd_params'] = argv
84 graph.graph['fw'] = 'onnx'
85 graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
86 # extract basic attributes earlier to enable some passes that relies on them before full attribute
88 extract_node_attrs(graph, lambda node: (True, common_onnx_fields(node)))
89 except Exception as e:
91 'Cannot pre-process ONNX graph after reading from model file "{}". ' \
92 'File is corrupt or has unsupported format. Details: {}. ' +
98 user_shapes, outputs, _ = user_data_repack(graph, user_shapes, outputs, None)
100 graph, output_op_nodes = add_output_ops(graph, outputs)
101 graph, input_op_nodes = add_input_ops(graph, user_shapes, True)
103 # this call of 'graph_clean_up' removes child nodes of outputs which is useful when custom output is specified
104 graph_clean_up(graph)
106 extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
108 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
110 create_tensor_nodes(graph)
111 graph_clean_up(graph)
113 override_placeholder_shapes(graph, user_shapes)
114 override_batch(graph, argv.batch)
116 graph_clean_up(graph)
117 remove_op_nodes(graph, {'op': 'Identity'})
119 graph_clean_up(graph)
121 remove_output_ops(graph)
124 graph_clean_up(graph)
127 graph, input_op_nodes = add_input_ops(graph, user_shapes, False)
128 graph_clean_up(graph)
130 #change_placeholders_types_to_FP32(graph)
132 scale_input(graph, scale)
133 add_mean_scale_values(graph, mean_scale_values)
135 convert_dilated_convolution(graph)
136 graph_clean_up(graph)
138 graph_clean_up(graph)
140 remove_op_nodes(graph, {'op': 'Identity'})
141 remove_useless_split(graph)
143 class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
145 convert_gemm_to_fully_connected(graph)
148 graph_clean_up(graph)
150 # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
151 mark_unfused_nodes(graph, argv.finegrain_fusing)
153 # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
154 # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
155 convert_batch_norm(graph)
156 graph_clean_up(graph)
158 if not argv.disable_fusing:
159 # Converting ScaleShift layer to Mul->Add
160 convert_scale_shift_to_mul_add(graph)
161 graph_clean_up(graph)
163 # Fusing the sequences of Mul/Add operations
164 fuse_mul_add_sequence(graph)
165 graph_clean_up(graph)
167 # Fusing linear operation to Convolution
168 fuse_linear_ops(graph)
169 graph_clean_up(graph)
171 if not argv.disable_gfusing:
172 grouped_convolutions_fusing(graph)
173 graph_clean_up(graph)
174 if not argv.disable_fusing:
175 fuse_linear_ops(graph)
176 graph_clean_up(graph)
178 convert_muladd_to_scaleshift_or_power(graph)
179 graph_clean_up(graph)
181 convert_mul_add_to_power(graph)
183 # Need to eliminate dead nodes before doing update_fully_connected_shapes
184 # because update_fully_connected_shapes does partial inference and dead
185 # nodes will lead to sporadic failures.
186 graph_clean_up(graph)
187 update_fully_connected_shapes(graph)
189 convert_reshape(graph)
190 convert_add_to_scaleshift(graph) # scale = 1
191 convert_mul_to_scaleshift(graph) # biases = 0
194 graph_clean_up(graph)
196 if argv.reverse_input_channels:
197 reverse_input_channels(graph)
199 if argv.move_to_preprocess:
200 move_scaleshift_to_preprocess(graph)
201 graph_clean_up(graph)
203 fuse_sequence_of_reshapes(graph)
204 graph_clean_up(graph)
206 class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
208 prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name)