0cc13188075f2b1695bffac8731bbed06d13e43c
[platform/core/ml/nnfw.git] / compiler / bcq-tools / generate_bcq_output_arrays.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 #
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
9 #
10 #    http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
17
18 import tensorflow as tf
19
20 import argparse
21 import sys
22
23
24 # This function is copied from
25 # https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
26 def load_graph(model_file):
27     graph = tf.Graph()
28     graph_def = tf.compat.v1.GraphDef()
29
30     with open(model_file, "rb") as f:
31         graph_def.ParseFromString(f.read())
32     with graph.as_default():
33         tf.import_graph_def(graph_def, name="")
34
35     return graph
36
37
38 def get_bcq_version(input_path):
39     """
40     If BCQ metadata exists, BCQ version is in the second element.
41     Return -1 when the metadata is not found.
42     """
43     graph = load_graph(input_path)
44     graph_def = graph.as_graph_def()
45     for node in graph_def.node:
46         if node.op == "Const" and "one_compiler/bcqinfo_one_metadata" in node.name:
47             metadata_tensor = tf.make_ndarray(node.attr["value"].tensor)
48             return metadata_tensor[1]
49     return -1
50
51
52 def get_bcqinfo_output_arrays_v1(input_path, output_arrays):
53     """
54     This function generates a file which includes output arrays of BCQ v1
55     information bundles. Each bundle is consisted with one of candidate
56     operations (BCQ may be applied) and BCQ constant nodes related with
57     the operation.
58     """
59     graph = load_graph(input_path)
60     ops = graph.get_operations()
61
62     # If there is a constant node named PREFIX_1/bcqinfo_alpha,
63     # it is used for applying BCQ to constant node named PREFIX_1.
64     # Collected prefixes will be used for connecting
65     # bcqinfo nodes and user operations of prefix nodes.
66     prefix_set = set()
67     has_dequant_weight = False
68     for op in ops:
69         if op.type == "Const" and "/bcqinfo_" in op.outputs[0].name:
70             # Metadata do not have prefix
71             if "one_compiler/bcqinfo_one_metadata" in op.outputs[0].name:
72                 continue
73
74             prefix_index = op.outputs[0].name.index("/bcqinfo_")
75             prefix = op.outputs[0].name[:prefix_index]
76             prefix_set.add(prefix)
77
78             # Usually, output name of op is like "outputname:0"
79             # -2 is for removing ":0"
80             infoname = op.outputs[0].name[prefix_index + 1:-2]
81             if infoname == "bcqinfo_dequant_weight":
82                 has_dequant_weight = True
83
84     # Ideal situation is that the user nodes of BCQ applicable constant nodes
85     # are BCQ applicable operations such as MatMul, GatherV2, etc.
86     # However, operations which do not change original values such as
87     # Ideneity or Transpose can exist between them. In view of TensorFlow Lite,
88     # real user nodes of BCQ applicable constant nodes must be found first.
89     # This work is done by BFS search with queue.
90
91     prefix_node_dict = {}  # key : prefix / value : list of candidates
92     matmul_node_prefix_dict = {}  # key : Name of MatMul node / value : prefix
93
94     queue_prefix = list(prefix_set)
95     queue_nodename = [queue_prefix[idx] + ":0" for idx in range(len(queue_prefix))]
96
97     while len(queue_prefix) > 0:
98         prefix = queue_prefix.pop(0)
99         nodename = queue_nodename.pop(0)
100         if prefix not in prefix_node_dict.keys():
101             prefix_node_dict[prefix] = []
102
103         # Usually, output name of op is like "outputname:0"
104         # -2 is for removing ":0"
105         for op in ops:
106             if op.type == "MatMul" and (op.inputs[0].name == nodename
107                                         or op.inputs[1].name == nodename):
108                 prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
109                 matmul_node_prefix_dict[op.outputs[0].name[:-2]] = prefix
110             elif op.type == "Einsum" and (op.inputs[0].name == nodename
111                                           or op.inputs[1].name == nodename):
112                 prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
113             elif op.type == "GatherV2" and op.inputs[0].name == nodename:
114                 prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
115             elif len(op.outputs) == 1:
116                 for i in range(len(op.inputs)):
117                     if op.inputs[i].name == nodename:
118                         queue_prefix.append(prefix)
119                         queue_nodename.append(op.outputs[0].name)
120                         break
121
122     # When TensorFlow model is converted to TensorFlow Lite model,
123     # more than one operation can be fused as one.
124     # For example, MatMul + BiasAdd + ReLU in TensorFlow can be fused as
125     # one FullyConnected in TensorFlow Lite.
126     # It means that even real user nodes of BCQ applicable constant nodes
127     # in TensorFlow are found, they may be real user nodes in TensorFlow Lite.
128     # Therefore additional candidates of real user nodes should be found either.
129     # Finding additional candidates is done by BFS search with queue.
130
131     fuseop_prefix_dict = {}  # key : Candidate operation / Value : prefix
132
133     # These ops can be candidate. However other candidates may exists after these ops.
134     mark_type = ["Add", "AddV2", "BiasAdd", "Reshape", "Transpose"]
135
136     # These ops can be candidate. And no more candidates will be found after these ops.
137     mark_and_stop_type = ["Relu", "Relu6", "Tanh"]
138
139     # These ops cannot be candidates but other candidates may exists after these ops.
140     # NOTE : Some of following ops may be removed from the list but not sure for now.
141     pass_type = [
142         "BatchToSpaceND", "Cast", "DepthToSpace", "ExpandDims", "ResizeBilinear",
143         "ResizeNearestNeighbor", "ScatterNd", "SpaceToBatchND", "SpaceToDepth", "Squeeze",
144         "Identity", "Pack", "Unpack", "Stack"
145     ]
146
147     queue_prefix = list(matmul_node_prefix_dict.values())
148     queue_nodename = [matmul + ":0" for matmul in matmul_node_prefix_dict.keys()]
149
150     visited_nodes = set(queue_nodename)
151     while len(queue_prefix) > 0:
152         prefix = queue_prefix.pop(0)
153         nodename = queue_nodename.pop(0)
154
155         # Usually, output name of op is like "outputname:0"
156         # -2 is for removing ":0"
157         for op in ops:
158             for i in range(len(op.inputs)):
159                 if nodename == op.inputs[i].name:
160                     if op.type in mark_type:
161                         if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
162                             fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
163                         fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
164                         if op.outputs[0].name not in visited_nodes:
165                             queue_prefix.append(prefix)
166                             queue_nodename.append(op.outputs[0].name)
167                             visited_nodes.add(op.outputs[0].name)
168                     elif op.type in mark_and_stop_type:
169                         if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
170                             fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
171                         fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
172                     elif op.type in pass_type and op.outputs[0].name not in visited_nodes:
173                         queue_prefix.append(prefix)
174                         queue_nodename.append(op.outputs[0].name)
175                         visited_nodes.add(op.outputs[0].name)
176
177     # the name of metadata node
178     ret_output_arrays = ['one_compiler/bcqinfo_one_metadata']
179
180     # given node from user
181     ret_output_arrays.append(output_arrays)
182
183     # all pairs of candidate operations and related BCQ information nodes
184     for prefix in prefix_set:
185         for fusable_op in prefix_node_dict[prefix]:
186             ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
187             ret_output_arrays.append(prefix + '/bcqinfo_alpha')
188             ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
189             ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
190             ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
191             ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
192             ret_output_arrays.append(fusable_op)
193             if has_dequant_weight:
194                 ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
195     for fuseop in fuseop_prefix_dict.keys():
196         if len(fuseop_prefix_dict[fuseop]) == 1:
197             prefix = fuseop_prefix_dict[fuseop].pop()
198             ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
199             ret_output_arrays.append(prefix + '/bcqinfo_alpha')
200             ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
201             ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
202             ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
203             ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
204             ret_output_arrays.append(fuseop)
205             if has_dequant_weight:
206                 ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
207
208     return ret_output_arrays
209
210
211 def get_bcq_output_arrays(input_path, output_arrays):
212     """Returns BCQ output arrays that the model from input_path has"""
213     program_version = 1
214     model_version = get_bcq_version(input_path)
215
216     if model_version == 1:
217         return get_bcqinfo_output_arrays_v1(input_path, output_arrays)
218     elif model_version == -1:
219         return None
220     else:
221         err_msg = "BCQ version of the model(v{}) ".format(model_version)
222         err_msg += "is higher than "
223         err_msg += "the version supported by this program(v{})".format(program_version)
224         raise SystemExit(err_msg)