Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / tf.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import argparse
18 import logging as log
19
20 import tensorflow as tf
21
22 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
23 from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
24 from extensions.middle.ConcatOptimization import ConcatOptimization
25
26 try:
27     import tensorflow.contrib
28 except:
29     pass  # we try to import contrib for loading models that use contrib operations
30
31 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
32 from mo.middle.passes.eliminate import remove_const_ops
33 from mo.front.common.register_custom_ops import check_for_duplicates
34 from mo.front.common.register_custom_ops import update_extractors_with_extensions
35 from mo.front.extractor import restore_edges, extract_node_attrs, remove_output_ops, remove_control_dependency_inputs
36 from mo.front.tf.extractor import get_tf_edges, tf_op_extractor, tf_op_extractors
37 from mo.front.tf.loader import load_tf_graph_def, protobuf2nx
38 from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift, convert_matmul_to_fully_connected, \
39     convert_muladd_to_scaleshift_or_power, fuse_pad, transpose_fully_connected_weights
40 from mo.middle.passes.eliminate import graph_clean_up_tf
41 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
42 from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
43 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
44 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
45 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
46 from mo.middle.passes.infer import convert_mul_add_to_power, update_fully_connected_shapes
47 from mo.middle.passes.leaky_relu import convert_mul_eltwise_to_leaky_relu
48 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
49 from mo.middle.passes.shape import convert_squeeze, convert_reshape, reverse_input_channels, \
50     conv_flatten_concat, fuse_sequence_of_reshapes, repack_fully_connected_weights_nhwc_to_nchw, \
51     apply_nhwc_to_nchw_permutation, permute_data_nodes_attrs, permute_op_nodes_attrs, merge_nodes_permutations
52 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
53 from mo.pipeline.common import prepare_emit_ir
54 from mo.utils import class_registration, tensorboard
55 from mo.utils.cli_parser import get_meta_info
56 from mo.utils.error import Error
57 from mo.utils.utils import refer_to_faq_msg
58
59
60 def tf2nx(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str,
61           is_binary: bool):
62     """
63     Convert TF GraphDef object to NetworkX representation.
64     The resulting graph is still TF-specific and needs normalization passes to be applied.
65     The specific TF structure assumes each GraphDef node is converted to a single
66     NetworkX node, node id is an original TF node name, and edges go directly from one op   to another op.
67     """
68     meta_info = get_meta_info(argv)
69
70     if argv.tensorflow_custom_layer_libraries:
71         libraries = argv.tensorflow_custom_layer_libraries.split(',')
72         for library in libraries:
73             log.info('Loading library "{}" with custom operations'.format(library))
74             tf.load_op_library(library)
75
76     graph_def, variables_values = load_tf_graph_def(graph_file_name=model_file_name, is_binary=is_binary,
77                                                     checkpoint=argv.input_checkpoint,
78                                                     user_output_node_names_list=argv.output,
79                                                     model_dir=argv.saved_model_dir,
80                                                     meta_graph_file=argv.input_meta_graph,
81                                                     saved_model_tags=argv.saved_model_tags)
82
83     try:
84         tf.import_graph_def(graph_def, name='')
85     except:
86         log.warning("TensorFlow post-processing of loaded model was unsuccessful. "
87                     "This is an optional step that Model Optimizer performs for any input model but it is not usually "
88                     "required for all models."
89                     "It likely means that the original model is ill-formed. "
90                     "Model Optimizer will continue converting this model.")
91
92     log.debug("Number of nodes in graph_def: {}".format(len(graph_def.node)))  # pylint: disable=no-member
93
94     if argv.tensorboard_logdir:
95         tensorboard.dump_for_tensorboard(graph_def, argv.tensorboard_logdir)
96
97     update_extractors_with_extensions(tf_op_extractors)
98
99     try:
100         graph = protobuf2nx(graph_def)
101         graph.__setattr__('name', output_model_name)
102         # 'layout' parameter change may cause an issue in EltwiseInputReshape replacer
103         # and convert_nhwc_to_nchw(graph)
104         graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC'
105         graph.graph['cmd_params'] = argv
106         graph.graph['fw'] = 'tf'
107         graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
108
109         graph.graph['variables_values'] = variables_values
110         del variables_values
111
112         graph = restore_edges(graph, get_tf_edges)
113         graph = remove_control_dependency_inputs(graph)
114     except Exception as e:
115         raise Error(
116             'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \
117             'File is corrupt or has unsupported format. Details: {}. ' +
118             refer_to_faq_msg(44),
119             model_file_name,
120             str(e)
121         ) from e
122
123     graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
124     extract_node_attrs(graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
125
126     # --------------------------------- LOAD END ------------------------------------------------------
127     class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
128     class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
129
130     fuse_pad(graph)
131     graph_clean_up_tf(graph)
132
133     convert_matmul_to_fully_connected(graph)
134
135     # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
136     for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
137
138     # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
139     # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
140     convert_batch_norm(graph)
141     graph_clean_up_tf(graph)
142
143     if not argv.disable_fusing:
144         # Converting ScaleShift layer to Mul->Add
145         for_graph_and_each_sub_graph_recursively(graph, convert_scale_shift_to_mul_add)
146         for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
147
148         # Fusing the sequences of Mul/Add operations
149         for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence)
150         for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
151
152         # Fusing linear operation to Convolution
153         for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
154         for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
155
156     if not argv.disable_gfusing:
157         grouped_convolutions_fusing(graph)
158         graph_clean_up_tf(graph)
159         if not argv.disable_fusing:
160             fuse_linear_ops(graph)
161             graph_clean_up_tf(graph)
162
163     # Converting Mul->Add to ScaleShift node
164     for_graph_and_each_sub_graph_recursively(graph, convert_muladd_to_scaleshift_or_power)
165     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
166
167     for_graph_and_each_sub_graph_recursively(graph, convert_mul_add_to_power)
168
169     # Need to eliminate dead nodes before doing update_fully_connected_shapes
170     # because update_fully_connected_shapes does partial inference and dead
171     # nodes will lead to sporadic failures.
172     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
173     for_graph_and_each_sub_graph_recursively(graph, update_fully_connected_shapes)
174
175     for_graph_and_each_sub_graph_recursively(graph, convert_mul_eltwise_to_leaky_relu)
176     graph_clean_up_tf(graph)
177     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
178
179     for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
180     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
181
182     for_graph_and_each_sub_graph_recursively(graph, convert_reshape)
183     for_graph_and_each_sub_graph_recursively(graph, convert_squeeze)
184
185     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
186
187     for_graph_and_each_sub_graph_recursively(graph, convert_add_or_mul_to_scaleshift)  # scale = 1
188     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
189
190     if argv.reverse_input_channels:
191         reverse_input_channels(graph)
192
193     if argv.move_to_preprocess:
194         move_scaleshift_to_preprocess(graph)
195         graph_clean_up_tf(graph)
196
197     fuse_sequence_of_reshapes(graph)
198
199     pattern = EltwiseInputNormalize()
200     pattern.find_and_replace_pattern(graph)
201
202     conv_flatten_concat(graph)
203
204     if argv.enable_concat_optimization:
205         ConcatOptimization().find_and_replace_pattern(graph)
206
207     LayoutChangeForConstantShapePaths().find_and_replace_pattern(graph)
208     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
209
210     for_graph_and_each_sub_graph_recursively(graph, apply_nhwc_to_nchw_permutation)
211     for_graph_and_each_sub_graph_recursively(graph, merge_nodes_permutations)
212     for_graph_and_each_sub_graph_recursively(graph, permute_data_nodes_attrs)
213     for_graph_and_each_sub_graph_recursively(graph, permute_op_nodes_attrs)
214
215     for_graph_and_each_sub_graph_recursively(graph, repack_fully_connected_weights_nhwc_to_nchw)
216     for_graph_and_each_sub_graph_recursively(graph, transpose_fully_connected_weights)
217
218     for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
219
220     class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
221
222     for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
223     CreateConstNodesReplacement().find_and_replace_pattern(graph)
224
225     for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
226
227     prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
228                     meta_info=meta_info)
229
230     return 0