2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
20 from mo.front.caffe import custom_layers_mapping, loader
21 from mo.front.caffe.extractor import caffe_type_extractors, caffe_extractor
22 from mo.front.common.register_custom_ops import update_extractors_with_extensions, check_for_duplicates
23 from mo.front.extractor import extract_node_attrs, remove_output_ops
24 from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift
25 from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
26 convert_matmul_to_fully_connected, batch_norm_fuse
27 from mo.middle.passes.eliminate import graph_clean_up
28 from mo.middle.passes.eliminate import remove_const_ops
29 from mo.middle.passes.fusing.decomposition import convert_bn_to_mul_add, convert_scale_shift_to_mul_add
30 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
31 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
32 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
33 from mo.middle.passes.fusing.resnet_optimization import stride_optimization
34 from mo.middle.passes.infer import convert_mul_add_to_power
35 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
36 from mo.middle.passes.shape import reverse_input_channels, fuse_sequence_of_reshapes
37 from mo.pipeline.common import prepare_emit_ir
38 from mo.utils import class_registration
39 from mo.utils.cli_parser import get_meta_info
40 from mo.utils.error import Error
41 from mo.utils.find_inputs import find_inputs
42 from mo.utils.utils import refer_to_faq_msg
45 def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str,
46 output_dir: str, mean_file: str = "",
47 mean_file_offsets: tuple = None, custom_layers_mapping_path: str = None):
48 meta_info = get_meta_info(argv)
50 proto, model = loader.load_caffe_proto_model(proto_file_name, model_file_name)
52 update_extractors_with_extensions(
53 caffe_type_extractors,
54 argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
55 argv.disable_flattening_optional_params if hasattr(argv, 'disable_flattening_optional_params') else False
59 graph, original_shapes = loader.caffe_pb_to_nx(proto, model)
60 except ValueError as e:
61 raise Error('Invalid prototxt file: value error {}. ' +
62 refer_to_faq_msg(11), str(e)) from e
64 log.debug("After caffe_pb_to_nx")
65 graph.print_graph_stat()
66 graph.check_empty_graph('load_caffe_proto_model')
68 graph.__setattr__('proto_path', proto_file_name)
69 graph.__setattr__('caffemodel_path', model_file_name)
70 graph.__setattr__('name', getattr(proto, 'name', None) or output_model_name)
71 graph.graph['layout'] = 'NCHW'
72 graph.graph['cmd_params'] = argv
73 graph.graph['fw'] = 'caffe'
74 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
76 custom_layers_map = custom_layers_mapping.load_layers_xml(custom_layers_mapping_path)
77 custom_layers_mapping.update_extractors(
78 caffe_type_extractors,
80 argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
81 argv.enable_flattening_nested_params if hasattr(argv, 'enable_flattening_nested_params') else False
83 extract_node_attrs(graph, lambda node: caffe_extractor(node, check_for_duplicates(caffe_type_extractors)))
85 # --------------------------------- LOAD END ------------------------------------------------------
86 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
87 class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
89 # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
90 mark_unfused_nodes(graph, argv.finegrain_fusing)
92 # need this pass even without fusing to convert scale with 2 inputs
93 convert_scale_shift_to_mul_add(graph)
96 if not argv.disable_fusing:
97 convert_bn_to_mul_add(graph)
100 fuse_mul_add_sequence(graph)
101 graph_clean_up(graph)
103 fuse_linear_ops(graph)
104 graph_clean_up(graph)
106 if not argv.disable_resnet_optimization:
107 stride_optimization(graph)
109 convert_muladd_to_scaleshift_or_power(graph)
110 convert_matmul_to_fully_connected(graph)
111 batch_norm_fuse(graph)
112 convert_mul_add_to_power(graph)
113 graph_clean_up(graph)
114 convert_add_or_mul_to_scaleshift(graph) # scale = 1
115 graph_clean_up(graph)
117 log.debug("After graph_cleanup")
118 graph.print_graph_stat()
120 if argv.reverse_input_channels:
121 reverse_input_channels(graph)
123 if argv.move_to_preprocess:
124 move_scaleshift_to_preprocess(graph)
125 graph_clean_up(graph)
127 fuse_sequence_of_reshapes(graph)
129 input_names = find_inputs(graph)
132 if mean_file and len(original_shapes) == 1:
133 mf = loader.parse_mean(mean_file, original_shapes[input_names[0]], mean_file_offsets)
135 raise Error('Mean file for topologies with multiple inputs is not supported. ' +
137 except ValueError as e:
138 raise Error('Cannot load or process mean file: value error {}. ' +
139 refer_to_faq_msg(10), str(e)) from e
141 class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
143 remove_const_ops(graph)
144 CreateConstNodesReplacement().find_and_replace_pattern(graph)
146 remove_output_ops(graph)
148 prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
150 input_names=input_names,