Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / loader.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 import os
19 import re
20
21 from mo.utils.error import Error, FrameworkError
22 from mo.utils.utils import refer_to_faq_msg
23
24 try:
25     import tensorflow as tf
26 except ImportError:
27     raise Error('Module tensorflow was not found. Please install tensorflow 1.2 or higher. ' +
28                 refer_to_faq_msg(42))
29
30 from google.protobuf import text_format
31 from mo.graph.graph import create_graph_with_nodes, Graph
32 from mo.utils.summarize_graph import summarize_graph
33
34
35 def freeze_checkpoints(graph_def: tf.GraphDef, checkpoint_dir: str, output_node_names: list):
36     """
37     Loads all the variables in a graph and stores them in a separate dictionary. Freezes output nodes in the graph
38     :param graph_def: GraphDef object holding the network.
39     :param checkpoint_dir: path to directory with checkpoint files with values of graph variables.
40     :param output_node_names: list of output node names.
41     :return: GraphDef containing a simplified version of the original.
42     """
43     log.debug("Loading checkpoint files from directory: {}".format(checkpoint_dir))
44     checkpoint_files = []
45     for checkpoint_name in sorted(os.listdir(checkpoint_dir)):
46         checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
47         if os.path.isfile(checkpoint_path):
48             checkpoint_files.append(checkpoint_path)
49             log.debug("File {} will be loaded".format(checkpoint_path))
50         else:
51             log.debug("Path {} is not a file. Skipping")
52
53     if len(checkpoint_files) == 0:
54         raise Error("There are no checkpoint files in directory: {}".format(checkpoint_dir))
55
56     tf.import_graph_def(graph_def, name='')
57
58     with tf.Session() as sess:
59         uninitialized_variables = [str(v, 'utf-8') for v in set(sess.run(tf.report_uninitialized_variables()))]
60         all_variables = [n.name for n in sess.graph.as_graph_def().node if n.op in ['Variable', 'VariableV2']]
61         white_list = [v for v in all_variables if v not in uninitialized_variables]
62         black_list = [v for v in all_variables if v in uninitialized_variables]
63         output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names,
64                                                                         variable_names_whitelist=white_list,
65                                                                         variable_names_blacklist=black_list)
66     variable_values = {}
67     for checkpoint_file in checkpoint_files:
68         log.debug("Loading {}".format(checkpoint_file))
69         with tf.Session() as sess:
70             var_list = {}
71             var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint_file).get_variable_to_shape_map()
72             for key in var_to_shape_map:
73                 try:
74                     tensor = sess.graph.get_operation_by_name(key).outputs[0]
75                 except KeyError:
76                     continue
77                 var_list[key] = tensor
78             tf.train.Saver(var_list=var_list).restore(sess, checkpoint_file)
79             for name, tensor in var_list.items():
80                 variable_values[name] = sess.run(tensor)
81     return output_graph_def, variable_values
82
83
84 def freeze_checkpoint(graph_def, checkpoint, output_node_names):
85     """
86     Replaces all the variables in a graph with constants of the same values.
87     :param graph_def: GraphDef object holding the network.
88     :param checkpoint: path to checkpoint file with values of variables.
89     :param output_node_names: list of output node names
90     :return: GraphDef containing a simplified version of the original.
91     """
92     tf.import_graph_def(graph_def, name="")
93
94     with tf.Session() as sess:
95         var_list = {}
96         var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint).get_variable_to_shape_map()
97         for key in var_to_shape_map:
98             try:
99                 tensor = sess.graph.get_operation_by_name(key).outputs[0]
100             except KeyError:
101                 continue
102             var_list[key] = tensor
103         tf.train.Saver(var_list=var_list).restore(sess, checkpoint)
104         output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names)
105     return output_graph_def
106
107
108 def read_file_to_graph_def(graph_def: [tf.GraphDef, tf.MetaGraphDef], graph_file_name: str = "",
109                            is_binary: bool = True):
110     """
111     Reads file to protobuf
112     :param graph_def: GraphDef orr MetaGraphDef object to store the network
113     :param graph_file_name: path to file with graph
114     :param is_binary: flag to switch between binary and test protobuf format of graph file
115     :return: GraphDef or MetaGaphDef containing the network with cleared device info.
116     """
117     try:
118         if is_binary:
119             with open(graph_file_name, "rb") as f:
120                 graph_def.ParseFromString(f.read())
121         else:
122             with open(graph_file_name, "r") as f:
123                 text_format.Merge(f.read(), graph_def)
124         nodes_to_clear_device = graph_def.node if isinstance(graph_def, tf.GraphDef) else graph_def.graph_def.node
125         for node in nodes_to_clear_device:
126             node.device = ""
127     except Exception as e:
128         raise FrameworkError(
129             'TensorFlow cannot read the model file: "{}" is incorrect TensorFlow model file. '
130             '\nThe file should contain one of the following TensorFlow graphs:'
131             '\n1. frozen graph in text or binary format'
132             '\n2. inference graph for freezing with checkpoint (--input_checkpoint) in text or binary format'
133             '\n3. meta graph'
134             '\n\nMake sure that --input_model_is_text is provided for a model in text format. '
135             'By default, a model is interpreted in binary format. Framework error details: {}. ' +
136             refer_to_faq_msg(43),
137             graph_file_name,
138             str(e)
139         ) from e
140     return graph_def
141
142
143 def get_output_node_names_list(graph_def, user_defined_output_node_names_list: list):
144     return summarize_graph(graph_def)['outputs'] \
145         if user_defined_output_node_names_list is None or len(user_defined_output_node_names_list) == 0 \
146         else user_defined_output_node_names_list
147
148
149 def deducing_metagraph_path(meta_graph_file: str):
150     match = re.search('^(.*)\.(data-\d*-of-\d*|index|meta)$', meta_graph_file)
151     if match is not None:
152         deduced_meta_graph_file = match.group(1) + '.meta'
153         if not os.path.isfile(deduced_meta_graph_file):
154             raise Error('\n\nMetaGraph freezing mechanism was enabled. '
155                         '\n{} file does not represent MetaGraph. '
156                         '\n{} path to MetaGraph was deduced, but it does not exist'
157                         '\n\nModel with MetaGraph consists of 3-4 files:'
158                         '\n1. model_name.meta'
159                         '\n2. model_name.index'
160                         '\n3. model_name.data-00000-of-00001 (digit part may vary)'
161                         '\n4. checkpoint (optional)'.format(meta_graph_file, deduced_meta_graph_file))
162         else:
163             meta_graph_file = deduced_meta_graph_file
164     else:
165         raise Error('\n\nMetaGraph freezing mechanism was enabled. '
166                     '\n{} file does not represent MetaGraph. '
167                     '\n\nModel with MetaGraph consists of 3-4 files:'
168                     '\n1. model_name.meta'
169                     '\n2. model_name.index'
170                     '\n3. model_name.data-00000-of-00001 (digit part may vary)'
171                     '\n4. checkpoint (optional)'
172                     '\n\nTo load this model, simply run:'
173                     '\npython3 mo_tf.py --input_meta_graph model_name.meta'
174                     ''.format(meta_graph_file))
175     return meta_graph_file
176
177
178 def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpoint: str = "",
179                       model_dir: str = "", saved_model_tags: list = [], meta_graph_file: str = "",
180                       user_output_node_names_list: list = []):
181     # As a provisional solution, use a native TF methods to load a model protobuf
182     graph_def = tf.GraphDef()
183     if isinstance(graph_file_name, str) and (re.match('.*\.(ckpt|meta)$', graph_file_name)):
184         print('[ WARNING ] The value for the --input_model command line parameter ends with ".ckpt" or ".meta" '
185               'extension.\n'
186               'It means that the model is not frozen.\n'
187               'To load non frozen model to Model Optimizer run:'
188               '\n\n1. For "*.ckpt" file:'
189               '\n- if inference graph is in binary format'
190               '\npython3 mo_tf.py --input_model "path/to/inference_graph.pb" --input_checkpoint "path/to/*.ckpt"'
191               '\n- if inference graph is in text format'
192               '\npython3 mo_tf.py --input_model "path/to/inference_graph.pbtxt" --input_model_is_text '
193               '--input_checkpoint "path/to/*.ckpt"'
194               '\n\n2. For "*.meta" file:'
195               '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"')
196     variables_values = {}
197     try:
198         if graph_file_name and not meta_graph_file and not checkpoint:
199             # frozen graph
200             return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values
201         if graph_file_name and not meta_graph_file and checkpoint:
202             # inference graph and checkpoint
203             graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary)
204             outputs = get_output_node_names_list(graph_def, user_output_node_names_list)
205             if os.path.isfile(checkpoint):
206                 graph_def = freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs)
207             elif os.path.isdir(checkpoint):
208                 graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint,
209                                                                  output_node_names=outputs)
210             # we are sure that checkpoint is existing file or directory due to cli_parser configuration
211             return graph_def, variables_values
212         if not graph_file_name and meta_graph_file:
213             meta_graph_file = deducing_metagraph_path(meta_graph_file)
214             input_meta_graph_def = read_file_to_graph_def(tf.MetaGraphDef(), meta_graph_file, is_binary)
215             # pylint: disable=no-member
216             with tf.Session() as sess:
217                 restorer = tf.train.import_meta_graph(input_meta_graph_def)
218                 restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file))
219                 outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list)
220                 graph_def = tf.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def, outputs)
221                 return graph_def, variables_values
222         if model_dir:
223             # saved model directory
224             tags = saved_model_tags if saved_model_tags is not None else [tf.saved_model.tag_constants.SERVING]
225             with tf.Session() as sess:
226                 meta_graph_def = tf.saved_model.loader.load(sess, tags, model_dir)
227                 outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list)
228                 graph_def = tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
229                 return graph_def, variables_values
230     except Exception as e:
231         raise FrameworkError('Cannot load input model: {}', e) from e
232     raise Error("Unknown configuration of input model parameters")
233
234
235 def protobuf_attrs(pb: tf.NodeDef):
236     return {'pb': pb}
237
238
239 def protobuf2nx(pb: tf.GraphDef):
240     graph = create_graph_with_nodes(pb.node, get_id=lambda pb: pb.name, get_attrs=protobuf_attrs)
241     # initial order of nodes in the GraphDef. It is used to specify order in
242     # which merged nodes are added to the generated sub-graph GraphDef for the TensorFlow offload feature.
243     graph.graph['initial_nodes_order'] = [node.name for node in pb.node]
244
245     # Remove data dependency edges. This is needed for the TF offload case
246     for _, attrs in list(graph.nodes(data=True)):
247         pb = attrs['pb']
248         if '_class' in pb.attr:
249             index = 0
250             while index < len(pb.attr['_class'].list.s):
251                 if re.match('^loc:@.*', pb.attr['_class'].list.s[index].decode('utf-8')):
252                     del pb.attr['_class'].list.s[index]
253                 else:
254                     index = index + 1
255
256     return graph
257
258
259 def variables_to_constants(graph: Graph, variables_values: dict):
260     """
261     Converts `Variable<V2>` operations to FakeConst operations with `value` from `variables_values` dictionary
262     :param graph: graph to operate on
263     :param variables_values: dictionary with variable names as keys and np.array data as values
264     """
265     for node in graph.get_op_nodes(op='FakeConst'):
266         node_name = node.name
267
268         if node_name not in variables_values:
269             log.debug("There is no value for '{}': {} in checkpoint variable values".format(node.op, node_name))
270             continue
271
272         node['value'] = variables_values[node_name]