Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / PropagateQParamBackwardPass.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/PropagateQParamBackwardPass.h"
18 #include "QuantizationUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Service/Nodes/CircleConst.h>
23 #include <luci/Log.h>
24
25 #include <cmath>
26 #include <limits>
27
28 namespace
29 {
30
31 // Return true if node is a virtual node
32 bool virtual_op(const luci::CircleOpcode opcode)
33 {
34   switch (opcode)
35   {
36 #define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
37   case luci::CircleOpcode::OPCODE:        \
38     return false;
39 #define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \
40   case luci::CircleOpcode::OPCODE:         \
41     return true;
42 #include <luci/IR/CircleNodes.lst>
43 #undef CIRCLE_NODE
44 #undef CIRCLE_VNODE
45     default:
46       throw std::runtime_error("Unknown opcode detected");
47   }
48 }
49
50 void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
51                         loco::DataType quant_type)
52 {
53   uint32_t size = const_node->size<loco::DataType::FLOAT32>();
54
55   const float scaling_factor_inv = 1.0 / scaling_factor;
56   std::vector<int32_t> quantized_values(size);
57   for (uint32_t i = 0; i < size; ++i)
58   {
59     auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
60     double quantized_data = std::round(data * scaling_factor_inv) + zerop;
61     constexpr double int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
62     constexpr double int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
63     quantized_data = std::min(int_max, std::max(int_min, quantized_data));
64
65     quantized_values[i] = static_cast<int32_t>(quantized_data);
66   }
67
68   switch (quant_type)
69   {
70     case loco::DataType::U8:
71       const_node->dtype(loco::DataType::U8);      // change the type of tensor
72       const_node->size<loco::DataType::U8>(size); // resize tensor
73       for (uint32_t i = 0; i < size; ++i)
74         const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
75       break;
76     case loco::DataType::S16:
77       assert(zerop == 0);
78       const_node->dtype(loco::DataType::S16);      // change the type of tensor
79       const_node->size<loco::DataType::S16>(size); // resize tensor
80       for (uint32_t i = 0; i < size; ++i)
81         const_node->at<loco::DataType::S16>(i) =
82           std::min(32767, std::max(-32767, quantized_values[i]));
83       break;
84     default:
85       throw std::runtime_error("Unsupported data type");
86   }
87 }
88
89 void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *target)
90 {
91   auto source_qparam = source->quantparam();
92   if (source_qparam == nullptr)
93     throw std::runtime_error("source quantparam is not found during overwrite");
94
95   auto target_qparam = target->quantparam();
96   if (target_qparam == nullptr)
97   {
98     auto quantparam = std::make_unique<luci::CircleQuantParam>();
99     target->quantparam(std::move(quantparam));
100     target_qparam = target->quantparam();
101
102     if (target_qparam == nullptr)
103       throw std::runtime_error("Creating new quant param failed");
104   }
105   target_qparam->min = source_qparam->min;
106   target_qparam->max = source_qparam->max;
107   target_qparam->scale = source_qparam->scale;
108   target_qparam->zerop = source_qparam->zerop;
109   target_qparam->quantized_dimension = source_qparam->quantized_dimension;
110 }
111
112 /**
113  * Tells if pad_v2 quantization should ignore padding value
114  * In that case padding const will be quantized with input parameters, and probably clipped
115  */
116 bool ignore_pad_v2_const_quantization(const luci::CirclePadV2 *pad)
117 {
118   // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
119   // TODO use metadata hints to detect this case
120   auto const_value_node = dynamic_cast<const luci::CircleConst *>(pad->arg(2));
121   if (!const_value_node)
122     return false;
123   if (const_value_node->dtype() == loco::DataType::FLOAT32)
124   {
125     float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
126     if (const_value == std::numeric_limits<float>::lowest())
127       return true;
128   }
129   return false;
130 }
131
132 /** EXAMPLE
133  *
134  * BEFORE
135  *
136  *         [CircleNode]       [CircleConst]
137  *           (qparam1)           (FP32)
138  *                   \            /
139  *                    \          /
140  *                    [CirclePack]
141  *                     (qparam2)
142  *
143  *  AFTER
144  *
145  *         [CircleNode]        [CircleConst]   [CircleConst] <- Dead node
146  *           (qparam2)           (qparam2)         (FP32)
147  *                   \            /
148  *                    \          /
149  *                    [CirclePack]
150  *                     (qparam2)
151  *
152  * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
153  */
154 void propagate_pack_quantparam(luci::CirclePack *pack)
155 {
156   assert(pack->quantparam() != nullptr);
157
158   const auto num_inputs = pack->values_count();
159
160   for (uint32_t i = 0; i < num_inputs; i++)
161   {
162     auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
163
164     // Quantize constant values
165     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
166     {
167       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
168       if (const_node->dtype() != loco::DataType::FLOAT32)
169         throw std::runtime_error("Unsupported data type for constant input of pack Op");
170
171       const auto pack_qparam = pack->quantparam();
172       if (pack_qparam == nullptr)
173         throw std::runtime_error("quantparam of pack is not found during propagation");
174
175       assert(pack_qparam->scale.size() == 1);
176       assert(pack_qparam->zerop.size() == 1);
177       const auto scaling_factor = pack_qparam->scale[0];
178       const auto zerop = pack_qparam->zerop[0];
179
180       auto new_const = luci::clone(const_node);
181       quant_const_values(new_const, scaling_factor, zerop, pack->dtype());
182       pack->values(i, new_const);
183       overwrite_quantparam(pack, new_const);
184     }
185     else
186     {
187       const auto succs = loco::succs(node);
188       if (succs.size() > 1)
189         continue;
190
191       // Non-const input must have been quantized
192       assert(node->quantparam() != nullptr);
193       overwrite_quantparam(pack, node);
194     }
195   }
196 }
197
198 /** EXAMPLE
199  *
200  *
201  *
202  * BEFORE
203  *
204  *      [CircleNode] [CircleConst] [CircleConst] [CircleNode]
205  *          (S32)        (S32)        (FP32)     (U8 qparam1)
206  *              \          \           /            /
207  *               \          \        /            /
208  *                \          \     /            /
209  *                 -------[CircleOneHot]-------
210  *                         (U8 qparam2)
211  *
212  *  AFTER
213  *
214  *      [CircleNode] [CircleConst] [CircleConst] [CircleNode]      [CircleConst] <- Dead node
215  *          (S32)        (S32)     (U8 qparam2)  (U8 qparam2)         (FP32)
216  *              \          \           /           /
217  *               \          \        /            /
218  *                \          \     /            /
219  *                 -------[CircleOneHot]-------
220  *                         (U8 qparam2)
221  *
222  * NOTE Quantization parameter of CircleOneHot (qparam2) is propagated to on_value/off_value.
223  */
224 void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot)
225 {
226   assert(one_hot->quantparam() != nullptr);
227
228   // Propagate quantization parameters from output to inputs,
229   // to fit both input and counstant_value in one quant range.
230   auto quant_input = [one_hot](void (luci::CircleOneHot::*arg_setter)(loco::Node *),
231                                loco::Node *(luci::CircleOneHot::*arg_getter)() const) {
232     auto node = loco::must_cast<luci::CircleNode *>((one_hot->*arg_getter)());
233
234     // Quantize constant values
235     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
236     {
237       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
238       if (is_quantized(const_node))
239         return;
240
241       if (const_node->dtype() != loco::DataType::FLOAT32)
242         throw std::runtime_error("Unsupported data type for constant input of OneHot Op");
243
244       const auto qparam = one_hot->quantparam();
245       if (qparam == nullptr)
246         throw std::runtime_error("quantparam of OneHot is not found during propagation");
247
248       assert(qparam->scale.size() == 1);
249       const auto scaling_factor = qparam->scale.at(0);
250       const auto zerop = qparam->zerop.at(0);
251
252       auto new_const = luci::clone(const_node);
253       quant_const_values(new_const, scaling_factor, zerop, one_hot->dtype());
254       overwrite_quantparam(one_hot, new_const);
255       (one_hot->*arg_setter)(new_const);
256     }
257     else
258     {
259       const auto succs = loco::succs(node);
260       if (succs.size() > 1)
261         return;
262
263       // Non-const input must have been quantized
264       assert(node->quantparam() != nullptr);
265       overwrite_quantparam(one_hot, node);
266     }
267   };
268
269   quant_input(&luci::CircleOneHot::on_value, &luci::CircleOneHot::on_value);
270   quant_input(&luci::CircleOneHot::off_value, &luci::CircleOneHot::off_value);
271 }
272
273 } // namespace
274
275 namespace luci
276 {
277
278 /** BEFORE
279  *
280  *         [CircleNode]             [CircleConst]
281  *         (U8 qparam1)                 (FP32)
282  *                   \                    /
283  *                    \                  /
284  *                    [CircleConcatenation]
285  *                        (U8 qparam2)
286  *
287  *  AFTER
288  *         [CircleNode]             [CircleConst]   [CircleConst] <- Dead node
289  *         (U8 qparam2)             (U8 qparam2)       (FP32)
290  *                   \                    /
291  *                    \                  /
292  *                    [CircleConcatenation]
293  *                        (U8 qparam2)
294  */
295 void propagate_concat_quantparam(luci::CircleConcatenation *concat)
296 {
297   assert(concat->quantparam() != nullptr);
298
299   const auto num_inputs = concat->numValues();
300
301   // Quantize const inputs using their values if concat has fused act function
302   if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
303   {
304     for (uint32_t i = 0; i < num_inputs; i++)
305     {
306       auto node = concat->arg(i);
307       auto const_node = dynamic_cast<luci::CircleConst *>(node);
308       if (const_node != nullptr)
309       {
310         auto new_const = luci::clone(const_node);
311         quant_const(new_const, concat->dtype());
312         concat->values(i, new_const);
313       }
314     }
315     return;
316   }
317
318   for (uint32_t i = 0; i < num_inputs; i++)
319   {
320     auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
321
322     // Quantize constant values
323     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
324     {
325       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
326
327       const auto concat_qparam = concat->quantparam();
328       assert(concat_qparam->scale.size() == 1);
329       const auto scaling_factor = concat_qparam->scale[0];
330       const auto zerop = concat_qparam->zerop[0];
331
332       auto new_const = luci::clone(const_node);
333       quant_const_values(new_const, scaling_factor, zerop, concat->dtype());
334       concat->values(i, new_const);
335       overwrite_quantparam(concat, new_const);
336     }
337     else
338     {
339       const auto succs = loco::succs(node);
340       if (succs.size() > 1)
341         continue;
342
343       // Non-const input must have been quantized
344       assert(node->quantparam() != nullptr);
345       overwrite_quantparam(concat, node);
346     }
347   }
348 }
349
350 /** BEFORE
351  *
352  *         [CircleNode] [CircleConst] [CircleConst]
353  *         (U8 qparam1)     (S32)       (FP32)
354  *                   \        |         /
355  *                    \       |        /
356  *                      [CirclePadV2]
357  *                       (U8 qparam2)
358  *
359  *  AFTER (case 1)
360  *
361  *  By default qparam is propagated from output to inputs to meet backend requirements.
362  *
363  *         [CircleNode] [CircleConst] [CircleConst]   [CircleConst] <- Dead node
364  *         (U8 qparam2)     (S32)      (U8 qparam2)       (FP32)
365  *                   \        |         /
366  *                    \       |        /
367  *                      [CirclePadV2]
368  *                       (U8 qparam2)
369  *
370  *  AFTER (case 2)
371  *
372  * In case padded value is the lowest float value
373  * Qparam is propagated from input to output and constant.
374  *
375  * This is a special case for optimization constructed pad, needed to guarantee that
376  * extremely large negative constant do not stretch output quantization range.
377  *
378  *         [CircleNode] [CircleConst] [CircleConst]   [CircleConst] <- Dead node
379  *         (U8 qparam1)     (S32)      (U8 qparam1)       (FP32)
380  *                   \        |         /
381  *                    \       |        /
382  *                      [CirclePadV2]
383  *                       (U8 qparam1)
384  */
385 void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2)
386 {
387   if (ignore_pad_v2_const_quantization(pad_v2))
388   {
389     // propagate input quantization paramters from input to output and padding const value
390     auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
391     overwrite_quantparam(pad_v2_input, pad_v2);
392
393     auto const_value_node = loco::must_cast<luci::CircleConst *>(
394       pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
395     auto new_const = luci::clone(const_value_node);
396
397     const auto pad_v2_input_qparam = pad_v2_input->quantparam();
398     assert(pad_v2_input_qparam != nullptr);
399     assert(pad_v2_input_qparam->scale.size() == 1);
400     const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
401     const auto zerop = pad_v2_input_qparam->zerop.at(0);
402
403     quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
404     overwrite_quantparam(pad_v2_input, new_const);
405     pad_v2->constant_values(new_const);
406     return;
407   }
408
409   // Propagate quantization paramters from output to inputs,
410   // to fit both input and counstant_value in one quant range.
411   auto quant_input = [pad_v2](void (CirclePadV2::*arg_setter)(loco::Node *), uint32_t arg) {
412     auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
413
414     // Quantize constant values
415     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
416     {
417       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
418       if (is_quantized(const_node))
419         return;
420
421       if (const_node->dtype() != loco::DataType::FLOAT32)
422         throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
423
424       const auto pad_v2_qparam = pad_v2->quantparam();
425       if (pad_v2_qparam == nullptr)
426         throw std::runtime_error("quantparam of PadV2 is not found during propagation");
427
428       assert(pad_v2_qparam->scale.size() == 1);
429       const auto scaling_factor = pad_v2_qparam->scale.at(0);
430       const auto zerop = pad_v2_qparam->zerop.at(0);
431
432       auto new_const = luci::clone(const_node);
433       quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
434       overwrite_quantparam(pad_v2, new_const);
435       (pad_v2->*arg_setter)(new_const);
436     }
437     else
438     {
439       const auto succs = loco::succs(node);
440       if (succs.size() > 1)
441         return;
442
443       // Non-const input must have been quantized
444       assert(node->quantparam() != nullptr);
445       overwrite_quantparam(pad_v2, node);
446     }
447   };
448
449   quant_input(&CirclePadV2::input, 0);
450   quant_input(&CirclePadV2::constant_values, 2);
451 }
452
453 } // namespace luci
454
455 namespace
456 {
457
458 // Visitor to propagate quantization parameters backwards
459 struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<void>
460 {
461   void visit(luci::CircleNode *) {}
462
463   void visit(luci::CircleConcatenation *node) { propagate_concat_quantparam(node); }
464
465   void visit(luci::CircleOneHot *node) { propagate_one_hot_quantparam(node); }
466
467   void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
468
469   void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
470
471   // Propagate qparam for non-value changing Ops
472   // (ex: Reshape, Transpose, etc.)
473   // TODO Add more Ops
474
475   void visit(luci::CircleReshape *node)
476   {
477     auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
478
479     // Do not propagate qparam if input node has multiple users
480     if (loco::succs(input_node).size() > 1)
481       return;
482
483     const auto input_opcode = input_node->opcode();
484
485     // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
486     // Why? It is not safe to propagate qparam to some virtual nodes. For example,
487     // const node, multi-out nodes. Let's block them for now.
488     // TODO Revisit this condition
489     if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
490       return;
491
492     overwrite_quantparam(node, input_node);
493   }
494
495   void visit(luci::CircleTranspose *node)
496   {
497     auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
498
499     // Do not propagate qparam if input node has multiple users
500     if (loco::succs(input_node).size() > 1)
501       return;
502
503     const auto input_opcode = input_node->opcode();
504
505     // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
506     // Why? It is not safe to propagate qparam to some virtual nodes. For example,
507     // const node, multi-out nodes. Let's block them for now.
508     // TODO Revisit this condition
509     if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
510       return;
511
512     overwrite_quantparam(node, input_node);
513   }
514 };
515
516 } // namespace
517
518 namespace luci
519 {
520
521 bool PropagateQParamBackwardPass::run(loco::Graph *g)
522 {
523   LOGGER(l);
524
525   // We use reverse post-order traversal as qparam is propagated backward
526   auto nodes = loco::postorder_traversal(loco::output_nodes(g));
527   std::reverse(nodes.begin(), nodes.end());
528   for (auto node : nodes)
529   {
530     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
531     INFO(l) << "PropagateQParamBackwardPass visit node: " << circle_node->name() << std::endl;
532
533     // We can't propagate non-existent qparam
534     if (circle_node->quantparam() == nullptr)
535       continue;
536
537     PropagateQParamBackward pqb;
538     circle_node->accept(&pqb);
539   }
540
541   // This pass is only run once, so return false
542   // TODO Refactoring not to return meaningless value
543   return false;
544 }
545
546 } // namespace luci