3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
18 import tensorflow as tf
24 # This function is copied from
25 # https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
26 def load_graph(model_file):
28 graph_def = tf.compat.v1.GraphDef()
30 with open(model_file, "rb") as f:
31 graph_def.ParseFromString(f.read())
32 with graph.as_default():
33 tf.import_graph_def(graph_def, name="")
38 def get_bcq_version(input_path):
40 If BCQ metadata exists, BCQ version is in the second element.
41 Return -1 when the metadata is not found.
43 graph = load_graph(input_path)
44 graph_def = graph.as_graph_def()
45 for node in graph_def.node:
46 if node.op == "Const" and "one_compiler/bcqinfo_one_metadata" in node.name:
47 metadata_tensor = tf.make_ndarray(node.attr["value"].tensor)
48 return metadata_tensor[1]
52 def get_bcqinfo_output_arrays_v1(input_path, output_arrays):
54 This function generates a file which includes output arrays of BCQ v1
55 information bundles. Each bundle is consisted with one of candidate
56 operations (BCQ may be applied) and BCQ constant nodes related with
59 graph = load_graph(input_path)
60 ops = graph.get_operations()
62 # If there is a constant node named PREFIX_1/bcqinfo_alpha,
63 # it is used for applying BCQ to constant node named PREFIX_1.
64 # Collected prefixes will be used for connecting
65 # bcqinfo nodes and user operations of prefix nodes.
67 has_dequant_weight = False
69 if op.type == "Const" and "/bcqinfo_" in op.outputs[0].name:
70 # Metadata do not have prefix
71 if "one_compiler/bcqinfo_one_metadata" in op.outputs[0].name:
74 prefix_index = op.outputs[0].name.index("/bcqinfo_")
75 prefix = op.outputs[0].name[:prefix_index]
76 prefix_set.add(prefix)
78 # Usually, output name of op is like "outputname:0"
79 # -2 is for removing ":0"
80 infoname = op.outputs[0].name[prefix_index + 1:-2]
81 if infoname == "bcqinfo_dequant_weight":
82 has_dequant_weight = True
84 # Ideal situation is that the user nodes of BCQ applicable constant nodes
85 # are BCQ applicable operations such as MatMul, GatherV2, etc.
86 # However, operations which do not change original values such as
87 # Ideneity or Transpose can exist between them. In view of TensorFlow Lite,
88 # real user nodes of BCQ applicable constant nodes must be found first.
89 # This work is done by BFS search with queue.
91 prefix_node_dict = {} # key : prefix / value : list of candidates
92 matmul_node_prefix_dict = {} # key : Name of MatMul node / value : prefix
94 queue_prefix = list(prefix_set)
95 queue_nodename = [queue_prefix[idx] + ":0" for idx in range(len(queue_prefix))]
97 while len(queue_prefix) > 0:
98 prefix = queue_prefix.pop(0)
99 nodename = queue_nodename.pop(0)
100 if prefix not in prefix_node_dict.keys():
101 prefix_node_dict[prefix] = []
103 # Usually, output name of op is like "outputname:0"
104 # -2 is for removing ":0"
106 if op.type == "MatMul" and (op.inputs[0].name == nodename
107 or op.inputs[1].name == nodename):
108 prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
109 matmul_node_prefix_dict[op.outputs[0].name[:-2]] = prefix
110 elif op.type == "Einsum" and (op.inputs[0].name == nodename
111 or op.inputs[1].name == nodename):
112 prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
113 elif op.type == "GatherV2" and op.inputs[0].name == nodename:
114 prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
115 elif len(op.outputs) == 1:
116 for i in range(len(op.inputs)):
117 if op.inputs[i].name == nodename:
118 queue_prefix.append(prefix)
119 queue_nodename.append(op.outputs[0].name)
122 # When TensorFlow model is converted to TensorFlow Lite model,
123 # more than one operation can be fused as one.
124 # For example, MatMul + BiasAdd + ReLU in TensorFlow can be fused as
125 # one FullyConnected in TensorFlow Lite.
126 # It means that even real user nodes of BCQ applicable constant nodes
127 # in TensorFlow are found, they may be real user nodes in TensorFlow Lite.
128 # Therefore additional candidates of real user nodes should be found either.
129 # Finding additional candidates is done by BFS search with queue.
131 fuseop_prefix_dict = {} # key : Candidate operation / Value : prefix
133 # These ops can be candidate. However other candidates may exists after these ops.
134 mark_type = ["Add", "AddV2", "BiasAdd", "Reshape", "Transpose"]
136 # These ops can be candidate. And no more candidates will be found after these ops.
137 mark_and_stop_type = ["Relu", "Relu6", "Tanh"]
139 # These ops cannot be candidates but other candidates may exists after these ops.
140 # NOTE : Some of following ops may be removed from the list but not sure for now.
142 "BatchToSpaceND", "Cast", "DepthToSpace", "ExpandDims", "ResizeBilinear",
143 "ResizeNearestNeighbor", "ScatterNd", "SpaceToBatchND", "SpaceToDepth", "Squeeze",
144 "Identity", "Pack", "Unpack", "Stack"
147 queue_prefix = list(matmul_node_prefix_dict.values())
148 queue_nodename = [matmul + ":0" for matmul in matmul_node_prefix_dict.keys()]
150 visited_nodes = set(queue_nodename)
151 while len(queue_prefix) > 0:
152 prefix = queue_prefix.pop(0)
153 nodename = queue_nodename.pop(0)
155 # Usually, output name of op is like "outputname:0"
156 # -2 is for removing ":0"
158 for i in range(len(op.inputs)):
159 if nodename == op.inputs[i].name:
160 if op.type in mark_type:
161 if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
162 fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
163 fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
164 if op.outputs[0].name not in visited_nodes:
165 queue_prefix.append(prefix)
166 queue_nodename.append(op.outputs[0].name)
167 visited_nodes.add(op.outputs[0].name)
168 elif op.type in mark_and_stop_type:
169 if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
170 fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
171 fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
172 elif op.type in pass_type and op.outputs[0].name not in visited_nodes:
173 queue_prefix.append(prefix)
174 queue_nodename.append(op.outputs[0].name)
175 visited_nodes.add(op.outputs[0].name)
177 # the name of metadata node
178 ret_output_arrays = ['one_compiler/bcqinfo_one_metadata']
180 # given node from user
181 ret_output_arrays.append(output_arrays)
183 # all pairs of candidate operations and related BCQ information nodes
184 for prefix in prefix_set:
185 for fusable_op in prefix_node_dict[prefix]:
186 ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
187 ret_output_arrays.append(prefix + '/bcqinfo_alpha')
188 ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
189 ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
190 ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
191 ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
192 ret_output_arrays.append(fusable_op)
193 if has_dequant_weight:
194 ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
195 for fuseop in fuseop_prefix_dict.keys():
196 if len(fuseop_prefix_dict[fuseop]) == 1:
197 prefix = fuseop_prefix_dict[fuseop].pop()
198 ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
199 ret_output_arrays.append(prefix + '/bcqinfo_alpha')
200 ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
201 ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
202 ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
203 ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
204 ret_output_arrays.append(fuseop)
205 if has_dequant_weight:
206 ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
208 return ret_output_arrays
211 def get_bcq_output_arrays(input_path, output_arrays):
212 """Returns BCQ output arrays that the model from input_path has"""
214 model_version = get_bcq_version(input_path)
216 if model_version == 1:
217 return get_bcqinfo_output_arrays_v1(input_path, output_arrays)
218 elif model_version == -1:
221 err_msg = "BCQ version of the model(v{}) ".format(model_version)
222 err_msg += "is higher than "
223 err_msg += "the version supported by this program(v{})".format(program_version)
224 raise SystemExit(err_msg)