Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / caffe.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import argparse
17 import logging as log
18
19 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
20 from mo.front.caffe import custom_layers_mapping, loader
21 from mo.front.caffe.extractor import caffe_type_extractors, caffe_extractor
22 from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
23 from mo.front.extractor import extract_node_attrs, remove_output_ops
24 from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift
25 from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
26     convert_matmul_to_fully_connected, batch_norm_fuse
27 from mo.middle.passes.eliminate import graph_clean_up
28 from mo.middle.passes.eliminate import remove_const_ops
29 from mo.middle.passes.fusing.decomposition import convert_bn_to_mul_add, convert_scale_shift_to_mul_add
30 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
31 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
32 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
33 from mo.middle.passes.fusing.resnet_optimization import stride_optimization
34 from mo.middle.passes.infer import convert_mul_add_to_power
35 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
36 from mo.middle.passes.shape import reverse_input_channels, fuse_sequence_of_reshapes
37 from mo.pipeline.common import prepare_emit_ir
38 from mo.utils import class_registration
39 from mo.utils.cli_parser import get_meta_info
40 from mo.utils.error import Error
41 from mo.utils.find_inputs import find_inputs
42 from mo.utils.utils import refer_to_faq_msg
43
44
45 def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str,
46            output_dir: str, mean_file: str = "",
47            mean_file_offsets: tuple = None, custom_layers_mapping_path: str = None):
48     meta_info = get_meta_info(argv)
49
50     proto, model = loader.load_caffe_proto_model(proto_file_name, model_file_name)
51
52     update_extractors_with_extensions(
53         caffe_type_extractors,
54         argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
55         argv.disable_flattening_optional_params if hasattr(argv, 'disable_flattening_optional_params') else False
56     )
57
58     try:
59         graph, original_shapes = loader.caffe_pb_to_nx(proto, model)
60     except ValueError as e:
61         raise Error('Invalid prototxt file: value error {}. ' +
62                     refer_to_faq_msg(11), str(e)) from e
63
64     log.debug("After caffe_pb_to_nx")
65     graph.print_graph_stat()
66     graph.check_empty_graph('load_caffe_proto_model')
67
68     graph.__setattr__('proto_path', proto_file_name)
69     graph.__setattr__('caffemodel_path', model_file_name)
70     graph.__setattr__('name', getattr(proto, 'name', None) or output_model_name)
71     graph.graph['layout'] = 'NCHW'
72     graph.graph['cmd_params'] = argv
73     graph.graph['fw'] = 'caffe'
74     graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
75
76     custom_layers_map = custom_layers_mapping.load_layers_xml(custom_layers_mapping_path)
77     custom_layers_mapping.update_extractors(
78         caffe_type_extractors,
79         custom_layers_map,
80         argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
81         argv.enable_flattening_nested_params if hasattr(argv, 'enable_flattening_nested_params') else False
82     )
83     extract_node_attrs(graph, lambda node: caffe_extractor(node, check_for_duplicates(caffe_type_extractors)))
84
85     # --------------------------------- LOAD END ------------------------------------------------------
86     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
87     class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
88
89     # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
90     mark_unfused_nodes(graph, argv.finegrain_fusing)
91
92     # need this pass even without fusing to convert scale with 2 inputs
93     convert_scale_shift_to_mul_add(graph)
94     graph_clean_up(graph)
95
96     if not argv.disable_fusing:
97         convert_bn_to_mul_add(graph)
98         graph_clean_up(graph)
99
100         fuse_mul_add_sequence(graph)
101         graph_clean_up(graph)
102
103         fuse_linear_ops(graph)
104         graph_clean_up(graph)
105
106     if not argv.disable_resnet_optimization:
107         stride_optimization(graph)
108
109     convert_muladd_to_scaleshift_or_power(graph)
110     convert_matmul_to_fully_connected(graph)
111     batch_norm_fuse(graph)
112     convert_mul_add_to_power(graph)
113     graph_clean_up(graph)
114     convert_add_or_mul_to_scaleshift(graph)  # scale = 1
115     graph_clean_up(graph)
116
117     log.debug("After graph_cleanup")
118     graph.print_graph_stat()
119
120     if argv.reverse_input_channels:
121         reverse_input_channels(graph)
122
123     if argv.move_to_preprocess:
124         move_scaleshift_to_preprocess(graph)
125         graph_clean_up(graph)
126
127     fuse_sequence_of_reshapes(graph)
128
129     input_names = find_inputs(graph)
130     mf = []
131     try:
132         if mean_file and len(original_shapes) == 1:
133             mf = loader.parse_mean(mean_file, original_shapes[input_names[0]], mean_file_offsets)
134         elif mean_file:
135             raise Error('Mean file for topologies with multiple inputs is not supported. ' +
136                         refer_to_faq_msg(9))
137     except ValueError as e:
138         raise Error('Cannot load or process mean file: value error {}. ' +
139                     refer_to_faq_msg(10), str(e)) from e
140
141     class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
142
143     remove_const_ops(graph)
144     CreateConstNodesReplacement().find_and_replace_pattern(graph)
145
146     remove_output_ops(graph)
147
148     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
149                     mean_data=mf,
150                     input_names=input_names,
151                     meta_info=meta_info)
152     return 0