2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
17 from extensions.front.restore_ports import RestorePorts
18 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
19 from mo.utils.error import Error, FrameworkError
20 from mo.utils.utils import refer_to_faq_msg
25 raise Error('Module mxnet was not found. Please install appropriate version of mxnet via install_prerequisites '
26 'script.' + refer_to_faq_msg(52))
30 from mo.front.extractor import extract_node_attrs, remove_output_ops
31 from mo.front.mxnet.extractor import mxnet_op_extractor
32 from mo.front.mxnet.loader import symbol2nx, load_symbol_def
33 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
34 from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
35 convert_add_or_mul_to_scaleshift, fuse_pad
36 from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
37 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
38 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
39 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
40 from mo.middle.passes.fusing.resnet_optimization import stride_optimization
41 from mo.middle.passes.infer import convert_mul_add_to_power
42 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
43 from mo.middle.passes.shape import reverse_input_channels
44 from mo.pipeline.common import prepare_emit_ir
45 from mo.front.mxnet.nd_to_params import save_params_file
46 from mo.front.common.register_custom_ops import update_extractors_with_extensions
47 from mo.front.mxnet.extractor import mxnet_op_extractors
48 from mo.utils import class_registration
49 from mo.utils.cli_parser import get_meta_info
50 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
53 def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, output_dir: str):
54 meta_info = get_meta_info(argv)
57 model_nodes, model_params, model_name, iteration_number = load_symbol_def(input_model, argv.input_symbol,
60 argv.pretrained_model_name,
61 argv.legacy_mxnet_model)
62 except (ValueError, mxnet.base.MXNetError) as e:
64 'The following error happened while loading mxnet model {}: {}. ' +
70 if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
71 save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number)
73 update_extractors_with_extensions(mxnet_op_extractors)
74 graph = symbol2nx(model_nodes, model_params, argv.input)
75 graph.check_empty_graph('symbol2nx. It may happen due to problems with loaded model')
77 graph.__setattr__('name', output_model_name)
78 graph.graph['layout'] = 'NCHW'
79 graph.graph['cmd_params'] = argv
80 graph.graph['fw'] = 'mxnet'
81 graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
82 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
83 extract_node_attrs(graph, mxnet_op_extractor)
85 # --------------------------------- LOAD END ------------------------------------------------------
86 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
87 class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
91 # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
92 mark_unfused_nodes(graph, argv.finegrain_fusing)
94 # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
95 convert_batch_norm(graph)
98 if not argv.disable_fusing:
99 # Converting ScaleShift layer to Mul->Add
100 convert_scale_shift_to_mul_add(graph)
101 graph_clean_up(graph)
103 # Fusing the sequences of Mul/Add operations
104 fuse_mul_add_sequence(graph)
105 graph_clean_up(graph)
107 # Fusing linear operation to Convolution
108 fuse_linear_ops(graph)
109 graph_clean_up(graph)
111 if not argv.disable_resnet_optimization:
112 stride_optimization(graph)
116 # Converting Mul->Add to ScaleShift node
117 convert_muladd_to_scaleshift_or_power(graph)
118 graph_clean_up(graph)
120 convert_mul_add_to_power(graph)
121 graph_clean_up(graph)
122 convert_add_or_mul_to_scaleshift(graph) # scale = 1
123 graph_clean_up(graph)
125 if argv.reverse_input_channels:
126 reverse_input_channels(graph)
128 if argv.move_to_preprocess:
129 move_scaleshift_to_preprocess(graph)
130 graph_clean_up(graph)
132 pattern = EltwiseInputNormalize()
133 pattern.find_and_replace_pattern(graph)
135 class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
137 for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
138 CreateConstNodesReplacement().find_and_replace_pattern(graph)
140 for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
142 prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,