2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
24 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
25 from extensions.middle.AddQuantizeFuse import AddQuantizeFuse
26 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
27 from extensions.middle.MulQuantizeFuse import MulQuantizeFuse
28 from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
29 from mo.front.extractor import extract_node_attrs, remove_output_ops
30 from mo.front.onnx.extractor import onnx_op_extractor, onnx_op_extractors
31 from mo.front.onnx.loader import load_onnx_model, protobuf2nx
32 from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift, convert_muladd_to_scaleshift_or_power, fuse_pad
33 from mo.middle.passes.eliminate import graph_clean_up_onnx, remove_const_ops
34 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
35 from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
36 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
37 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
38 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
39 from mo.middle.passes.infer import convert_mul_add_to_power
40 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
41 from mo.middle.passes.shape import convert_reshape, reverse_input_channels, \
42 fuse_sequence_of_reshapes, merge_nodes_permutations, permute_data_nodes_attrs, permute_op_nodes_attrs
43 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
44 from mo.pipeline.common import prepare_emit_ir
45 from mo.utils import class_registration
46 from mo.utils.cli_parser import get_meta_info
47 from mo.utils.error import Error
48 from mo.utils.utils import refer_to_faq_msg
51 def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str):
52 meta_info = get_meta_info(argv)
54 model_proto = load_onnx_model(model_file_name)
55 model_graph = model_proto.graph # pylint: disable=no-member
57 # assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
58 log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
59 log.debug("Number of all input ports (not true inputs) in graph_def: {}".format(len(model_graph.input)))
60 log.debug("Number of initializers in graph_def: {}".format(len(model_graph.initializer)))
61 log.debug("Number of real inputs in graph_def: {}".format(len(model_graph.input) - len(model_graph.initializer)))
62 update_extractors_with_extensions(onnx_op_extractors)
65 graph = protobuf2nx(model_proto)
66 log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
67 graph.__setattr__('name',
68 output_model_name if output_model_name else model_proto.graph.name) # pylint: disable=no-member
69 graph.graph['layout'] = 'NCHW'
70 graph.graph['cmd_params'] = argv
71 graph.graph['fw'] = 'onnx'
72 graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
73 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
74 except Exception as e:
76 'Cannot pre-process ONNX graph after reading from model file "{}". ' \
77 'File is corrupt or has unsupported format. Details: {}. ' +
82 graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
83 extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
85 # --------------------------------- LOAD END ------------------------------------------------------
86 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
87 class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
90 graph_clean_up_onnx(graph)
92 # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
93 mark_unfused_nodes(graph, argv.finegrain_fusing)
95 # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
96 # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
97 convert_batch_norm(graph)
98 graph_clean_up_onnx(graph)
100 if not argv.disable_fusing:
101 # Converting ScaleShift layer to Mul->Add
102 convert_scale_shift_to_mul_add(graph)
103 graph_clean_up_onnx(graph)
105 # Fusing the sequences of Mul/Add operations
106 fuse_mul_add_sequence(graph)
107 graph_clean_up_onnx(graph)
109 # Fusing linear operation to Convolution
110 fuse_linear_ops(graph)
111 graph_clean_up_onnx(graph)
113 if not argv.disable_gfusing:
114 grouped_convolutions_fusing(graph)
115 graph_clean_up_onnx(graph)
116 if not argv.disable_fusing:
117 fuse_linear_ops(graph)
118 graph_clean_up_onnx(graph)
120 AddQuantizeFuse().find_and_replace_pattern(graph)
121 MulQuantizeFuse().find_and_replace_pattern(graph)
123 convert_muladd_to_scaleshift_or_power(graph)
124 graph_clean_up_onnx(graph)
126 convert_mul_add_to_power(graph)
127 graph_clean_up_onnx(graph)
129 convert_reshape(graph)
130 graph_clean_up_onnx(graph)
131 convert_add_or_mul_to_scaleshift(graph) # scale = 1
132 graph_clean_up_onnx(graph)
135 graph_clean_up_onnx(graph)
137 if argv.reverse_input_channels:
138 reverse_input_channels(graph)
140 if argv.move_to_preprocess:
141 move_scaleshift_to_preprocess(graph)
142 graph_clean_up_onnx(graph)
144 fuse_sequence_of_reshapes(graph)
145 graph_clean_up_onnx(graph)
147 pattern = EltwiseInputNormalize()
148 pattern.find_and_replace_pattern(graph)
150 merge_nodes_permutations(graph)
151 permute_data_nodes_attrs(graph)
152 permute_op_nodes_attrs(graph)
154 class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
156 for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
158 CreateConstNodesReplacement().find_and_replace_pattern(graph)
160 for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
162 prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,