Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / mx.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
17 from extensions.front.restore_ports import RestorePorts
18 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
19 from mo.utils.error import Error, FrameworkError
20 from mo.utils.utils import refer_to_faq_msg
21
22 try:
23     import mxnet
24 except ImportError:
25     raise Error('Module mxnet was not found. Please install appropriate version of mxnet via install_prerequisites '
26                 'script.' + refer_to_faq_msg(52))
27
28 import argparse
29
30 from mo.front.extractor import extract_node_attrs, remove_output_ops
31 from mo.front.mxnet.extractor import mxnet_op_extractor
32 from mo.front.mxnet.loader import symbol2nx, load_symbol_def
33 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
34 from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
35     convert_add_or_mul_to_scaleshift, fuse_pad
36 from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
37 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
38 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
39 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
40 from mo.middle.passes.fusing.resnet_optimization import stride_optimization
41 from mo.middle.passes.infer import convert_mul_add_to_power
42 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
43 from mo.middle.passes.shape import reverse_input_channels
44 from mo.pipeline.common import prepare_emit_ir
45 from mo.front.mxnet.nd_to_params import save_params_file
46 from mo.front.common.register_custom_ops import update_extractors_with_extensions
47 from mo.front.mxnet.extractor import mxnet_op_extractors
48 from mo.utils import class_registration
49 from mo.utils.cli_parser import get_meta_info
50 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
51
52
53 def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, output_dir: str):
54     meta_info = get_meta_info(argv)
55
56     try:
57         model_nodes, model_params, model_name, iteration_number = load_symbol_def(input_model, argv.input_symbol,
58                                                                                   argv.input,
59                                                                                   argv.nd_prefix_name,
60                                                                                   argv.pretrained_model_name,
61                                                                                   argv.legacy_mxnet_model)
62     except (ValueError, mxnet.base.MXNetError) as e:
63         raise FrameworkError(
64             'The following error happened while loading mxnet model {}: {}. ' +
65             refer_to_faq_msg(53),
66             input_model,
67             str(e)
68         ) from e
69
70     if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
71         save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number)
72
73     update_extractors_with_extensions(mxnet_op_extractors)
74     graph = symbol2nx(model_nodes, model_params, argv.input)
75     graph.check_empty_graph('symbol2nx. It may happen due to problems with loaded model')
76
77     graph.__setattr__('name', output_model_name)
78     graph.graph['layout'] = 'NCHW'
79     graph.graph['cmd_params'] = argv
80     graph.graph['fw'] = 'mxnet'
81     graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
82     graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
83     extract_node_attrs(graph, mxnet_op_extractor)
84
85     # --------------------------------- LOAD END ------------------------------------------------------
86     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
87     class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
88
89     fuse_pad(graph)
90
91     # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
92     mark_unfused_nodes(graph, argv.finegrain_fusing)
93
94     # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
95     convert_batch_norm(graph)
96     graph_clean_up(graph)
97
98     if not argv.disable_fusing:
99         # Converting ScaleShift layer to Mul->Add
100         convert_scale_shift_to_mul_add(graph)
101         graph_clean_up(graph)
102
103         # Fusing the sequences of Mul/Add operations
104         fuse_mul_add_sequence(graph)
105         graph_clean_up(graph)
106
107         # Fusing linear operation to Convolution
108         fuse_linear_ops(graph)
109         graph_clean_up(graph)
110
111     if not argv.disable_resnet_optimization:
112         stride_optimization(graph)
113
114     fuse_pad(graph)
115
116     # Converting Mul->Add to ScaleShift node
117     convert_muladd_to_scaleshift_or_power(graph)
118     graph_clean_up(graph)
119
120     convert_mul_add_to_power(graph)
121     graph_clean_up(graph)
122     convert_add_or_mul_to_scaleshift(graph)  # scale = 1
123     graph_clean_up(graph)
124
125     if argv.reverse_input_channels:
126         reverse_input_channels(graph)
127
128     if argv.move_to_preprocess:
129         move_scaleshift_to_preprocess(graph)
130         graph_clean_up(graph)
131
132     pattern = EltwiseInputNormalize()
133     pattern.find_and_replace_pattern(graph)
134
135     class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
136
137     for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
138     CreateConstNodesReplacement().find_and_replace_pattern(graph)
139
140     for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
141
142     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
143                     meta_info=meta_info)
144     return 0