2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseBCQPass.h"
19 #include <luci/IR/CircleNodes.h>
29 * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied
30 * are connected with their name. And their names include common prefix.
31 * However, after pb file is converted to tflite file, some nodes' name are changed.
32 * Thus this function will return original common prefix.
34 * @note All the re-naming rule of TFLite converter is not figured out.
35 * Therefore, if new naming rule is detected, this function should be updated.
37 const std::string node_name_prefix(luci::NodeName node_name)
39 std::string prefix = node_name;
41 if (prefix.find("ReadVariableOp/resource/") != std::string::npos)
43 const auto start_index = prefix.find("ReadVariableOp/resource/");
45 const auto left_prefix = prefix.substr(0, start_index);
46 const auto right_prefix = prefix.substr(start_index + 24);
48 prefix = left_prefix + right_prefix;
51 if (prefix.find("Tensordot/") != std::string::npos)
53 const auto index = prefix.find("Tensordot/");
54 prefix = prefix.substr(0, index - 1);
56 else if (prefix.find("/MatMul") != std::string::npos)
58 const auto index = prefix.find("/MatMul");
59 prefix = prefix.substr(0, index);
61 else if (prefix.find("kernel/") != std::string::npos)
63 const auto index = prefix.find("kernel/");
64 prefix = prefix.substr(0, index - 1);
66 else if (prefix.find("/bcqinfo_") != std::string::npos)
68 const auto index = prefix.find("/bcqinfo_");
69 prefix = prefix.substr(0, index);
76 * @brief Create CircleOutputExclude operation, which has same shape and dtype with
77 * original circle_node.
79 luci::CircleOutputExclude *createNoOp(luci::CircleNode *circle_node)
81 auto graph = circle_node->graph();
82 auto noOp = graph->nodes()->create<luci::CircleOutputExclude>();
84 if (circle_node->shape_status() == luci::ShapeStatus::VALID)
86 noOp->dtype(circle_node->dtype());
87 noOp->rank(circle_node->rank());
88 for (uint32_t i = 0; i < circle_node->rank(); ++i)
89 noOp->dim(i) = circle_node->dim(i);
94 noOp->dtype(loco::DataType::FLOAT32);
105 // V means the version of BCQ.
106 template <int32_t V> class BCQFuser;
108 template <> class BCQFuser<1>
111 bool fuseBCQ(loco::Graph *g)
113 bool changed = false;
115 for (auto node : loco::all_nodes(g))
117 if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
119 add_BCQ_info_node(circle_const);
123 if (!is_bcqinfo_valid())
126 for (auto node : loco::active_nodes(loco::output_nodes(g)))
128 if (auto gather = dynamic_cast<luci::CircleGather *>(node))
130 auto params = dynamic_cast<luci::CircleConst *>(gather->params());
131 if (params != nullptr && has_BCQ_info(params))
133 auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
135 bcq_gather->op_version(1);
136 bcq_gather->input_scales(get_alpha(params));
137 bcq_gather->input_binary(get_packed_binary_code(params));
138 bcq_gather->indices(gather->indices());
139 bcq_gather->input_clusters(packed_clusters(params));
141 // input_binary shape : [output_size, hidden_size]
142 const auto binary_hidden_size =
143 loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32;
144 bcq_gather->input_hidden_size(binary_hidden_size);
148 bcq_gather->axis(gather->axis());
152 const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
153 bcq_gather->axis(axis_transpose);
156 loco::replace(gather).with(bcq_gather);
161 else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
163 auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights());
164 if (weights != nullptr && has_BCQ_info(weights))
166 auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
168 bcq_fc->op_version(1);
169 bcq_fc->weights_scales(get_alpha(weights));
170 bcq_fc->weights_binary(get_packed_binary_code(weights));
171 bcq_fc->bias(fully_connected->bias());
172 bcq_fc->weights_clusters(packed_clusters(weights));
173 bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
175 loco::Node *bcq_input = fully_connected->input();
176 int32_t batch_rank = 0;
178 // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
179 const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
180 if (original_input->shape_status() == luci::ShapeStatus::VALID &&
181 original_input->rank() > 2)
183 auto new_shape = g->nodes()->create<luci::CircleConst>();
184 new_shape->dtype(loco::DataType::S32);
185 new_shape->size<loco::DataType::S32>(2);
187 new_shape->dim(0) = 2;
190 for (uint32_t i = 0; i < original_input->rank() - 1; ++i)
191 batch_size *= original_input->dim(i).value();
193 new_shape->at<loco::DataType::S32>(0) = batch_size;
194 new_shape->at<loco::DataType::S32>(1) =
195 original_input->dim(original_input->rank() - 1).value();
196 new_shape->shape_status(luci::ShapeStatus::VALID);
198 auto reshape = g->nodes()->create<luci::CircleReshape>();
199 reshape->tensor(original_input);
200 reshape->shape(new_shape);
203 batch_rank = original_input->rank() - 2;
206 // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
209 const auto binary_hidden_size =
210 loco::must_cast<luci::CircleNode *>(fully_connected->input())
213 bcq_fc->weights_hidden_size(binary_hidden_size);
214 bcq_fc->input(bcq_input);
215 loco::replace(fully_connected).with(bcq_fc);
219 const auto binary_hidden_size =
220 loco::must_cast<luci::CircleNode *>(fully_connected->input())
221 ->dim(1 + batch_rank)
223 bcq_fc->weights_hidden_size(binary_hidden_size);
225 auto perm = g->nodes()->create<luci::CircleConst>();
226 perm->dtype(loco::DataType::S32);
227 perm->size<loco::DataType::S32>(2);
230 perm->at<loco::DataType::S32>(0) = 1;
231 perm->at<loco::DataType::S32>(1) = 0;
232 perm->shape_status(luci::ShapeStatus::VALID);
234 auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
235 input_transpose->a(bcq_input);
236 input_transpose->perm(perm);
238 bcq_fc->input(input_transpose);
240 auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
241 output_transpose->a(bcq_fc);
242 output_transpose->perm(perm);
244 loco::replace(fully_connected).with(output_transpose);
259 void add_BCQ_info_node(luci::CircleConst *node)
261 const auto node_name = node->name();
262 const auto prefix = node_name_prefix(node_name);
264 // If bcqinfo_* nodes are held by Reshape operation,
265 // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation.
266 // Then the name becomes bcqinfo_*_copy_shape.
267 // We should prevent this node not to added to bcq information.
268 if (node_name.find("_copy_shape") != std::string::npos)
271 if (node_name.find("bcqinfo_do_w_x") != std::string::npos)
272 _do_w_x[prefix] = node;
273 else if (node_name.find("bcqinfo_alpha") != std::string::npos)
274 _alpha[prefix] = node;
275 else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos)
276 _packed_binary_code[prefix] = node;
277 else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos)
278 _number_of_clusters[prefix] = node;
279 else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos)
280 _size_of_clusters[prefix] = node;
281 else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos)
282 _qbits_of_clusters[prefix] = node;
283 else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos)
284 _dequant_weight[prefix] = node;
287 bool has_BCQ_info(luci::CircleConst *node)
289 const auto prefix = node_name_prefix(node->name());
290 bool has_info = true;
292 has_info &= (_do_w_x.find(prefix) != _do_w_x.end());
293 has_info &= (_alpha.find(prefix) != _alpha.end());
294 has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end());
295 has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end());
296 has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end());
297 has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end());
298 // bcqinfo_dequant_weight is just for validation, so not always exists.
304 * @brief Exclude BCQ information nodes which are used for fusing BCQ operations
305 * from graph output by using CircleOutputExclude
307 void clear_BCQ_nodes()
309 auto clear_nodes = [](std::map<std::string, luci::CircleConst *> &nodes) {
310 for (auto &n : nodes)
312 auto node = n.second;
314 for (auto s : loco::succs(node))
316 if (auto outnode = dynamic_cast<luci::CircleOutput *>(s))
318 outnode->from(createNoOp(node));
320 else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s))
322 for (auto o : loco::succs(reshape_node))
324 auto circle_output = loco::must_cast<luci::CircleOutput *>(o);
325 circle_output->from(createNoOp(reshape_node));
332 clear_nodes(_do_w_x);
334 clear_nodes(_packed_binary_code);
335 clear_nodes(_number_of_clusters);
336 clear_nodes(_size_of_clusters);
337 clear_nodes(_qbits_of_clusters);
338 clear_nodes(_dequant_weight);
341 bool is_bcqinfo_valid()
343 // do_w_x should be int32 or bool type
344 for (auto n : _do_w_x)
346 if (n.second->dtype() != loco::DataType::BOOL && n.second->dtype() != loco::DataType::S32)
354 bool do_w_x(luci::CircleConst *node)
356 const auto prefix = node_name_prefix(node->name());
358 if (_do_w_x[prefix]->dtype() == loco::DataType::S32)
359 return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1;
361 return _do_w_x[prefix]->at<loco::DataType::BOOL>(0);
364 luci::CircleConst *get_alpha(luci::CircleConst *node)
366 const auto prefix = node_name_prefix(node->name());
367 return _alpha[prefix];
370 luci::CircleConst *get_packed_binary_code(luci::CircleConst *node)
372 const auto prefix = node_name_prefix(node->name());
373 return _packed_binary_code[prefix];
376 luci::CircleConst *get_number_of_clusters(luci::CircleConst *node)
378 const auto prefix = node_name_prefix(node->name());
379 return _number_of_clusters[prefix];
382 luci::CircleConst *get_size_of_clusters(luci::CircleConst *node)
384 const auto prefix = node_name_prefix(node->name());
385 return _size_of_clusters[prefix];
388 luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node)
390 const auto prefix = node_name_prefix(node->name());
391 return _qbits_of_clusters[prefix];
394 luci::CircleConst *packed_clusters(luci::CircleConst *node)
396 auto graph = node->graph();
397 auto qbits_of_clusters = get_qbits_of_clusters(node);
398 auto size_of_clusters = get_size_of_clusters(node);
399 const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0);
401 auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
402 packed_clusters->dtype(loco::DataType::S32);
403 packed_clusters->size<loco::DataType::S32>(number_of_clusters * 2);
404 packed_clusters->rank(2);
405 packed_clusters->dim(0) = number_of_clusters;
406 packed_clusters->dim(1) = 2;
407 packed_clusters->shape_status(luci::ShapeStatus::VALID);
409 for (int i = 0; i < number_of_clusters; ++i)
411 packed_clusters->at<loco::DataType::S32>(i * 2) =
412 qbits_of_clusters->at<loco::DataType::S32>(i);
413 packed_clusters->at<loco::DataType::S32>(i * 2 + 1) =
414 size_of_clusters->at<loco::DataType::S32>(i);
417 return packed_clusters;
421 std::map<std::string, luci::CircleConst *> _do_w_x;
422 std::map<std::string, luci::CircleConst *> _alpha;
423 std::map<std::string, luci::CircleConst *> _packed_binary_code;
424 std::map<std::string, luci::CircleConst *> _number_of_clusters;
425 std::map<std::string, luci::CircleConst *> _size_of_clusters;
426 std::map<std::string, luci::CircleConst *> _qbits_of_clusters;
427 std::map<std::string, luci::CircleConst *> _dequant_weight;
435 bool FuseBCQPass::run(loco::Graph *g)
437 bool changed = false;
439 // Find BCQ version information and check validity.
440 luci::CircleConst *version_node = nullptr;
441 for (auto node : loco::all_nodes(g))
443 if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
445 if (circle_const->name().find("/bcqinfo_version") != std::string::npos)
447 // There should be only one bcqinfo_version in the model
448 if (version_node != nullptr)
450 assert(false && "Multiple version information found");
454 version_node = circle_const;
459 // If version node is not found, regard it as version 1.
460 int32_t bcq_version = (version_node != nullptr) ? version_node->at<loco::DataType::S32>(0) : 1;
462 if (bcq_version == 1)
463 changed = BCQFuser<1>().fuseBCQ(g);
465 assert(false && "Not supported BCQ version");
467 if (changed && version_node != nullptr)
469 // If BCQ is applied and version node was found, remove the node.
470 loco::replace(version_node).with(createNoOp(version_node));