Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / bcq-tools / generate_bcq_output_arrays
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 #
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
9 #
10 #    http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
17
18 import tensorflow as tf
19
20 import argparse
21 import sys
22
23
24 def _get_parser():
25     """
26     Returns an ArgumentParser for generating output_arrays.
27     """
28     parser = argparse.ArgumentParser(
29         description=("Command line tool to generated output_arrays of BCQ nodes"))
30
31     # Input and output path.
32     parser.add_argument(
33         "-i",
34         "--input_path",
35         type=str,
36         help="Full filepath of the input file.",
37         required=True)
38     parser.add_argument(
39         "-m",
40         "--metadata_path",
41         type=str,
42         help="Full filepath for the file that provides metadata.",
43         required=True)
44     parser.add_argument(
45         "-A",
46         "--output_arrays_path",
47         type=str,
48         help="Full filepath for the file that provides output arrays",
49         required=True)
50
51     return parser
52
53
54 # This function is copied from
55 # https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
56 def load_graph(model_file):
57     graph = tf.Graph()
58     graph_def = tf.compat.v1.GraphDef()
59
60     with open(model_file, "rb") as f:
61         graph_def.ParseFromString(f.read())
62     with graph.as_default():
63         tf.import_graph_def(graph_def, name="")
64
65     return graph
66
67
68 def find_bcq_version(flags):
69     """
70     If BCQ metadata exists, BCQ version is in the second element.
71     Return -1 when the metadata is not found.
72     """
73     graph = load_graph(flags.input_path)
74     graph_def = graph.as_graph_def()
75     for node in graph_def.node:
76         if node.op == "Const" and "one_compiler/bcqinfo_one_metadata" in node.name:
77             metadata_tensor = tf.make_ndarray(node.attr["value"].tensor)
78             return metadata_tensor[1]
79     return -1
80
81
82 def print_bcqinfo_output_arrays_v1(flags):
83     """
84     This function generates a file which includes output arrays of BCQ v1
85     information bundles. Each bundle is consisted with one of candidate
86     operations (BCQ may be applied) and BCQ constant nodes related with
87     the operation.
88     """
89     graph = load_graph(flags.input_path)
90     graph_def = graph.as_graph_def()
91     ops = graph.get_operations()
92
93     # If there is a constant node named PREFIX_1/bcqinfo_alpha,
94     # it is used for applying BCQ to constant node named PREFIX_1.
95     # Collected prefixes will be used for connecting
96     # bcqinfo nodes and user operations of prefix nodes.
97     prefix_set = set()
98     has_dequant_weight = False
99     for op in ops:
100         if op.type == "Const" and "/bcqinfo_" in op.outputs[0].name:
101             # Metadata do not have prefix
102             if "one_compiler/bcqinfo_one_metadata" in op.outputs[0].name:
103                 continue
104
105             prefix_index = op.outputs[0].name.index("/bcqinfo_")
106             prefix = op.outputs[0].name[:prefix_index]
107             prefix_set.add(prefix)
108
109             # Usually, output name of op is like "outputname:0"
110             # -2 is for removing ":0"
111             infoname = op.outputs[0].name[prefix_index + 1:-2]
112             if infoname == "bcqinfo_dequant_weight":
113                 has_dequant_weight = True
114
115     # Write the name of metadata node
116     with open(flags.metadata_path, 'w') as f_metadata:
117         f_metadata.write("one_compiler/bcqinfo_one_metadata,")
118
119     # Write all pairs of a constant node and related BCQ information nodes.
120     with open(flags.output_arrays_path, 'w') as f_arrays:
121         for prefix in prefix_set:
122             f_arrays.write("," + prefix + "/bcqinfo_do_w_x")
123             f_arrays.write("," + prefix + "/bcqinfo_alpha")
124             f_arrays.write("," + prefix + "/bcqinfo_packed_binary_code")
125             f_arrays.write("," + prefix + "/bcqinfo_number_of_clusters")
126             f_arrays.write("," + prefix + "/bcqinfo_size_of_clusters")
127             f_arrays.write("," + prefix + "/bcqinfo_qbits_of_clusters")
128             f_arrays.write("," + prefix)
129             if has_dequant_weight:
130                 f_arrays.write("," + prefix + "/bcqinfo_dequant_weight")
131
132
133 def print_bcq_output_arrays(flags):
134     program_version = 1
135     model_version = find_bcq_version(flags)
136     
137     if model_version == 1:
138         print_bcqinfo_output_arrays_v1(flags)
139     elif model_version == -1:
140         # When BCQ information not found, print nothing.
141         f_metadata = open(flags.metadata_path, 'w')
142         f_arrays = open(flags.output_arrays_path, 'w')
143         f_metadata.close()
144         f_arrays.close()
145     else:
146         err_msg = "BCQ version of the model(v{}) ".format(model_version)
147         err_msg += "is higher than "
148         err_msg += "the version supported by this program(v{})".format(program_version)
149         raise SystemExit(err_msg)
150
151
152 def main():
153     # Parse argument.
154     parser = _get_parser()
155     flags = parser.parse_known_args(args=sys.argv[1:])
156
157     print_bcq_output_arrays(flags[0])
158
159
160 if __name__ == "__main__":
161     main()