2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 import tensorflow as tf
22 from extensions.back.CreateConstNodes import CreateConstNodesReplacement
23 from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
24 from extensions.middle.ConcatOptimization import ConcatOptimization
27 import tensorflow.contrib
29 pass # we try to import contrib for loading models that use contrib operations
31 from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
32 from mo.middle.passes.eliminate import remove_const_ops
33 from mo.front.common.register_custom_ops import check_for_duplicates
34 from mo.front.common.register_custom_ops import update_extractors_with_extensions
35 from mo.front.extractor import restore_edges, extract_node_attrs, remove_output_ops, remove_control_dependency_inputs
36 from mo.front.tf.extractor import get_tf_edges, tf_op_extractor, tf_op_extractors
37 from mo.front.tf.loader import load_tf_graph_def, protobuf2nx
38 from mo.middle.passes.conv import convert_add_or_mul_to_scaleshift, convert_matmul_to_fully_connected, \
39 convert_muladd_to_scaleshift_or_power, fuse_pad, transpose_fully_connected_weights
40 from mo.middle.passes.eliminate import graph_clean_up_tf
41 from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
42 from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
43 from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
44 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
45 from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
46 from mo.middle.passes.infer import convert_mul_add_to_power, update_fully_connected_shapes
47 from mo.middle.passes.leaky_relu import convert_mul_eltwise_to_leaky_relu
48 from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
49 from mo.middle.passes.shape import convert_squeeze, convert_reshape, reverse_input_channels, \
50 conv_flatten_concat, fuse_sequence_of_reshapes, repack_fully_connected_weights_nhwc_to_nchw, \
51 apply_nhwc_to_nchw_permutation, permute_data_nodes_attrs, permute_op_nodes_attrs, merge_nodes_permutations
52 from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
53 from mo.pipeline.common import prepare_emit_ir
54 from mo.utils import class_registration, tensorboard
55 from mo.utils.cli_parser import get_meta_info
56 from mo.utils.error import Error
57 from mo.utils.utils import refer_to_faq_msg
60 def tf2nx(argv: argparse.Namespace, model_file_name: str, output_model_name: str, output_dir: str,
63 Convert TF GraphDef object to NetworkX representation.
64 The resulting graph is still TF-specific and needs normalization passes to be applied.
65 The specific TF structure assumes each GraphDef node is converted to a single
66 NetworkX node, node id is an original TF node name, and edges go directly from one op to another op.
68 meta_info = get_meta_info(argv)
70 if argv.tensorflow_custom_layer_libraries:
71 libraries = argv.tensorflow_custom_layer_libraries.split(',')
72 for library in libraries:
73 log.info('Loading library "{}" with custom operations'.format(library))
74 tf.load_op_library(library)
76 graph_def, variables_values = load_tf_graph_def(graph_file_name=model_file_name, is_binary=is_binary,
77 checkpoint=argv.input_checkpoint,
78 user_output_node_names_list=argv.output,
79 model_dir=argv.saved_model_dir,
80 meta_graph_file=argv.input_meta_graph,
81 saved_model_tags=argv.saved_model_tags)
84 tf.import_graph_def(graph_def, name='')
86 log.warning("TensorFlow post-processing of loaded model was unsuccessful. "
87 "This is an optional step that Model Optimizer performs for any input model but it is not usually "
88 "required for all models."
89 "It likely means that the original model is ill-formed. "
90 "Model Optimizer will continue converting this model.")
92 log.debug("Number of nodes in graph_def: {}".format(len(graph_def.node))) # pylint: disable=no-member
94 if argv.tensorboard_logdir:
95 tensorboard.dump_for_tensorboard(graph_def, argv.tensorboard_logdir)
97 update_extractors_with_extensions(tf_op_extractors)
100 graph = protobuf2nx(graph_def)
101 graph.__setattr__('name', output_model_name)
102 # 'layout' parameter change may cause an issue in EltwiseInputReshape replacer
103 # and convert_nhwc_to_nchw(graph)
104 graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC'
105 graph.graph['cmd_params'] = argv
106 graph.graph['fw'] = 'tf'
107 graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 5
109 graph.graph['variables_values'] = variables_values
112 graph = restore_edges(graph, get_tf_edges)
113 graph = remove_control_dependency_inputs(graph)
114 except Exception as e:
116 'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \
117 'File is corrupt or has unsupported format. Details: {}. ' +
118 refer_to_faq_msg(44),
123 graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
124 extract_node_attrs(graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
126 # --------------------------------- LOAD END ------------------------------------------------------
127 class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
128 class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
131 graph_clean_up_tf(graph)
133 convert_matmul_to_fully_connected(graph)
135 # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
136 for_graph_and_each_sub_graph_recursively(graph, lambda graph: mark_unfused_nodes(graph, argv.finegrain_fusing))
138 # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
139 # IE doesn't support BN with 4 inputs, so we have to split it to two ScaleShift
140 convert_batch_norm(graph)
141 graph_clean_up_tf(graph)
143 if not argv.disable_fusing:
144 # Converting ScaleShift layer to Mul->Add
145 for_graph_and_each_sub_graph_recursively(graph, convert_scale_shift_to_mul_add)
146 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
148 # Fusing the sequences of Mul/Add operations
149 for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence)
150 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
152 # Fusing linear operation to Convolution
153 for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
154 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
156 if not argv.disable_gfusing:
157 grouped_convolutions_fusing(graph)
158 graph_clean_up_tf(graph)
159 if not argv.disable_fusing:
160 fuse_linear_ops(graph)
161 graph_clean_up_tf(graph)
163 # Converting Mul->Add to ScaleShift node
164 for_graph_and_each_sub_graph_recursively(graph, convert_muladd_to_scaleshift_or_power)
165 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
167 for_graph_and_each_sub_graph_recursively(graph, convert_mul_add_to_power)
169 # Need to eliminate dead nodes before doing update_fully_connected_shapes
170 # because update_fully_connected_shapes does partial inference and dead
171 # nodes will lead to sporadic failures.
172 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
173 for_graph_and_each_sub_graph_recursively(graph, update_fully_connected_shapes)
175 for_graph_and_each_sub_graph_recursively(graph, convert_mul_eltwise_to_leaky_relu)
176 graph_clean_up_tf(graph)
177 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
179 for_graph_and_each_sub_graph_recursively(graph, fuse_pad)
180 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
182 for_graph_and_each_sub_graph_recursively(graph, convert_reshape)
183 for_graph_and_each_sub_graph_recursively(graph, convert_squeeze)
185 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
187 for_graph_and_each_sub_graph_recursively(graph, convert_add_or_mul_to_scaleshift) # scale = 1
188 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
190 if argv.reverse_input_channels:
191 reverse_input_channels(graph)
193 if argv.move_to_preprocess:
194 move_scaleshift_to_preprocess(graph)
195 graph_clean_up_tf(graph)
197 fuse_sequence_of_reshapes(graph)
199 pattern = EltwiseInputNormalize()
200 pattern.find_and_replace_pattern(graph)
202 conv_flatten_concat(graph)
204 if argv.enable_concat_optimization:
205 ConcatOptimization().find_and_replace_pattern(graph)
207 LayoutChangeForConstantShapePaths().find_and_replace_pattern(graph)
208 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
210 for_graph_and_each_sub_graph_recursively(graph, apply_nhwc_to_nchw_permutation)
211 for_graph_and_each_sub_graph_recursively(graph, merge_nodes_permutations)
212 for_graph_and_each_sub_graph_recursively(graph, permute_data_nodes_attrs)
213 for_graph_and_each_sub_graph_recursively(graph, permute_op_nodes_attrs)
215 for_graph_and_each_sub_graph_recursively(graph, repack_fully_connected_weights_nhwc_to_nchw)
216 for_graph_and_each_sub_graph_recursively(graph, transpose_fully_connected_weights)
218 for_graph_and_each_sub_graph_recursively(graph, graph_clean_up_tf)
220 class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
222 for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
223 CreateConstNodesReplacement().find_and_replace_pattern(graph)
225 for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)
227 prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,