Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / bcq-tools / generate_bcq_output_arrays
1 #!/usr/bin/env python3
2
3 import tensorflow as tf
4
5 import argparse
6 import sys
7
8
9 def _get_parser():
10     """
11     Returns an ArgumentParser for generating output_arrays.
12     """
13     parser = argparse.ArgumentParser(
14         description=("Command line tool to generated output_arrays of BCQ nodes"))
15
16     # Input and output path.
17     parser.add_argument(
18         "-i",
19         "--input_path",
20         type=str,
21         help="Full filepath of the input file.",
22         required=True)
23     parser.add_argument(
24         "-o",
25         "--output_path",
26         type=str,
27         help="Full filepath of the output file.",
28         required=True)
29
30     return parser
31
32
33 def load_graph(frozen_graph_filename):
34     """
35     Load graph from frozen pb file
36     """
37     with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
38         graph_def = tf.compat.v1.GraphDef()
39         graph_def.ParseFromString(f.read())
40     with tf.Graph().as_default() as graph:
41         tf.import_graph_def(graph_def, name='')
42     return graph
43
44
45 def dtype2str(dtype):
46     if dtype == "int32":
47         return "TF_INT32"
48     elif dtype == "int64":
49         return "TF_INT64"
50     elif dtype == "float32":
51         return "TF_FLOAT"
52     elif dtype == "bool":
53         return "TF_BOOL"
54     else:
55         raise Exception("Not supported dtype")
56
57
58 def print_output_arrays(flags):
59     graph_model = load_graph(flags.input_path)
60     graph_model_def = graph_model.as_graph_def()
61     ops = graph_model.get_operations()
62
63     output_names = [op.outputs[0].name for op in ops 
64         if op.type == "Const" and "bcqinfo_" in op.outputs[0].name]
65
66     output_arrays = ""    
67     for output_name in output_names:
68         output_arrays += ","
69
70         colon_index = output_name.find(":")
71         if colon_index == -1:
72             output_arrays += output_name
73         else:
74             output_arrays += output_name[:colon_index]
75
76     f = open(flags.output_path, 'w')
77     f.write(output_arrays)
78     f.close()
79
80
81 def main():
82     # Parse argument.
83     parser = _get_parser()
84     flags = parser.parse_known_args(args=sys.argv[1:])
85
86     print_output_arrays(flags[0])
87
88
89 if __name__ == "__main__":
90     main()