Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / summarize_graph.py
1 #!/usr/bin/env python3
2
3 """
4  Copyright (c) 2018-2019 Intel Corporation
5
6  Licensed under the Apache License, Version 2.0 (the "License");
7  you may not use this file except in compliance with the License.
8  You may obtain a copy of the License at
9
10       http://www.apache.org/licenses/LICENSE-2.0
11
12  Unless required by applicable law or agreed to in writing, software
13  distributed under the License is distributed on an "AS IS" BASIS,
14  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  See the License for the specific language governing permissions and
16  limitations under the License.
17 """
18
19 import argparse
20 import os
21 import sys
22
23 import tensorflow as tf
24
25 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
26 unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert']
27
28
29 def children(op_name: str, graph: tf.Graph):
30     op = graph.get_operation_by_name(op_name)
31     return set(op for out in op.outputs for op in out.consumers())
32
33
34 def summarize_graph(graph_def):
35     placeholders = dict()
36     outputs = list()
37     graph = tf.Graph()
38     with graph.as_default():  # pylint: disable=not-context-manager
39         tf.import_graph_def(graph_def, name='')
40     for node in graph.as_graph_def().node:  # pylint: disable=no-member
41         if node.op == 'Placeholder':
42             node_dict = dict()
43             node_dict['type'] = tf.DType(node.attr['dtype'].type).name
44             node_dict['shape'] = str(tf.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1')
45             placeholders[node.name] = node_dict
46         if len(children(node.name, graph)) == 0:
47             if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types:
48                 outputs.append(node.name)
49     result = dict()
50     result['inputs'] = placeholders
51     result['outputs'] = outputs
52     return result
53
54
55 if __name__ == "__main__":  # pragma: no cover
56     sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
57     from mo.front.tf.loader import load_tf_graph_def
58
59     parser = argparse.ArgumentParser()
60     parser.add_argument("--input_model", type=str, help="Path to tensorflow model", default="")
61     parser.add_argument('--input_model_is_text', dest='text',
62                         help='TensorFlow*: treat the input model file as a text protobuf format. If not specified, '
63                              'the Model Optimizer treats it as a binary file by default.', action='store_true',
64                         default=False)
65     parser.add_argument('--input_meta', action='store_true',
66                         help='TensorFlow*: treat the input model file as a meta graph def format', default=False)
67     parser.add_argument("--input_checkpoint", type=str, help='TensorFlow variables file to load.', default="")
68     parser.add_argument('--saved_model_dir', type=str, default="", help="TensorFlow saved_model_dir")
69     parser.add_argument('--saved_model_tags', type=str, default="",
70                         help="Group of tag(s) of the MetaGraphDef to load, in string \
71                           format, separated by ','. For tag-set contains multiple tags, all tags must be passed in.")
72
73     argv = parser.parse_args()
74     if not argv.input_model and not argv.saved_model_dir:
75         print("[ ERROR ] Please, provide --input_model and --input_model_is_text if needed or --input_dir for saved "
76               "model directory")
77         sys.exit(1)
78     if argv.input_model and argv.saved_model_dir:
79         print("[ ERROR ] Both keys were provided --input_model and --input_dir. Please, provide only one of them")
80         sys.exit(1)
81     graph_def, _ = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text,
82                                      checkpoint=argv.input_checkpoint,
83                                      model_dir=argv.saved_model_dir, saved_model_tags=argv.saved_model_tags)
84     summary = summarize_graph(graph_def)
85     print("{} input(s) detected:".format(len(summary['inputs'])))
86     for input in summary['inputs']:
87         print("Name: {}, type: {}, shape: {}".format(input, summary['inputs'][input]['type'],
88                                                      summary['inputs'][input]['shape']))
89     print("{} output(s) detected:".format(len(summary['outputs'])))
90     print(*summary['outputs'], sep="\n")