2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseBCQPass.h"
19 #include <luci/IR/CircleNodes.h>
29 * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied
30 * are connected with their name. And their names include common prefix.
31 * However, after pb file is converted to tflite file, some nodes' name are changed.
32 * Thus this function will return original common prefix.
34 * @note All the re-naming rule of TFLite converter is not figured out.
35 * Therefore, if new naming rule is detected, this function should be updated.
37 const std::string node_name_prefix(luci::NodeName node_name)
39 std::string prefix = node_name;
41 if (prefix.find("ReadVariableOp/resource/") != std::string::npos)
43 const auto start_index = prefix.find("ReadVariableOp/resource/");
45 const auto left_prefix = prefix.substr(0, start_index);
46 const auto right_prefix = prefix.substr(start_index + 24);
48 prefix = left_prefix + right_prefix;
51 if (prefix.find("Tensordot/") != std::string::npos)
53 const auto index = prefix.find("Tensordot/");
54 prefix = prefix.substr(0, index - 1);
56 else if (prefix.find("kernel/") != std::string::npos)
58 const auto index = prefix.find("kernel/");
59 prefix = prefix.substr(0, index - 1);
61 else if (prefix.find("/bcqinfo_") != std::string::npos)
63 const auto index = prefix.find("/bcqinfo_");
64 prefix = prefix.substr(0, index);
75 class BCQConverter final
78 void add_BCQ_info_node(luci::CircleConst *node)
80 const auto node_name = node->name();
81 const auto prefix = node_name_prefix(node_name);
83 // If bcqinfo_* nodes are held by Reshape operation,
84 // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation.
85 // Then the name becomes bcqinfo_*_copy_shape.
86 // We should prevent this node not to added to bcq information.
87 if (node_name.find("_copy_shape") != std::string::npos)
90 if (node_name.find("bcqinfo_do_w_x") != std::string::npos)
91 _do_w_x[prefix] = node;
92 else if (node_name.find("bcqinfo_alpha") != std::string::npos)
93 _alpha[prefix] = node;
94 else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos)
95 _packed_binary_code[prefix] = node;
96 else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos)
97 _number_of_clusters[prefix] = node;
98 else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos)
99 _size_of_clusters[prefix] = node;
100 else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos)
101 _qbits_of_clusters[prefix] = node;
102 else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos)
103 _dequant_weight[prefix] = node;
106 bool has_BCQ_info(luci::CircleConst *node)
108 const auto prefix = node_name_prefix(node->name());
109 bool has_info = true;
111 has_info &= (_do_w_x.find(prefix) != _do_w_x.end());
112 has_info &= (_alpha.find(prefix) != _alpha.end());
113 has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end());
114 has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end());
115 has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end());
116 has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end());
117 // bcqinfo_dequant_weight is just for validation, so not always exists.
122 bool do_w_x(luci::CircleConst *node)
124 const auto prefix = node_name_prefix(node->name());
126 if (_do_w_x[prefix]->dtype() == loco::DataType::S32)
127 return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1;
128 else if (_do_w_x[prefix]->dtype() == loco::DataType::BOOL)
129 return _do_w_x[prefix]->at<loco::DataType::BOOL>(0);
131 throw std::runtime_error("do_w_x should be int or bool");
134 luci::CircleConst *get_alpha(luci::CircleConst *node)
136 const auto prefix = node_name_prefix(node->name());
137 return _alpha[prefix];
140 luci::CircleConst *get_packed_binary_code(luci::CircleConst *node)
142 const auto prefix = node_name_prefix(node->name());
143 return _packed_binary_code[prefix];
146 luci::CircleConst *get_number_of_clusters(luci::CircleConst *node)
148 const auto prefix = node_name_prefix(node->name());
149 return _number_of_clusters[prefix];
152 luci::CircleConst *get_size_of_clusters(luci::CircleConst *node)
154 const auto prefix = node_name_prefix(node->name());
155 return _size_of_clusters[prefix];
158 luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node)
160 const auto prefix = node_name_prefix(node->name());
161 return _qbits_of_clusters[prefix];
164 luci::CircleConst *packed_clusters(luci::CircleConst *node)
166 auto graph = node->graph();
167 auto qbits_of_clusters = get_qbits_of_clusters(node);
168 auto size_of_clusters = get_size_of_clusters(node);
169 const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0);
171 auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
172 packed_clusters->dtype(loco::DataType::S32);
173 packed_clusters->size<loco::DataType::S32>(number_of_clusters * 2);
174 packed_clusters->rank(2);
175 packed_clusters->dim(0) = number_of_clusters;
176 packed_clusters->dim(1) = 2;
177 packed_clusters->shape_status(luci::ShapeStatus::VALID);
179 for (int i = 0; i < number_of_clusters; ++i)
181 packed_clusters->at<loco::DataType::S32>(i * 2) =
182 qbits_of_clusters->at<loco::DataType::S32>(i);
183 packed_clusters->at<loco::DataType::S32>(i * 2 + 1) =
184 size_of_clusters->at<loco::DataType::S32>(i);
187 return packed_clusters;
191 * @brief Exclude BCQ information nodes which are used for fusing BCQ operations
192 * from graph output by using CircleOutputExclude
194 void clear_BCQ_nodes()
196 auto createNoOp = [](luci::CircleNode *circle_node) {
197 auto graph = circle_node->graph();
198 auto noOp = graph->nodes()->create<luci::CircleOutputExclude>();
200 if (circle_node->shape_status() == luci::ShapeStatus::VALID)
202 noOp->dtype(circle_node->dtype());
203 noOp->rank(circle_node->rank());
204 for (uint32_t i = 0; i < circle_node->rank(); ++i)
205 noOp->dim(i) = circle_node->dim(i);
209 // For type inference
210 noOp->dtype(loco::DataType::FLOAT32);
216 auto clear_nodes = [createNoOp](std::map<std::string, luci::CircleConst *> &nodes) {
217 for (auto &n : nodes)
219 auto node = n.second;
221 for (auto s : loco::succs(node))
223 if (auto outnode = dynamic_cast<luci::CircleOutput *>(s))
225 outnode->from(createNoOp(node));
227 else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s))
229 for (auto o : loco::succs(reshape_node))
231 auto circle_output = loco::must_cast<luci::CircleOutput *>(o);
232 circle_output->from(createNoOp(reshape_node));
239 clear_nodes(_do_w_x);
241 clear_nodes(_packed_binary_code);
242 clear_nodes(_number_of_clusters);
243 clear_nodes(_size_of_clusters);
244 clear_nodes(_qbits_of_clusters);
245 clear_nodes(_dequant_weight);
249 std::map<std::string, luci::CircleConst *> _do_w_x;
250 std::map<std::string, luci::CircleConst *> _alpha;
251 std::map<std::string, luci::CircleConst *> _packed_binary_code;
252 std::map<std::string, luci::CircleConst *> _number_of_clusters;
253 std::map<std::string, luci::CircleConst *> _size_of_clusters;
254 std::map<std::string, luci::CircleConst *> _qbits_of_clusters;
255 std::map<std::string, luci::CircleConst *> _dequant_weight;
263 bool FuseBCQPass::run(loco::Graph *g)
265 BCQConverter converter;
267 bool changed = false;
269 for (auto node : loco::all_nodes(g))
271 if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
273 converter.add_BCQ_info_node(circle_const);
277 for (auto node : loco::active_nodes(loco::output_nodes(g)))
279 if (auto gather = dynamic_cast<luci::CircleGather *>(node))
281 auto params = dynamic_cast<luci::CircleConst *>(gather->params());
282 if (params != nullptr && converter.has_BCQ_info(params))
284 auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
286 bcq_gather->input_scales(converter.get_alpha(params));
287 bcq_gather->input_binary(converter.get_packed_binary_code(params));
288 bcq_gather->indices(gather->indices());
289 bcq_gather->input_clusters(converter.packed_clusters(params));
291 const auto binary_hidden_size =
292 loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32;
293 bcq_gather->input_hidden_size(binary_hidden_size);
295 if (converter.do_w_x(params))
297 bcq_gather->axis(gather->axis());
301 const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
302 bcq_gather->axis(axis_transpose);
305 loco::replace(gather).with(bcq_gather);
310 else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
312 auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights());
313 if (weights != nullptr && converter.has_BCQ_info(weights))
315 auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
317 bcq_fc->weights_scales(converter.get_alpha(weights));
318 bcq_fc->weights_binary(converter.get_packed_binary_code(weights));
319 bcq_fc->bias(fully_connected->bias());
320 bcq_fc->weights_clusters(converter.packed_clusters(weights));
321 bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
323 loco::Node *bcq_input = fully_connected->input();
324 int32_t batch_rank = 0;
326 // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
327 const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
328 if (original_input->shape_status() == ShapeStatus::VALID && original_input->rank() > 2)
330 auto new_shape = g->nodes()->create<luci::CircleConst>();
331 new_shape->dtype(loco::DataType::S32);
332 new_shape->size<loco::DataType::S32>(2);
334 new_shape->dim(0) = 2;
337 for (uint32_t i = 0; i < original_input->rank() - 1; ++i)
338 batch_size *= original_input->dim(i).value();
340 new_shape->at<loco::DataType::S32>(0) = batch_size;
341 new_shape->at<loco::DataType::S32>(1) =
342 original_input->dim(original_input->rank() - 1).value();
343 new_shape->shape_status(ShapeStatus::VALID);
345 auto reshape = g->nodes()->create<luci::CircleReshape>();
346 reshape->tensor(original_input);
347 reshape->shape(new_shape);
350 batch_rank = original_input->rank() - 2;
353 // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
354 if (converter.do_w_x(weights))
356 const auto binary_hidden_size =
357 loco::must_cast<luci::CircleNode *>(fully_connected->input())
360 bcq_fc->weights_hidden_size(binary_hidden_size);
361 bcq_fc->input(bcq_input);
362 loco::replace(fully_connected).with(bcq_fc);
366 const auto binary_hidden_size =
367 loco::must_cast<luci::CircleNode *>(fully_connected->input())
368 ->dim(1 + batch_rank)
370 bcq_fc->weights_hidden_size(binary_hidden_size);
372 auto perm = g->nodes()->create<luci::CircleConst>();
373 perm->dtype(loco::DataType::S32);
374 perm->size<loco::DataType::S32>(2);
377 perm->at<loco::DataType::S32>(0) = 1;
378 perm->at<loco::DataType::S32>(1) = 0;
379 perm->shape_status(ShapeStatus::VALID);
381 auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
382 input_transpose->a(bcq_input);
383 input_transpose->perm(perm);
385 bcq_fc->input(input_transpose);
387 auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
388 output_transpose->a(bcq_fc);
389 output_transpose->perm(perm);
391 loco::replace(fully_connected).with(output_transpose);
400 converter.clear_BCQ_nodes();