2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FuseBCQPass.h"
19 #include <luci/IR/CircleNodes.h>
28 bool is_fusable_const(luci::CircleConst *before, luci::CircleConst *after, bool do_w_x)
30 if (after->dtype() != loco::DataType::FLOAT32)
33 if (after->rank() != 2)
36 if (after->size<loco::DataType::FLOAT32>() != before->size<loco::DataType::FLOAT32>())
39 auto after_dim0 = after->dim(0).value();
40 auto after_dim1 = after->dim(1).value();
42 if (before->rank() == 2)
46 // Check for [dim0, dim1] --> [dim0, dim1]
47 if (!(after->dim(0) == before->dim(0) && after->dim(1) == before->dim(1)))
50 for (uint32_t i = 0; i < after->size<loco::DataType::FLOAT32>(); ++i)
51 if (after->at<loco::DataType::FLOAT32>(i) != before->at<loco::DataType::FLOAT32>(i))
56 // Check for [dim0, dim1] --> [dim1, dim0]
57 if (!(after->dim(0) == before->dim(1) && after->dim(1) == before->dim(0)))
60 for (uint32_t i = 0; i < after_dim0; ++i)
61 for (uint32_t j = 0; j < after_dim1; ++j)
62 if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) !=
63 before->at<loco::DataType::FLOAT32>(j * after_dim0 + i))
69 else if (before->rank() == 3)
73 // This case is not found yet.
78 // When Einsum op is converted to FullyConnected, original rank can be 3.
79 auto before_dim0 = before->dim(0).value();
80 auto before_dim1 = before->dim(1).value();
81 auto before_dim2 = before->dim(2).value();
83 // Check if [dim0, dim1, dim2] --> [dim2, dim0 * dim1] or
84 // [dim0, dim1, dim2] --> [dim1 * dim2, dim0]
85 if ((after_dim0 == before_dim1 * before_dim2 && after_dim1 == before_dim0) ||
86 (after_dim0 == before_dim2 && after_dim1 == before_dim0 * before_dim1))
88 for (uint32_t i = 0; i < after_dim0; ++i)
89 for (uint32_t j = 0; j < after_dim1; ++j)
90 if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) !=
91 before->at<loco::DataType::FLOAT32>(j * after_dim0 + i))
107 // V means the version of BCQ.
108 template <int32_t V> class BCQFuser;
110 template <> class BCQFuser<1>
113 BCQFuser<1>(int32_t original_output_cnt, int32_t bundle_cnt)
114 : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt}
120 void register_bcq_info(loco::Graph *g)
122 for (auto node : loco::output_nodes(g))
124 auto output_node = loco::must_cast<luci::CircleOutput *>(node);
127 * First output of model is metadata for BCQ. Please refer to following example.
129 * When original_output_cnt is 2,
130 * BCQ_METADATA, original_output_1, original_output_2, BCQ_INFO_1, ...
132 if ((int)output_node->index() > _original_output_cnt)
134 const auto prefix = (output_node->index() - (_original_output_cnt + 1)) / (_bundle_cnt);
135 const MetadataType metadata_type = static_cast<MetadataType>(
136 (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt));
137 const auto circle_node = loco::must_cast<luci::CircleNode *>(output_node->from());
138 add_BCQ_info_node(prefix, metadata_type, circle_node);
143 bool fuseBCQ(loco::Graph *g)
145 if (!is_bcqinfo_valid())
148 for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
150 // Fuse Gather to BCQGather
151 if (auto gather = dynamic_cast<luci::CircleGather *>(node))
153 if (auto params = dynamic_cast<luci::CircleConst *>(gather->params()))
155 auto prefix = get_prefix_of_const(params);
156 if (prefix == -1 || !is_valid_prefix(prefix))
159 auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
161 bcq_gather->op_version(1);
162 bcq_gather->input_scales(alpha(g, prefix));
163 bcq_gather->input_binary(packed_binary_code(g, prefix));
164 bcq_gather->indices(gather->indices());
165 bcq_gather->input_clusters(packed_clusters(g, prefix));
167 if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
169 bcq_gather->input_hidden_size(params->dim(1).value());
170 bcq_gather->axis(gather->axis());
171 loco::replace(gather).with(bcq_gather);
175 bcq_gather->input_hidden_size(params->dim(0).value());
176 const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
177 bcq_gather->axis(axis_transpose);
179 const auto indices_rank =
180 loco::must_cast<luci::CircleNode *>(gather->indices())->rank();
182 auto perm = g->nodes()->create<luci::CircleConst>();
183 perm->dtype(loco::DataType::S32);
184 perm->size<loco::DataType::S32>(1 + indices_rank);
186 perm->dim(0) = 1 + indices_rank;
187 for (uint32_t idx = 0; idx < indices_rank; ++idx)
188 perm->at<loco::DataType::S32>(idx) = idx + 1;
189 perm->at<loco::DataType::S32>(indices_rank) = 0;
190 perm->shape_status(luci::ShapeStatus::VALID);
192 auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
193 output_transpose->a(bcq_gather);
194 output_transpose->perm(perm);
196 loco::replace(gather).with(output_transpose);
203 // Fuse FullyConnected to BCQFullyConnected
204 if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
206 if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()))
208 auto prefix = get_prefix_of_const(weights);
209 if (prefix == -1 || !is_valid_prefix(prefix))
212 auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
214 bcq_fc->op_version(1);
215 bcq_fc->weights_scales(alpha(g, prefix));
216 bcq_fc->weights_binary(packed_binary_code(g, prefix));
217 bcq_fc->bias(fully_connected->bias());
218 bcq_fc->weights_clusters(packed_clusters(g, prefix));
219 bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
221 loco::Node *bcq_input = fully_connected->input();
223 // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
224 const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
225 if (original_input->shape_status() == luci::ShapeStatus::VALID &&
226 original_input->rank() > 2)
228 auto new_shape = g->nodes()->create<luci::CircleConst>();
229 new_shape->dtype(loco::DataType::S32);
230 new_shape->size<loco::DataType::S32>(2);
232 new_shape->dim(0) = 2;
235 for (uint32_t i = 0; i < original_input->rank() - 1; ++i)
236 batch_size *= original_input->dim(i).value();
238 new_shape->at<loco::DataType::S32>(0) = batch_size;
239 new_shape->at<loco::DataType::S32>(1) =
240 original_input->dim(original_input->rank() - 1).value();
241 new_shape->shape_status(luci::ShapeStatus::VALID);
243 auto reshape = g->nodes()->create<luci::CircleReshape>();
244 reshape->tensor(original_input);
245 reshape->shape(new_shape);
250 // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
251 bcq_fc->weights_hidden_size(weights->dim(1).value());
253 auto perm = g->nodes()->create<luci::CircleConst>();
254 perm->dtype(loco::DataType::S32);
255 perm->size<loco::DataType::S32>(2);
258 perm->at<loco::DataType::S32>(0) = 1;
259 perm->at<loco::DataType::S32>(1) = 0;
260 perm->shape_status(luci::ShapeStatus::VALID);
262 auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
263 input_transpose->a(bcq_input);
264 input_transpose->perm(perm);
266 bcq_fc->input(input_transpose);
268 auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
269 output_transpose->a(bcq_fc);
270 output_transpose->perm(perm);
272 loco::replace(fully_connected).with(output_transpose);
276 else if (auto weights_as_input =
277 dynamic_cast<luci::CircleConst *>(fully_connected->input()))
279 auto prefix = get_prefix_of_const(weights_as_input);
280 if (prefix == -1 || !is_valid_prefix(prefix))
283 assert(_do_w_x[prefix]->at<loco::DataType::BOOL>(0) == true);
285 auto perm = g->nodes()->create<luci::CircleConst>();
286 perm->dtype(loco::DataType::S32);
287 perm->size<loco::DataType::S32>(2);
290 perm->at<loco::DataType::S32>(0) = 1;
291 perm->at<loco::DataType::S32>(1) = 0;
292 perm->shape_status(luci::ShapeStatus::VALID);
294 auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
295 input_transpose->a(fully_connected->weights());
296 input_transpose->perm(perm);
298 auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
300 assert(dynamic_cast<luci::CircleOutputExclude *>(fully_connected->bias()) != nullptr);
302 bcq_fc->op_version(1);
303 bcq_fc->weights_scales(alpha(g, prefix));
304 bcq_fc->weights_binary(packed_binary_code(g, prefix));
305 bcq_fc->bias(fully_connected->bias());
306 bcq_fc->weights_clusters(packed_clusters(g, prefix));
307 bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
309 bcq_fc->weights_hidden_size(weights_as_input->dim(1).value());
310 bcq_fc->input(input_transpose);
311 loco::replace(fully_connected).with(bcq_fc);
334 void add_BCQ_info_node(int32_t prefix, MetadataType metadata_type, luci::CircleNode *node)
336 if (metadata_type == MetadataType::FUSABLE_OP)
338 _fusable_op[prefix] = node;
342 luci::CircleConst *const_node;
344 // Converter in TensorFlow v1.x sometimes generate Reshape op
345 if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
346 const_node = loco::must_cast<luci::CircleConst *>(reshape->tensor());
348 const_node = loco::must_cast<luci::CircleConst *>(node);
350 if (metadata_type == MetadataType::DO_W_X)
351 _do_w_x[prefix] = const_node;
352 else if (metadata_type == MetadataType::ALPHA)
353 _alpha[prefix] = const_node;
354 else if (metadata_type == MetadataType::BINARY_CODE)
355 _packed_binary_code[prefix] = const_node;
356 else if (metadata_type == MetadataType::NUM_OF_CLUSTERS)
357 _number_of_clusters[prefix] = const_node;
358 else if (metadata_type == MetadataType::SIZE_OF_CLUSTERS)
359 _size_of_clusters[prefix] = const_node;
360 else if (metadata_type == MetadataType::QBITS_OF_CLUSTERS)
361 _qbits_of_clusters[prefix] = const_node;
363 _dequant_weight[prefix] = const_node;
366 int32_t get_prefix_of_const(luci::CircleConst *w_after)
368 for (auto n : _fusable_op)
370 auto prefix = n.first;
371 auto w_before = loco::must_cast<luci::CircleConst *>(n.second);
372 if (is_fusable_const(w_before, w_after, _do_w_x[prefix]->at<loco::DataType::BOOL>(0)))
379 bool is_bcqinfo_valid()
383 for (auto n : _do_w_x)
385 // do_w_x should be BOOL type
386 if (n.second->dtype() != loco::DataType::BOOL)
388 WARN(l) << "FuseBCQPass : do_w_x has wrong type" << std::endl;
393 for (auto n : _alpha)
395 // alpha should be FLOAT32 type
396 if (n.second->dtype() != loco::DataType::FLOAT32)
398 WARN(l) << "FuseBCQPass : alpha has wrong type" << std::endl;
403 for (auto n : _packed_binary_code)
405 // packed_binary_code should be INT32 type
406 if (n.second->dtype() != loco::DataType::S32)
408 WARN(l) << "FuseBCQPass : packed_binary_code has wrong type" << std::endl;
413 for (auto n : _number_of_clusters)
415 // number_of_clusters should be INT32 type
416 if (n.second->dtype() != loco::DataType::S32)
418 WARN(l) << "FuseBCQPass : number_of_clusters has wrong type" << std::endl;
423 for (auto n : _size_of_clusters)
425 // size_of_clusters should be INT32 type
426 if (n.second->dtype() != loco::DataType::S32)
428 WARN(l) << "FuseBCQPass : size_of_clusters has wrong type" << std::endl;
433 for (auto n : _qbits_of_clusters)
435 // qbits_of_clusters should be INT32 type
436 if (n.second->dtype() != loco::DataType::S32)
438 WARN(l) << "FuseBCQPass : qbits_of_clusters has wrong type" << std::endl;
443 for (auto n : _fusable_op)
445 // fusable_op should be FLOAT32 type
446 if (n.second->dtype() != loco::DataType::FLOAT32)
448 WARN(l) << "FuseBCQPass : fusable_op has wrong type" << std::endl;
453 // As dequant_weight is not used for fusing, skip validation.
458 bool is_valid_prefix(int32_t prefix)
462 if (_do_w_x.find(prefix) == _do_w_x.end())
464 WARN(l) << "do_w_x is not found" << std::endl;
468 if (_alpha.find(prefix) == _alpha.end())
470 WARN(l) << "alpha is not found" << std::endl;
474 if (_packed_binary_code.find(prefix) == _packed_binary_code.end())
476 WARN(l) << "packed_binary_code is not found" << std::endl;
480 if (_number_of_clusters.find(prefix) == _number_of_clusters.end())
482 WARN(l) << "number_of_clusters is not found" << std::endl;
486 if (_size_of_clusters.find(prefix) == _size_of_clusters.end())
488 WARN(l) << "size_of_clusters is not found" << std::endl;
492 if (_qbits_of_clusters.find(prefix) == _qbits_of_clusters.end())
494 WARN(l) << "qbits_of_clusters is not found" << std::endl;
498 if (_fusable_op.find(prefix) == _fusable_op.end())
500 WARN(l) << "fusable_op is not found" << std::endl;
504 // As dequant_weight is not used for fusing, skip validation.
510 luci::CircleConst *alpha(loco::Graph *graph, int32_t prefix)
512 auto new_alpha = graph->nodes()->create<luci::CircleConst>();
514 new_alpha->dtype(loco::DataType::FLOAT32);
515 new_alpha->size<loco::DataType::FLOAT32>(_alpha[prefix]->size<loco::DataType::FLOAT32>());
517 new_alpha->dim(0) = _alpha[prefix]->dim(0);
518 for (uint32_t i = 0; i < _alpha[prefix]->size<loco::DataType::FLOAT32>(); ++i)
519 new_alpha->at<loco::DataType::FLOAT32>(i) = _alpha[prefix]->at<loco::DataType::FLOAT32>(i);
520 new_alpha->shape_status(luci::ShapeStatus::VALID);
525 luci::CircleConst *packed_binary_code(loco::Graph *graph, int32_t prefix)
527 auto new_beta = graph->nodes()->create<luci::CircleConst>();
529 new_beta->dtype(loco::DataType::S32);
530 new_beta->size<loco::DataType::S32>(_packed_binary_code[prefix]->size<loco::DataType::S32>());
532 new_beta->dim(0) = _packed_binary_code[prefix]->dim(0);
533 new_beta->dim(1) = _packed_binary_code[prefix]->dim(1);
534 for (uint32_t i = 0; i < _packed_binary_code[prefix]->size<loco::DataType::S32>(); ++i)
535 new_beta->at<loco::DataType::S32>(i) =
536 _packed_binary_code[prefix]->at<loco::DataType::S32>(i);
537 new_beta->shape_status(luci::ShapeStatus::VALID);
542 luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix)
544 auto qbits_of_clusters = _qbits_of_clusters[prefix];
545 auto size_of_clusters = _size_of_clusters[prefix];
546 const auto number_of_clusters = _number_of_clusters[prefix]->at<loco::DataType::S32>(0);
548 auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
549 packed_clusters->dtype(loco::DataType::S32);
550 packed_clusters->size<loco::DataType::S32>(number_of_clusters * 2);
551 packed_clusters->rank(2);
552 packed_clusters->dim(0) = number_of_clusters;
553 packed_clusters->dim(1) = 2;
554 packed_clusters->shape_status(luci::ShapeStatus::VALID);
556 for (int i = 0; i < number_of_clusters; ++i)
558 packed_clusters->at<loco::DataType::S32>(i * 2) =
559 qbits_of_clusters->at<loco::DataType::S32>(i);
560 packed_clusters->at<loco::DataType::S32>(i * 2 + 1) =
561 size_of_clusters->at<loco::DataType::S32>(i);
564 return packed_clusters;
568 std::map<int32_t, luci::CircleConst *> _do_w_x;
569 std::map<int32_t, luci::CircleConst *> _alpha;
570 std::map<int32_t, luci::CircleConst *> _packed_binary_code;
571 std::map<int32_t, luci::CircleConst *> _number_of_clusters;
572 std::map<int32_t, luci::CircleConst *> _size_of_clusters;
573 std::map<int32_t, luci::CircleConst *> _qbits_of_clusters;
574 std::map<int32_t, luci::CircleConst *> _dequant_weight;
575 std::map<int32_t, luci::CircleNode *> _fusable_op;
578 int32_t _original_output_cnt = 0;
579 int32_t _bundle_cnt = 0;
587 bool FuseBCQPass::run(luci::Module *m)
589 bool changed = false;
591 const int32_t start_magicnum = -2e9 + 27;
592 const int32_t end_magicnum = 2e9 - 27;
594 loco::Graph *main_graph = m->graph(0);
596 luci::CircleConst *metadata_node = nullptr;
597 for (auto node : loco::output_nodes(main_graph))
599 auto output_node = loco::must_cast<luci::CircleOutput *>(node);
601 // Metadata node should be first output
602 if (output_node->index() != 0)
605 // Metadata should be constant and dtype should be S32
606 auto const_node = dynamic_cast<luci::CircleConst *>(output_node->from());
607 if (const_node == nullptr || const_node->dtype() != loco::DataType::S32)
610 // Metadata has at least four elements
611 const auto element_cnt = const_node->size<loco::DataType::S32>();
615 // Metadata has magic numbers at first and at last
616 const auto start_value = const_node->at<loco::DataType::S32>(0);
617 const auto end_value = const_node->at<loco::DataType::S32>(element_cnt - 1);
618 if (start_value == start_magicnum && end_value == end_magicnum)
620 metadata_node = const_node;
625 if (metadata_node != nullptr)
627 const auto bcq_version = metadata_node->at<loco::DataType::S32>(1);
628 const auto original_output_cnt = metadata_node->at<loco::DataType::S32>(2);
630 if (bcq_version == 1)
632 const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3);
634 BCQFuser<1> fuser{original_output_cnt, bundle_cnt};
635 fuser.register_bcq_info(main_graph);
637 for (size_t g = 0; g < m->size(); ++g)
638 if (fuser.fuseBCQ(m->graph(g)))
644 WARN(l) << "Not supported BCQ version is found." << std::endl;
647 // Remove all of BCQ information nodes iff there is no change
648 if (changed == false)
650 for (auto node : loco::output_nodes(main_graph))
652 auto output_node = loco::must_cast<luci::CircleOutput *>(node);
653 if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt)
655 auto noOp = main_graph->nodes()->create<luci::CircleOutputExclude>();
656 noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting
657 output_node->from(noOp);
667 bool FuseBCQPass::run(loco::Graph *)
669 // Do nothing for graph