2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.utils.error import Error, FrameworkError
22 from mo.utils.utils import refer_to_faq_msg
25 import tensorflow as tf
27 raise Error('Module tensorflow was not found. Please install tensorflow 1.2 or higher. ' +
30 from google.protobuf import text_format
31 from mo.graph.graph import create_graph_with_nodes, Graph
32 from mo.utils.summarize_graph import summarize_graph
35 def freeze_checkpoints(graph_def: tf.GraphDef, checkpoint_dir: str, output_node_names: list):
37 Loads all the variables in a graph and stores them in a separate dictionary. Freezes output nodes in the graph
38 :param graph_def: GraphDef object holding the network.
39 :param checkpoint_dir: path to directory with checkpoint files with values of graph variables.
40 :param output_node_names: list of output node names.
41 :return: GraphDef containing a simplified version of the original.
43 log.debug("Loading checkpoint files from directory: {}".format(checkpoint_dir))
45 for checkpoint_name in sorted(os.listdir(checkpoint_dir)):
46 checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
47 if os.path.isfile(checkpoint_path):
48 checkpoint_files.append(checkpoint_path)
49 log.debug("File {} will be loaded".format(checkpoint_path))
51 log.debug("Path {} is not a file. Skipping")
53 if len(checkpoint_files) == 0:
54 raise Error("There are no checkpoint files in directory: {}".format(checkpoint_dir))
56 tf.import_graph_def(graph_def, name='')
58 with tf.Session() as sess:
59 uninitialized_variables = [str(v, 'utf-8') for v in set(sess.run(tf.report_uninitialized_variables()))]
60 all_variables = [n.name for n in sess.graph.as_graph_def().node if n.op in ['Variable', 'VariableV2']]
61 white_list = [v for v in all_variables if v not in uninitialized_variables]
62 black_list = [v for v in all_variables if v in uninitialized_variables]
63 output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names,
64 variable_names_whitelist=white_list,
65 variable_names_blacklist=black_list)
67 for checkpoint_file in checkpoint_files:
68 log.debug("Loading {}".format(checkpoint_file))
69 with tf.Session() as sess:
71 var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint_file).get_variable_to_shape_map()
72 for key in var_to_shape_map:
74 tensor = sess.graph.get_operation_by_name(key).outputs[0]
77 var_list[key] = tensor
78 tf.train.Saver(var_list=var_list).restore(sess, checkpoint_file)
79 for name, tensor in var_list.items():
80 variable_values[name] = sess.run(tensor)
81 return output_graph_def, variable_values
84 def freeze_checkpoint(graph_def, checkpoint, output_node_names):
86 Replaces all the variables in a graph with constants of the same values.
87 :param graph_def: GraphDef object holding the network.
88 :param checkpoint: path to checkpoint file with values of variables.
89 :param output_node_names: list of output node names
90 :return: GraphDef containing a simplified version of the original.
92 tf.import_graph_def(graph_def, name="")
94 with tf.Session() as sess:
96 var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint).get_variable_to_shape_map()
97 for key in var_to_shape_map:
99 tensor = sess.graph.get_operation_by_name(key).outputs[0]
102 var_list[key] = tensor
103 tf.train.Saver(var_list=var_list).restore(sess, checkpoint)
104 output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names)
105 return output_graph_def
108 def read_file_to_graph_def(graph_def: [tf.GraphDef, tf.MetaGraphDef], graph_file_name: str = "",
109 is_binary: bool = True):
111 Reads file to protobuf
112 :param graph_def: GraphDef orr MetaGraphDef object to store the network
113 :param graph_file_name: path to file with graph
114 :param is_binary: flag to switch between binary and test protobuf format of graph file
115 :return: GraphDef or MetaGaphDef containing the network with cleared device info.
119 with open(graph_file_name, "rb") as f:
120 graph_def.ParseFromString(f.read())
122 with open(graph_file_name, "r") as f:
123 text_format.Merge(f.read(), graph_def)
124 nodes_to_clear_device = graph_def.node if isinstance(graph_def, tf.GraphDef) else graph_def.graph_def.node
125 for node in nodes_to_clear_device:
127 except Exception as e:
128 raise FrameworkError(
129 'TensorFlow cannot read the model file: "{}" is incorrect TensorFlow model file. '
130 '\nThe file should contain one of the following TensorFlow graphs:'
131 '\n1. frozen graph in text or binary format'
132 '\n2. inference graph for freezing with checkpoint (--input_checkpoint) in text or binary format'
134 '\n\nMake sure that --input_model_is_text is provided for a model in text format. '
135 'By default, a model is interpreted in binary format. Framework error details: {}. ' +
136 refer_to_faq_msg(43),
143 def get_output_node_names_list(graph_def, user_defined_output_node_names_list: list):
144 return summarize_graph(graph_def)['outputs'] \
145 if user_defined_output_node_names_list is None or len(user_defined_output_node_names_list) == 0 \
146 else user_defined_output_node_names_list
149 def deducing_metagraph_path(meta_graph_file: str):
150 match = re.search('^(.*)\.(data-\d*-of-\d*|index|meta)$', meta_graph_file)
151 if match is not None:
152 deduced_meta_graph_file = match.group(1) + '.meta'
153 if not os.path.isfile(deduced_meta_graph_file):
154 raise Error('\n\nMetaGraph freezing mechanism was enabled. '
155 '\n{} file does not represent MetaGraph. '
156 '\n{} path to MetaGraph was deduced, but it does not exist'
157 '\n\nModel with MetaGraph consists of 3-4 files:'
158 '\n1. model_name.meta'
159 '\n2. model_name.index'
160 '\n3. model_name.data-00000-of-00001 (digit part may vary)'
161 '\n4. checkpoint (optional)'.format(meta_graph_file, deduced_meta_graph_file))
163 meta_graph_file = deduced_meta_graph_file
165 raise Error('\n\nMetaGraph freezing mechanism was enabled. '
166 '\n{} file does not represent MetaGraph. '
167 '\n\nModel with MetaGraph consists of 3-4 files:'
168 '\n1. model_name.meta'
169 '\n2. model_name.index'
170 '\n3. model_name.data-00000-of-00001 (digit part may vary)'
171 '\n4. checkpoint (optional)'
172 '\n\nTo load this model, simply run:'
173 '\npython3 mo_tf.py --input_meta_graph model_name.meta'
174 ''.format(meta_graph_file))
175 return meta_graph_file
178 def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpoint: str = "",
179 model_dir: str = "", saved_model_tags: list = [], meta_graph_file: str = "",
180 user_output_node_names_list: list = []):
181 # As a provisional solution, use a native TF methods to load a model protobuf
182 graph_def = tf.GraphDef()
183 if isinstance(graph_file_name, str) and (re.match('.*\.(ckpt|meta)$', graph_file_name)):
184 print('[ WARNING ] The value for the --input_model command line parameter ends with ".ckpt" or ".meta" '
186 'It means that the model is not frozen.\n'
187 'To load non frozen model to Model Optimizer run:'
188 '\n\n1. For "*.ckpt" file:'
189 '\n- if inference graph is in binary format'
190 '\npython3 mo_tf.py --input_model "path/to/inference_graph.pb" --input_checkpoint "path/to/*.ckpt"'
191 '\n- if inference graph is in text format'
192 '\npython3 mo_tf.py --input_model "path/to/inference_graph.pbtxt" --input_model_is_text '
193 '--input_checkpoint "path/to/*.ckpt"'
194 '\n\n2. For "*.meta" file:'
195 '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"')
196 variables_values = {}
198 if graph_file_name and not meta_graph_file and not checkpoint:
200 return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values
201 if graph_file_name and not meta_graph_file and checkpoint:
202 # inference graph and checkpoint
203 graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary)
204 outputs = get_output_node_names_list(graph_def, user_output_node_names_list)
205 if os.path.isfile(checkpoint):
206 graph_def = freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs)
207 elif os.path.isdir(checkpoint):
208 graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint,
209 output_node_names=outputs)
210 # we are sure that checkpoint is existing file or directory due to cli_parser configuration
211 return graph_def, variables_values
212 if not graph_file_name and meta_graph_file:
213 meta_graph_file = deducing_metagraph_path(meta_graph_file)
214 input_meta_graph_def = read_file_to_graph_def(tf.MetaGraphDef(), meta_graph_file, is_binary)
215 # pylint: disable=no-member
216 with tf.Session() as sess:
217 restorer = tf.train.import_meta_graph(input_meta_graph_def)
218 restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file))
219 outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list)
220 graph_def = tf.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def, outputs)
221 return graph_def, variables_values
223 # saved model directory
224 tags = saved_model_tags if saved_model_tags is not None else [tf.saved_model.tag_constants.SERVING]
225 with tf.Session() as sess:
226 meta_graph_def = tf.saved_model.loader.load(sess, tags, model_dir)
227 outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list)
228 graph_def = tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
229 return graph_def, variables_values
230 except Exception as e:
231 raise FrameworkError('Cannot load input model: {}', e) from e
232 raise Error("Unknown configuration of input model parameters")
235 def protobuf_attrs(pb: tf.NodeDef):
239 def protobuf2nx(pb: tf.GraphDef):
240 graph = create_graph_with_nodes(pb.node, get_id=lambda pb: pb.name, get_attrs=protobuf_attrs)
241 # initial order of nodes in the GraphDef. It is used to specify order in
242 # which merged nodes are added to the generated sub-graph GraphDef for the TensorFlow offload feature.
243 graph.graph['initial_nodes_order'] = [node.name for node in pb.node]
245 # Remove data dependency edges. This is needed for the TF offload case
246 for _, attrs in list(graph.nodes(data=True)):
248 if '_class' in pb.attr:
250 while index < len(pb.attr['_class'].list.s):
251 if re.match('^loc:@.*', pb.attr['_class'].list.s[index].decode('utf-8')):
252 del pb.attr['_class'].list.s[index]
259 def variables_to_constants(graph: Graph, variables_values: dict):
261 Converts `Variable<V2>` operations to FakeConst operations with `value` from `variables_values` dictionary
262 :param graph: graph to operate on
263 :param variables_values: dictionary with variable names as keys and np.array data as values
265 for node in graph.get_op_nodes(op='FakeConst'):
266 node_name = node.name
268 if node_name not in variables_values:
269 log.debug("There is no value for '{}': {} in checkpoint variable values".format(node.op, node_name))
272 node['value'] = variables_values[node_name]