Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / bcq-tools / preserve_bcq_info
1 #!/usr/bin/env python3
2
3 import tensorflow as tf
4 import numpy as np
5
6 import argparse
7 import sys
8
9
10 def _get_parser():
11     """
12     Returns an ArgumentParser for preserving BCQ information.
13     """
14     parser = argparse.ArgumentParser(
15         description=("Command line tool to preserve BCQ information"))
16
17     # Input and output path.
18     parser.add_argument(
19         "-i",
20         "--input_path",
21         type=str,
22         help="Full filepath of the input file.",
23         required=True)
24     parser.add_argument(
25         "-o",
26         "--output_path",
27         type=str,
28         help="Full filepath of the output file.",
29         required=True)
30
31     return parser
32
33
34 def load_graph(frozen_graph_filename):
35     """
36     Load graph from frozen pb file
37     """
38     with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
39         graph_def = tf.compat.v1.GraphDef()
40         graph_def.ParseFromString(f.read())
41     with tf.Graph().as_default() as graph:
42         tf.import_graph_def(graph_def, name='')
43     return graph
44
45
46 def preserve_bcq_info(flags):
47     """
48     Generate unique dummy value from -1 to -N.
49
50     We use negative values to preserve BCQ information because
51     positive values may cause some confusion with real BCQ information values.
52     """
53
54     class UniqueValueGen:
55         def __init__(self):
56             self.unique_value = -1
57
58         def gen(self):
59             val = self.unique_value
60             self.unique_value = val - 1
61             return val
62
63     unique_value = UniqueValueGen()
64
65     original_graph_model = load_graph(flags.input_path)
66     original_graph_model_def = original_graph_model.as_graph_def()
67
68     new_graph = tf.compat.v1.GraphDef()
69     substitution_dict = {}
70
71     DT_INT32 = None  # Just for copying DT_INT32 attribute value
72
73     for node in original_graph_model_def.node:
74         if node.op == "Const":
75             # Because bcqinfo_do_w_x is BOOL type, we cannot add dummy value at the end.
76             # Therefore we should convert the type to INT32 type.
77             if "/bcqinfo_do_w_x" in node.name:
78                 original_tensor = tf.make_ndarray(node.attr["value"].tensor)
79                 substitution_dict[node.name] = tf.make_tensor_proto(
80                     [int(original_tensor[0]), unique_value.gen()], tf.int32)
81
82             preserved_bcqinfo_list = ["/bcqinfo_number_of_clusters", "/bcqinfo_size_of_clusters", 
83                 "/bcqinfo_qbits_of_clusters"]
84
85             if any(name in node.name for name in preserved_bcqinfo_list):
86                 original_tensor = tf.make_ndarray(
87                     node.attr["value"].tensor)  # variable name change
88                 substitution_dict[node.name] = tf.make_tensor_proto(
89                     np.append(original_tensor, unique_value.gen()), tf.int32)
90                 DT_INT32 = node.attr["dtype"]
91
92     for node in original_graph_model_def.node:
93         if node.name in substitution_dict:
94             new_node = new_graph.node.add()
95             new_node.op = "Const"
96             new_node.name = node.name
97             new_node.attr["dtype"].CopyFrom(DT_INT32)
98             new_node.attr["value"].tensor.CopyFrom(substitution_dict[node.name])
99         else:
100             new_node = new_graph.node.add()
101             new_node.CopyFrom(node)
102
103     tf.io.write_graph(new_graph, '.', flags.output_path, False)
104
105
106 def main():
107     # Parse argument.
108     parser = _get_parser()
109     flags = parser.parse_known_args(args=sys.argv[1:])
110
111     # Generate a new pb file, which BCQ information is preserved.
112     preserve_bcq_info(flags[0])
113
114
115 if __name__ == "__main__":
116     main()