4 Copyright (c) 2018-2019 Intel Corporation
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
10 http://www.apache.org/licenses/LICENSE-2.0
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
16 limitations under the License.
23 import tensorflow as tf
25 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
26 unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert']
29 def children(op_name: str, graph: tf.Graph):
30 op = graph.get_operation_by_name(op_name)
31 return set(op for out in op.outputs for op in out.consumers())
34 def summarize_graph(graph_def):
38 with graph.as_default(): # pylint: disable=not-context-manager
39 tf.import_graph_def(graph_def, name='')
40 for node in graph.as_graph_def().node: # pylint: disable=no-member
41 if node.op == 'Placeholder':
43 node_dict['type'] = tf.DType(node.attr['dtype'].type).name
44 node_dict['shape'] = str(tf.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1')
45 placeholders[node.name] = node_dict
46 if len(children(node.name, graph)) == 0:
47 if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types:
48 outputs.append(node.name)
50 result['inputs'] = placeholders
51 result['outputs'] = outputs
55 if __name__ == "__main__": # pragma: no cover
56 sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
57 from mo.front.tf.loader import load_tf_graph_def
59 parser = argparse.ArgumentParser()
60 parser.add_argument("--input_model", type=str, help="Path to tensorflow model", default="")
61 parser.add_argument('--input_model_is_text', dest='text',
62 help='TensorFlow*: treat the input model file as a text protobuf format. If not specified, '
63 'the Model Optimizer treats it as a binary file by default.', action='store_true',
65 parser.add_argument('--input_meta', action='store_true',
66 help='TensorFlow*: treat the input model file as a meta graph def format', default=False)
67 parser.add_argument("--input_checkpoint", type=str, help='TensorFlow variables file to load.', default="")
68 parser.add_argument('--saved_model_dir', type=str, default="", help="TensorFlow saved_model_dir")
69 parser.add_argument('--saved_model_tags', type=str, default="",
70 help="Group of tag(s) of the MetaGraphDef to load, in string \
71 format, separated by ','. For tag-set contains multiple tags, all tags must be passed in.")
73 argv = parser.parse_args()
74 if not argv.input_model and not argv.saved_model_dir:
75 print("[ ERROR ] Please, provide --input_model and --input_model_is_text if needed or --input_dir for saved "
78 if argv.input_model and argv.saved_model_dir:
79 print("[ ERROR ] Both keys were provided --input_model and --input_dir. Please, provide only one of them")
81 graph_def, _ = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text,
82 checkpoint=argv.input_checkpoint,
83 model_dir=argv.saved_model_dir, saved_model_tags=argv.saved_model_tags)
84 summary = summarize_graph(graph_def)
85 print("{} input(s) detected:".format(len(summary['inputs'])))
86 for input in summary['inputs']:
87 print("Name: {}, type: {}, shape: {}".format(input, summary['inputs'][input]['type'],
88 summary['inputs'][input]['shape']))
89 print("{} output(s) detected:".format(len(summary['outputs'])))
90 print(*summary['outputs'], sep="\n")