Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseBCQPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseBCQPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <cassert>
22 #include <string>
23 #include <set>
24
25 namespace
26 {
27
28 /**
29  * @brief Circle nodes including BCQ information and a circle node to which BCQ will be applied
30  *        are connected with their name. And their names include common prefix.
31  *        However, after pb file is converted to tflite file, some nodes' name are changed.
32  *        Thus this function will return original common prefix.
33  *
34  * @note  All the re-naming rule of TFLite converter is not figured out.
35  *        Therefore, if new naming rule is detected, this function should be updated.
36  */
37 const std::string node_name_prefix(luci::NodeName node_name)
38 {
39   std::string prefix = node_name;
40
41   if (prefix.find("ReadVariableOp/resource/") != std::string::npos)
42   {
43     const auto start_index = prefix.find("ReadVariableOp/resource/");
44
45     const auto left_prefix = prefix.substr(0, start_index);
46     const auto right_prefix = prefix.substr(start_index + 24);
47
48     prefix = left_prefix + right_prefix;
49   }
50
51   if (prefix.find("Tensordot/") != std::string::npos)
52   {
53     const auto index = prefix.find("Tensordot/");
54     prefix = prefix.substr(0, index - 1);
55   }
56   else if (prefix.find("kernel/") != std::string::npos)
57   {
58     const auto index = prefix.find("kernel/");
59     prefix = prefix.substr(0, index - 1);
60   }
61   else if (prefix.find("/bcqinfo_") != std::string::npos)
62   {
63     const auto index = prefix.find("/bcqinfo_");
64     prefix = prefix.substr(0, index);
65   }
66
67   return prefix;
68 }
69
70 } // namespace
71
72 namespace
73 {
74
75 class BCQConverter final
76 {
77 public:
78   void add_BCQ_info_node(luci::CircleConst *node)
79   {
80     const auto node_name = node->name();
81     const auto prefix = node_name_prefix(node_name);
82
83     // If bcqinfo_* nodes are held by Reshape operation,
84     // shape of bcqinfo_* nodes are copied to `shape` input of Reshape operation.
85     // Then the name becomes bcqinfo_*_copy_shape.
86     // We should prevent this node not to added to bcq information.
87     if (node_name.find("_copy_shape") != std::string::npos)
88       return;
89
90     if (node_name.find("bcqinfo_do_w_x") != std::string::npos)
91       _do_w_x[prefix] = node;
92     else if (node_name.find("bcqinfo_alpha") != std::string::npos)
93       _alpha[prefix] = node;
94     else if (node_name.find("bcqinfo_packed_binary_code") != std::string::npos)
95       _packed_binary_code[prefix] = node;
96     else if (node_name.find("bcqinfo_number_of_clusters") != std::string::npos)
97       _number_of_clusters[prefix] = node;
98     else if (node_name.find("bcqinfo_size_of_clusters") != std::string::npos)
99       _size_of_clusters[prefix] = node;
100     else if (node_name.find("bcqinfo_qbits_of_clusters") != std::string::npos)
101       _qbits_of_clusters[prefix] = node;
102     else if (node_name.find("bcqinfo_dequant_weight") != std::string::npos)
103       _dequant_weight[prefix] = node;
104   }
105
106   bool has_BCQ_info(luci::CircleConst *node)
107   {
108     const auto prefix = node_name_prefix(node->name());
109     bool has_info = true;
110
111     has_info &= (_do_w_x.find(prefix) != _do_w_x.end());
112     has_info &= (_alpha.find(prefix) != _alpha.end());
113     has_info &= (_packed_binary_code.find(prefix) != _packed_binary_code.end());
114     has_info &= (_number_of_clusters.find(prefix) != _number_of_clusters.end());
115     has_info &= (_size_of_clusters.find(prefix) != _size_of_clusters.end());
116     has_info &= (_qbits_of_clusters.find(prefix) != _qbits_of_clusters.end());
117     // bcqinfo_dequant_weight is just for validation, so not always exists.
118
119     return has_info;
120   }
121
122   bool do_w_x(luci::CircleConst *node)
123   {
124     const auto prefix = node_name_prefix(node->name());
125
126     if (_do_w_x[prefix]->dtype() == loco::DataType::S32)
127       return _do_w_x[prefix]->at<loco::DataType::S32>(0) == 1;
128     else if (_do_w_x[prefix]->dtype() == loco::DataType::BOOL)
129       return _do_w_x[prefix]->at<loco::DataType::BOOL>(0);
130     else
131       throw std::runtime_error("do_w_x should be int or bool");
132   }
133
134   luci::CircleConst *get_alpha(luci::CircleConst *node)
135   {
136     const auto prefix = node_name_prefix(node->name());
137     return _alpha[prefix];
138   }
139
140   luci::CircleConst *get_packed_binary_code(luci::CircleConst *node)
141   {
142     const auto prefix = node_name_prefix(node->name());
143     return _packed_binary_code[prefix];
144   }
145
146   luci::CircleConst *get_number_of_clusters(luci::CircleConst *node)
147   {
148     const auto prefix = node_name_prefix(node->name());
149     return _number_of_clusters[prefix];
150   }
151
152   luci::CircleConst *get_size_of_clusters(luci::CircleConst *node)
153   {
154     const auto prefix = node_name_prefix(node->name());
155     return _size_of_clusters[prefix];
156   }
157
158   luci::CircleConst *get_qbits_of_clusters(luci::CircleConst *node)
159   {
160     const auto prefix = node_name_prefix(node->name());
161     return _qbits_of_clusters[prefix];
162   }
163
164   luci::CircleConst *packed_clusters(luci::CircleConst *node)
165   {
166     auto graph = node->graph();
167     auto qbits_of_clusters = get_qbits_of_clusters(node);
168     auto size_of_clusters = get_size_of_clusters(node);
169     const auto number_of_clusters = get_number_of_clusters(node)->at<loco::DataType::S32>(0);
170
171     auto packed_clusters = graph->nodes()->create<luci::CircleConst>();
172     packed_clusters->dtype(loco::DataType::S32);
173     packed_clusters->size<loco::DataType::S32>(number_of_clusters * 2);
174     packed_clusters->rank(2);
175     packed_clusters->dim(0) = number_of_clusters;
176     packed_clusters->dim(1) = 2;
177     packed_clusters->shape_status(luci::ShapeStatus::VALID);
178
179     for (int i = 0; i < number_of_clusters; ++i)
180     {
181       packed_clusters->at<loco::DataType::S32>(i * 2) =
182           qbits_of_clusters->at<loco::DataType::S32>(i);
183       packed_clusters->at<loco::DataType::S32>(i * 2 + 1) =
184           size_of_clusters->at<loco::DataType::S32>(i);
185     }
186
187     return packed_clusters;
188   }
189
190   /**
191    * @brief Exclude BCQ information nodes which are used for fusing BCQ operations
192    *        from graph output by using CircleOutputExclude
193    */
194   void clear_BCQ_nodes()
195   {
196     auto createNoOp = [](luci::CircleNode *circle_node) {
197       auto graph = circle_node->graph();
198       auto noOp = graph->nodes()->create<luci::CircleOutputExclude>();
199
200       if (circle_node->shape_status() == luci::ShapeStatus::VALID)
201       {
202         noOp->dtype(circle_node->dtype());
203         noOp->rank(circle_node->rank());
204         for (uint32_t i = 0; i < circle_node->rank(); ++i)
205           noOp->dim(i) = circle_node->dim(i);
206       }
207       else
208       {
209         // For type inference
210         noOp->dtype(loco::DataType::FLOAT32);
211       }
212
213       return noOp;
214     };
215
216     auto clear_nodes = [createNoOp](std::map<std::string, luci::CircleConst *> &nodes) {
217       for (auto &n : nodes)
218       {
219         auto node = n.second;
220
221         for (auto s : loco::succs(node))
222         {
223           if (auto outnode = dynamic_cast<luci::CircleOutput *>(s))
224           {
225             outnode->from(createNoOp(node));
226           }
227           else if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(s))
228           {
229             for (auto o : loco::succs(reshape_node))
230             {
231               auto circle_output = loco::must_cast<luci::CircleOutput *>(o);
232               circle_output->from(createNoOp(reshape_node));
233             }
234           }
235         }
236       }
237     };
238
239     clear_nodes(_do_w_x);
240     clear_nodes(_alpha);
241     clear_nodes(_packed_binary_code);
242     clear_nodes(_number_of_clusters);
243     clear_nodes(_size_of_clusters);
244     clear_nodes(_qbits_of_clusters);
245     clear_nodes(_dequant_weight);
246   }
247
248 private:
249   std::map<std::string, luci::CircleConst *> _do_w_x;
250   std::map<std::string, luci::CircleConst *> _alpha;
251   std::map<std::string, luci::CircleConst *> _packed_binary_code;
252   std::map<std::string, luci::CircleConst *> _number_of_clusters;
253   std::map<std::string, luci::CircleConst *> _size_of_clusters;
254   std::map<std::string, luci::CircleConst *> _qbits_of_clusters;
255   std::map<std::string, luci::CircleConst *> _dequant_weight;
256 };
257
258 } // namespace
259
260 namespace luci
261 {
262
263 bool FuseBCQPass::run(loco::Graph *g)
264 {
265   BCQConverter converter;
266
267   bool changed = false;
268
269   for (auto node : loco::all_nodes(g))
270   {
271     if (auto circle_const = dynamic_cast<luci::CircleConst *>(node))
272     {
273       converter.add_BCQ_info_node(circle_const);
274     }
275   }
276
277   for (auto node : loco::active_nodes(loco::output_nodes(g)))
278   {
279     if (auto gather = dynamic_cast<luci::CircleGather *>(node))
280     {
281       auto params = dynamic_cast<luci::CircleConst *>(gather->params());
282       if (params != nullptr && converter.has_BCQ_info(params))
283       {
284         auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
285
286         bcq_gather->input_scales(converter.get_alpha(params));
287         bcq_gather->input_binary(converter.get_packed_binary_code(params));
288         bcq_gather->indices(gather->indices());
289         bcq_gather->input_clusters(converter.packed_clusters(params));
290
291         const auto binary_hidden_size =
292             loco::must_cast<luci::CircleConst *>(bcq_gather->input_binary())->dim(1).value() * 32;
293         bcq_gather->input_hidden_size(binary_hidden_size);
294
295         if (converter.do_w_x(params))
296         {
297           bcq_gather->axis(gather->axis());
298         }
299         else
300         {
301           const auto axis_transpose = (gather->axis() == 0) ? 1 : 0;
302           bcq_gather->axis(axis_transpose);
303         }
304
305         loco::replace(gather).with(bcq_gather);
306
307         changed = true;
308       }
309     }
310     else if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
311     {
312       auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights());
313       if (weights != nullptr && converter.has_BCQ_info(weights))
314       {
315         auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
316
317         bcq_fc->weights_scales(converter.get_alpha(weights));
318         bcq_fc->weights_binary(converter.get_packed_binary_code(weights));
319         bcq_fc->bias(fully_connected->bias());
320         bcq_fc->weights_clusters(converter.packed_clusters(weights));
321         bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
322
323         loco::Node *bcq_input = fully_connected->input();
324         int32_t batch_rank = 0;
325
326         // If input of BCQFullyConnected has more than rank 2, we should reshape it as rank 2
327         const auto original_input = loco::must_cast<luci::CircleNode *>(fully_connected->input());
328         if (original_input->shape_status() == ShapeStatus::VALID && original_input->rank() > 2)
329         {
330           auto new_shape = g->nodes()->create<luci::CircleConst>();
331           new_shape->dtype(loco::DataType::S32);
332           new_shape->size<loco::DataType::S32>(2);
333           new_shape->rank(1);
334           new_shape->dim(0) = 2;
335
336           auto batch_size = 1;
337           for (uint32_t i = 0; i < original_input->rank() - 1; ++i)
338             batch_size *= original_input->dim(i).value();
339
340           new_shape->at<loco::DataType::S32>(0) = batch_size;
341           new_shape->at<loco::DataType::S32>(1) =
342               original_input->dim(original_input->rank() - 1).value();
343           new_shape->shape_status(ShapeStatus::VALID);
344
345           auto reshape = g->nodes()->create<luci::CircleReshape>();
346           reshape->tensor(original_input);
347           reshape->shape(new_shape);
348
349           bcq_input = reshape;
350           batch_rank = original_input->rank() - 2;
351         }
352
353         // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
354         if (converter.do_w_x(weights))
355         {
356           const auto binary_hidden_size =
357               loco::must_cast<luci::CircleNode *>(fully_connected->input())
358                   ->dim(batch_rank)
359                   .value();
360           bcq_fc->weights_hidden_size(binary_hidden_size);
361           bcq_fc->input(bcq_input);
362           loco::replace(fully_connected).with(bcq_fc);
363         }
364         else
365         {
366           const auto binary_hidden_size =
367               loco::must_cast<luci::CircleNode *>(fully_connected->input())
368                   ->dim(1 + batch_rank)
369                   .value();
370           bcq_fc->weights_hidden_size(binary_hidden_size);
371
372           auto perm = g->nodes()->create<luci::CircleConst>();
373           perm->dtype(loco::DataType::S32);
374           perm->size<loco::DataType::S32>(2);
375           perm->rank(1);
376           perm->dim(0) = 2;
377           perm->at<loco::DataType::S32>(0) = 1;
378           perm->at<loco::DataType::S32>(1) = 0;
379           perm->shape_status(ShapeStatus::VALID);
380
381           auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
382           input_transpose->a(bcq_input);
383           input_transpose->perm(perm);
384
385           bcq_fc->input(input_transpose);
386
387           auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
388           output_transpose->a(bcq_fc);
389           output_transpose->perm(perm);
390
391           loco::replace(fully_connected).with(output_transpose);
392         }
393
394         changed = true;
395       }
396     }
397   }
398
399   if (changed)
400     converter.clear_BCQ_nodes();
401
402   return changed;
403 }
404
405 } // namespace luci