Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseBCQPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseBCQPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <cassert>
22 #include <string>
23 #include <set>
24
25 namespace
26 {
27
28 /**
29  * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied
30  *        are connected with their name. And their names include common prefix.
31  *        However, after pb file is converted to tflite file, some nodes' name are changed.
32  *        Thus this function will return original common prefix.
33  *
34  * @note  All the re-naming rule of TFLite converter is not figured out.
35  *        Therefore, if new naming rule is detected, this function should be updated.
36  */
37 const std::string node_name_prefix(luci::NodeName node_name)
38 {
39   std::string prefix = node_name;
40
41   if (prefix.find("ReadVariableOp/resource/") != std::string::npos)
42   {
43     const auto start_index = prefix.find("ReadVariableOp/resource/");
44
45     const auto left_prefix = prefix.substr(0, start_index);
46     const auto right_prefix = prefix.substr(start_index + 24);
47
48     prefix = left_prefix + right_prefix;
49   }
50
51   if (prefix.find("Tensordot/") != std::string::npos)
52   {
53     const auto index = prefix.find("Tensordot/");
54     prefix = prefix.substr(0, index - 1);
55   }
56   else if (prefix.find("/MatMul") != std::string::npos)
57   {
58     const auto index = prefix.find("/MatMul");
59     prefix = prefix.substr(0, index);
60   }
61   else if (prefix.find("kernel/") != std::string::npos)
62   {
63     const auto index = prefix.find("kernel/");
64     prefix = prefix.substr(0, index - 1);
65   }
66   else if (prefix.find("/bcqinfo_") != std::string::npos)
67   {
68     const auto index = prefix.find("/bcqinfo_");
69     prefix = prefix.substr(0, index);
70   }
71
72   return prefix;
73 }
74
75 /**
76  * @brief Create CircleOutputExclude operation, which has same shape and dtype with
77  *        original circle_node.
78  */
79 luci::CircleOutputExclude *createNoOp(luci::CircleNode *circle_node)
80 {
81   auto graph = circle_node->graph();
82   auto noOp = graph->nodes()->create<luci::CircleOutputExclude>();
83
84   if (circle_node->shape_status() == luci::ShapeStatus::VALID)
85   {
86     noOp->dtype(circle_node->dtype());
87     noOp->rank(circle_node->rank());
88     for (uint32_t i = 0; i < circle_node->rank(); ++i)
89       noOp->dim(i) = circle_node->dim(i);
90   }
91   else
92   {
93     // For type inference
94     noOp->dtype(loco::DataType::FLOAT32);
95   }
96
97   return noOp;
98 };
99
100 } // namespace
101
102 namespace
103 {
104
105 // V means the version of BCQ.
106 template <int32_t V> class BCQFuser;
107
108 template <> class BCQFuser<1>
109 {
110 public:
111   bool fuseBCQ(loco::Graph *g)
112   {
113     bool changed = false;
114
115     for (auto node : loco::all_nodes(g))
116     {
117       if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
118       {
119         add_BCQ_info_node(circle_const);
120       }
121     }
122
123     if (!is_bcqinfo_valid())
124       return false;
125
126     for (auto node : loco::active_nodes(loco::output_nodes(g)))
127     {
128       if (auto gather = dynamic_cast<luci::CircleGather *>(node))
129       {
130         auto params = dynamic_cast<luci::CircleConst *>(gather->params());
131         if (params != nullptr && has_BCQ_info(params))
132         {
133           auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
134
135           bcq_gather->op_version(1);
136           bcq_gather->input_scales(get_alpha(params));
137           bcq_gather->input_binary(get_packed_binary_code(params));
138           bcq_gather->indices(gather->indices());
139           bcq_gather->input_clusters(packed_clusters(params));
140
141           // input_binary shape : [output_size, hidden_size]
142           const auto binary_hidden_size =
143               loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32;
144           bcq_gather->input_hidden_size(binary_hidden_size);
145
146           if (do_w_x(params))
147           {
148             bcq_gather->axis(gather->axis());
149           }
150           else
151           {
152             const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
153             bcq_gather->axis(axis_transpose);
154           }
155
156           loco::replace(gather).with(bcq_gather);
157
158           changed = true;
159         }
160       }
161       else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
162       {
163         auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights());
164         if (weights != nullptr && has_BCQ_info(weights))
165         {
166           auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
167
168           bcq_fc->op_version(1);
169           bcq_fc->weights_scales(get_alpha(weights));
170           bcq_fc->weights_binary(get_packed_binary_code(weights));
171           bcq_fc->bias(fully_connected->bias());
172           bcq_fc->weights_clusters(packed_clusters(weights));
173           bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
174
175           loco::Node *bcq_input = fully_connected->input();
176           int32_t batch_rank = 0;
177
178           // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
179           const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
180           if (original_input->shape_status() == luci::ShapeStatus::VALID &&
181               original_input->rank() > 2)
182           {
183             auto new_shape = g->nodes()->create<luci::CircleConst>();
184             new_shape->dtype(loco::DataType::S32);
185             new_shape->size<loco::DataType::S32>(2);
186             new_shape->rank(1);
187             new_shape->dim(0) = 2;
188
189             auto batch_size = 1;
190             for (uint32_t i = 0; i < original_input->rank() - 1; ++i)
191               batch_size *= original_input->dim(i).value();
192
193             new_shape->at<loco::DataType::S32>(0) = batch_size;
194             new_shape->at<loco::DataType::S32>(1) =
195                 original_input->dim(original_input->rank() - 1).value();
196             new_shape->shape_status(luci::ShapeStatus::VALID);
197
198             auto reshape = g->nodes()->create<luci::CircleReshape>();
199             reshape->tensor(original_input);
200             reshape->shape(new_shape);
201
202             bcq_input = reshape;
203             batch_rank = original_input->rank() - 2;
204           }
205
206           // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
207           if (do_w_x(weights))
208           {
209             const auto binary_hidden_size =
210                 loco::must_cast<luci::CircleNode *>(fully_connected->input())
211                     ->dim(batch_rank)
212                     .value();
213             bcq_fc->weights_hidden_size(binary_hidden_size);
214             bcq_fc->input(bcq_input);
215             loco::replace(fully_connected).with(bcq_fc);
216           }
217           else
218           {
219             const auto binary_hidden_size =
220                 loco::must_cast<luci::CircleNode *>(fully_connected->input())
221                     ->dim(1 + batch_rank)
222                     .value();
223             bcq_fc->weights_hidden_size(binary_hidden_size);
224
225             auto perm = g->nodes()->create<luci::CircleConst>();
226             perm->dtype(loco::DataType::S32);
227             perm->size<loco::DataType::S32>(2);
228             perm->rank(1);
229             perm->dim(0) = 2;
230             perm->at<loco::DataType::S32>(0) = 1;
231             perm->at<loco::DataType::S32>(1) = 0;
232             perm->shape_status(luci::ShapeStatus::VALID);
233
234             auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
235             input_transpose->a(bcq_input);
236             input_transpose->perm(perm);
237
238             bcq_fc->input(input_transpose);
239
240             auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
241             output_transpose->a(bcq_fc);
242             output_transpose->perm(perm);
243
244             loco::replace(fully_connected).with(output_transpose);
245           }
246
247           changed = true;
248         }
249       }
250     }
251
252     if (changed)
253       clear_BCQ_nodes();
254
255     return changed;
256   }
257
258 private:
259   void add_BCQ_info_node(luci::CircleConst *node)
260   {
261     const auto node_name = node->name();
262     const auto prefix = node_name_prefix(node_name);
263
264     // If bcqinfo_* nodes are held by Reshape operation,
265     // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation.
266     // Then the name becomes bcqinfo_*_copy_shape.
267     // We should prevent this node not to added to bcq information.
268     if (node_name.find("_copy_shape") != std::string::npos)
269       return;
270
271     if (node_name.find("bcqinfo_do_w_x") != std::string::npos)
272       _do_w_x[prefix] = node;
273     else if (node_name.find("bcqinfo_alpha") != std::string::npos)
274       _alpha[prefix] = node;
275     else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos)
276       _packed_binary_code[prefix] = node;
277     else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos)
278       _number_of_clusters[prefix] = node;
279     else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos)
280       _size_of_clusters[prefix] = node;
281     else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos)
282       _qbits_of_clusters[prefix] = node;
283     else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos)
284       _dequant_weight[prefix] = node;
285   }
286
287   bool has_BCQ_info(luci::CircleConst *node)
288   {
289     const auto prefix = node_name_prefix(node->name());
290     bool has_info = true;
291
292     has_info &= (_do_w_x.find(prefix) != _do_w_x.end());
293     has_info &= (_alpha.find(prefix) != _alpha.end());
294     has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end());
295     has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end());
296     has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end());
297     has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end());
298     // bcqinfo_dequant_weight is just for validation, so not always exists.
299
300     return has_info;
301   }
302
303   /**
304    * @brief Exclude BCQ information nodes which are used for fusing BCQ operations
305    *        from graph output by using CircleOutputExclude
306    */
307   void clear_BCQ_nodes()
308   {
309     auto clear_nodes = [](std::map<std::string, luci::CircleConst *> &nodes) {
310       for (auto &n : nodes)
311       {
312         auto node = n.second;
313
314         for (auto s : loco::succs(node))
315         {
316           if (auto outnode = dynamic_cast<luci::CircleOutput *>(s))
317           {
318             outnode->from(createNoOp(node));
319           }
320           else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s))
321           {
322             for (auto o : loco::succs(reshape_node))
323             {
324               auto circle_output = loco::must_cast<luci::CircleOutput *>(o);
325               circle_output->from(createNoOp(reshape_node));
326             }
327           }
328         }
329       }
330     };
331
332     clear_nodes(_do_w_x);
333     clear_nodes(_alpha);
334     clear_nodes(_packed_binary_code);
335     clear_nodes(_number_of_clusters);
336     clear_nodes(_size_of_clusters);
337     clear_nodes(_qbits_of_clusters);
338     clear_nodes(_dequant_weight);
339   }
340
341   bool is_bcqinfo_valid()
342   {
343     // do_w_x should be int32 or bool type
344     for (auto n : _do_w_x)
345     {
346       if (n.second->dtype() != loco::DataType::BOOL && n.second->dtype() != loco::DataType::S32)
347         return false;
348     }
349
350     return true;
351   }
352
353 private:
354   bool do_w_x(luci::CircleConst *node)
355   {
356     const auto prefix = node_name_prefix(node->name());
357
358     if (_do_w_x[prefix]->dtype() == loco::DataType::S32)
359       return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1;
360     else
361       return _do_w_x[prefix]->at<loco::DataType::BOOL>(0);
362   }
363
364   luci::CircleConst *get_alpha(luci::CircleConst *node)
365   {
366     const auto prefix = node_name_prefix(node->name());
367     return _alpha[prefix];
368   }
369
370   luci::CircleConst *get_packed_binary_code(luci::CircleConst *node)
371   {
372     const auto prefix = node_name_prefix(node->name());
373     return _packed_binary_code[prefix];
374   }
375
376   luci::CircleConst *get_number_of_clusters(luci::CircleConst *node)
377   {
378     const auto prefix = node_name_prefix(node->name());
379     return _number_of_clusters[prefix];
380   }
381
382   luci::CircleConst *get_size_of_clusters(luci::CircleConst *node)
383   {
384     const auto prefix = node_name_prefix(node->name());
385     return _size_of_clusters[prefix];
386   }
387
388   luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node)
389   {
390     const auto prefix = node_name_prefix(node->name());
391     return _qbits_of_clusters[prefix];
392   }
393
394   luci::CircleConst *packed_clusters(luci::CircleConst *node)
395   {
396     auto graph = node->graph();
397     auto qbits_of_clusters = get_qbits_of_clusters(node);
398     auto size_of_clusters = get_size_of_clusters(node);
399     const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0);
400
401     auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
402     packed_clusters->dtype(loco::DataType::S32);
403     packed_clusters->size<loco::DataType::S32>(number_of_clusters * 2);
404     packed_clusters->rank(2);
405     packed_clusters->dim(0) = number_of_clusters;
406     packed_clusters->dim(1) = 2;
407     packed_clusters->shape_status(luci::ShapeStatus::VALID);
408
409     for (int i = 0; i < number_of_clusters; ++i)
410     {
411       packed_clusters->at<loco::DataType::S32>(i * 2) =
412           qbits_of_clusters->at<loco::DataType::S32>(i);
413       packed_clusters->at<loco::DataType::S32>(i * 2 + 1) =
414           size_of_clusters->at<loco::DataType::S32>(i);
415     }
416
417     return packed_clusters;
418   }
419
420 private:
421   std::map<std::string, luci::CircleConst *> _do_w_x;
422   std::map<std::string, luci::CircleConst *> _alpha;
423   std::map<std::string, luci::CircleConst *> _packed_binary_code;
424   std::map<std::string, luci::CircleConst *> _number_of_clusters;
425   std::map<std::string, luci::CircleConst *> _size_of_clusters;
426   std::map<std::string, luci::CircleConst *> _qbits_of_clusters;
427   std::map<std::string, luci::CircleConst *> _dequant_weight;
428 };
429
430 } // namespace
431
432 namespace luci
433 {
434
435 bool FuseBCQPass::run(loco::Graph *g)
436 {
437   bool changed = false;
438
439   // Find BCQ version information and check validity.
440   luci::CircleConst *version_node = nullptr;
441   for (auto node : loco::all_nodes(g))
442   {
443     if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
444     {
445       if (circle_const->name().find("/bcqinfo_version") != std::string::npos)
446       {
447         // There should be only one bcqinfo_version in the model
448         if (version_node != nullptr)
449         {
450           assert(false && "Multiple version information found");
451           return false;
452         }
453
454         version_node = circle_const;
455       }
456     }
457   }
458
459   // If version node is not found, regard it as version 1.
460   int32_t bcq_version = (version_node != nullptr) ? version_node->at<loco::DataType::S32>(0) : 1;
461
462   if (bcq_version == 1)
463     changed = BCQFuser<1>().fuseBCQ(g);
464   else
465     assert(false && "Not supported BCQ version");
466
467   if (changed && version_node != nullptr)
468   {
469     // If BCQ is applied and version node was found, remove the node.
470     loco::replace(version_node).with(createNoOp(version_node));
471   }
472
473   return changed;
474 }
475
476 } // namespace luci