Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / QuantizeWithMinMaxPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/QuantizeWithMinMaxPass.h"
18 #include "QuantizationUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Service/Nodes/CircleConst.h>
23 #include <luci/Log.h>
24
25 #include <oops/UserExn.h>
26
27 #include <iostream>
28 #include <cmath>
29 #include <functional>
30
31 namespace
32 {
33
34 using namespace luci;
35 using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
36
37 void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
38 {
39   loco::TensorShape dimension;
40   dimension.rank(4);
41   uint32_t indices[4] = {
42     0,
43   };
44
45   if (!get_channel_dim_index(node, dimension, channel_dim_index))
46   {
47     assert(false);
48     return;
49   }
50
51   for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
52   {
53     for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
54     {
55       for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
56       {
57         for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
58         {
59           func(indices, dimension, channel_dim_index);
60         }
61       }
62     }
63   }
64 }
65
66 } // namespace
67
68 namespace luci
69 {
70
71 namespace
72 {
73
74 // Create a new const node from an existing node.
75 // The new node has the following characteristics
76 // type: T
77 // shape: same with 'node' (given as an argument)
78 // buffer size: 'size' (given as an argument)
79 // Note that contents are not filled in this function.
80 template <loco::DataType T>
81 luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
82 {
83   auto new_node = node->graph()->nodes()->create<CircleConst>();
84   // TODO: We don't have any naming convention for quantized nodes yet.
85   //       Fix this when we have one.
86   new_node->name(node->name());
87   new_node->dtype(T);
88   new_node->rank(node->rank());
89   for (uint32_t i = 0; i < node->rank(); i++)
90     new_node->dim(i).set(node->dim(i).value());
91
92   new_node->size<T>(size);
93   new_node->shape_status(luci::ShapeStatus::VALID);
94
95   return new_node;
96 }
97
98 void overwrite_quantparam(luci::CircleNode *source, luci::CircleNode *target)
99 {
100   auto source_qparam = source->quantparam();
101   if (source_qparam == nullptr)
102     throw std::runtime_error("source quantparam is not found during overwrite");
103
104   auto target_qparam = target->quantparam();
105   if (target_qparam == nullptr)
106   {
107     auto quantparam = std::make_unique<CircleQuantParam>();
108     target->quantparam(std::move(quantparam));
109     target_qparam = target->quantparam();
110
111     if (target_qparam == nullptr)
112       throw std::runtime_error("Creating new quant param failed");
113   }
114   target_qparam->min = source_qparam->min;
115   target_qparam->max = source_qparam->max;
116   target_qparam->scale = source_qparam->scale;
117   target_qparam->zerop = source_qparam->zerop;
118   target_qparam->quantized_dimension = source_qparam->quantized_dimension;
119 }
120
121 void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
122                         loco::DataType quant_type)
123 {
124   uint32_t size = const_node->size<loco::DataType::FLOAT32>();
125
126   const float scaling_factor_inv = 1.0 / scaling_factor;
127   std::vector<int32_t> quantized_values(size);
128   for (uint32_t i = 0; i < size; ++i)
129   {
130     auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
131     double quantized_float = std::round(data * scaling_factor_inv) + zerop;
132     constexpr auto int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
133     constexpr auto int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
134     quantized_float = std::min(int_max, std::max(int_min, quantized_float));
135
136     quantized_values[i] = static_cast<int32_t>(quantized_float);
137   }
138
139   switch (quant_type)
140   {
141     case loco::DataType::U8:
142       const_node->dtype(loco::DataType::U8);      // change the type of tensor
143       const_node->size<loco::DataType::U8>(size); // resize tensor
144       for (uint32_t i = 0; i < size; ++i)
145         const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
146       break;
147     case loco::DataType::S16:
148       assert(zerop == 0);
149       const_node->dtype(loco::DataType::S16);      // change the type of tensor
150       const_node->size<loco::DataType::S16>(size); // resize tensor
151       for (uint32_t i = 0; i < size; ++i)
152         const_node->at<loco::DataType::S16>(i) =
153           std::min(32767, std::max(-32767, quantized_values[i]));
154       break;
155     default:
156       throw std::runtime_error("Unsupported data type");
157   }
158 }
159
160 // Quantize const per channel
161 //
162 // The last dimension of const is the same as the dimension of channel
163 // And the rest of the const dimensions should be 1
164 // So, a 'single value' is quantized per channel
165 //
166 // Quantization spec (f: fp value, q: quantized value)
167 //
168 // uint8
169 //   Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
170 //   Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
171 //
172 // int16
173 //   Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
174 //   Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
175 void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
176 {
177   assert(node->dtype() == loco::DataType::FLOAT32);
178   assert(node->rank() > 0);
179
180   for (uint32_t i = 0; i < node->rank() - 1; i++)
181   {
182     // Caller should call this function when the below condition is satisfied
183     if (node->dim(i).value() != 1)
184       throw std::runtime_error("Non-channel dimension of const node must be 1");
185   }
186
187   uint32_t size = node->size<loco::DataType::FLOAT32>();
188   assert(size == node->dim(node->rank() - 1).value());
189
190   auto quantparam = std::make_unique<CircleQuantParam>();
191   quantparam->quantized_dimension = node->rank() - 1;
192   std::vector<int32_t> quantized_data(size);
193
194   for (uint32_t i = 0; i < size; ++i)
195   {
196     auto data = node->at<loco::DataType::FLOAT32>(i);
197     if (quant_type == loco::DataType::U8)
198     {
199       if (data >= 0)
200       {
201         quantparam->scale.push_back(data);
202         quantparam->zerop.push_back(0);
203         quantized_data[i] = 1;
204       }
205       else
206       {
207         quantparam->scale.push_back(-data);
208         quantparam->zerop.push_back(1);
209         quantized_data[i] = 0;
210       }
211     }
212     else if (quant_type == loco::DataType::S16)
213     {
214       if (data >= 0)
215       {
216         quantparam->scale.push_back(data);
217         quantized_data[i] = 1;
218       }
219       else
220       {
221         quantparam->scale.push_back(-data);
222         quantized_data[i] = -1;
223       }
224       quantparam->zerop.push_back(0);
225     }
226   }
227   node->quantparam(std::move(quantparam));
228
229   switch (quant_type)
230   {
231     case loco::DataType::U8:
232       node->dtype(loco::DataType::U8);
233       node->size<loco::DataType::U8>(size);
234       for (uint32_t i = 0; i < size; ++i)
235       {
236         assert(quantized_data[i] == 0 || quantized_data[i] == 1);
237         node->at<loco::DataType::U8>(i) = quantized_data[i];
238       }
239       break;
240     case loco::DataType::S16:
241       node->dtype(loco::DataType::S16);
242       node->size<loco::DataType::S16>(size);
243       for (uint32_t i = 0; i < size; ++i)
244       {
245         assert(quantized_data[i] == -1 || quantized_data[i] == 1);
246         node->at<loco::DataType::S16>(i) = quantized_data[i];
247       }
248       break;
249     default:
250       throw std::runtime_error("Unsupported data type");
251   }
252 }
253
254 void quant_const(CircleConst *node, loco::DataType quant_type)
255 {
256   assert(node->dtype() == loco::DataType::FLOAT32);
257
258   float min = std::numeric_limits<float>::max();
259   float max = std::numeric_limits<float>::lowest();
260   for (uint32_t i = 0; i < node->size<loco::DataType::FLOAT32>(); i++)
261   {
262     auto data = node->at<loco::DataType::FLOAT32>(i);
263     min = data < min ? data : min;
264     max = data > max ? data : max;
265   }
266
267   float scaling_factor{0.0};
268   int64_t zp{0};
269   float nudged_min{0.0};
270   float nudged_max{0.0};
271
272   switch (quant_type)
273   {
274     case loco::DataType::U8:
275       asymmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
276                                               nudged_max);
277       break;
278     case loco::DataType::S16:
279       symmetric_wquant_with_minmax_per_layer(node, min, max, scaling_factor, zp, nudged_min,
280                                              nudged_max);
281       break;
282     default:
283       throw std::runtime_error("Unsupported data type");
284   }
285
286   auto quantparam = std::make_unique<CircleQuantParam>();
287   quantparam->scale.push_back(scaling_factor);
288   quantparam->zerop.push_back(zp);
289   node->quantparam(std::move(quantparam));
290 }
291
292 // Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer
293 // Returns a list of <input, weights, output> vectors for the above operators.
294 // Note that it returns a 'list' because bias can be used by multiple operators.
295 std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node)
296 {
297   std::vector<std::vector<loco::Node *>> result;
298   auto circle_const = dynamic_cast<CircleConst *>(node);
299   if (circle_const == nullptr)
300     return result;
301
302   auto succs = loco::succs(node);
303
304   for (auto out : succs)
305   {
306     auto conv = dynamic_cast<CircleConv2D *>(out);
307     if (conv != nullptr && conv->bias() == circle_const)
308     {
309       assert(conv->input() != nullptr);
310       assert(conv->filter() != nullptr);
311       result.push_back({conv->input(), conv->filter(), conv});
312       continue;
313     }
314     auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
315     if (dw_conv != nullptr && dw_conv->bias() == circle_const)
316     {
317       assert(dw_conv->input() != nullptr);
318       assert(dw_conv->filter() != nullptr);
319       result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv});
320       continue;
321     }
322     auto fc = dynamic_cast<CircleFullyConnected *>(out);
323     if (fc != nullptr && fc->bias() == circle_const)
324     {
325       assert(fc->input() != nullptr);
326       assert(fc->weights() != nullptr);
327       result.push_back({fc->input(), fc->weights(), fc});
328       continue;
329     }
330     auto tconv = dynamic_cast<CircleTransposeConv *>(out);
331     if (tconv != nullptr && tconv->bias() == circle_const)
332     {
333       assert(tconv->outBackprop() != nullptr);
334       assert(tconv->filter() != nullptr);
335       result.push_back({tconv->outBackprop(), tconv->filter(), tconv});
336       continue;
337     }
338   }
339   return result;
340 }
341
342 CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
343                                        float *scaling_factor, int64_t *zp)
344 {
345   float scale = input_scale * weight_scale;
346   const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
347
348   uint32_t size = node->size<loco::DataType::FLOAT32>();
349   std::vector<int32_t> quantized_values(size);
350   for (uint32_t i = 0; i < size; ++i)
351   {
352     quantized_values[i] =
353       static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
354   }
355
356   auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
357
358   const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
359   const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
360   for (uint32_t i = 0; i < size; ++i)
361   {
362     new_bias->at<loco::DataType::S32>(i) =
363       std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
364   }
365   *scaling_factor = scale;
366   *zp = 0;
367
368   return new_bias;
369 }
370
371 CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
372                                     std::vector<float> &weight_scale,
373                                     std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
374 {
375   float scaling_factor_inv{0};
376
377   uint32_t size = node->size<loco::DataType::FLOAT32>();
378   std::vector<int32_t> quantized_values(size);
379
380   for (uint32_t i = 0; i < size; ++i)
381   {
382     scaling_factor[i] = input_scale * weight_scale[i];
383     scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
384     quantized_values[i] =
385       static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
386     zp[i] = 0;
387   }
388
389   auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
390
391   const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
392   const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
393   for (uint32_t i = 0; i < size; ++i)
394   {
395     new_bias->at<loco::DataType::S32>(i) =
396       std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
397   }
398
399   return new_bias;
400 }
401
402 CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
403                                           std::vector<float> &weight_scale,
404                                           std::vector<float> &scaling_factor,
405                                           std::vector<int64_t> &zp)
406 {
407   float scaling_factor_inv{0};
408
409   uint32_t size = node->size<loco::DataType::FLOAT32>();
410   std::vector<int64_t> quantized_values(size);
411
412   for (uint32_t i = 0; i < size; ++i)
413   {
414     scaling_factor[i] = input_scale * weight_scale[i];
415     scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
416     quantized_values[i] =
417       static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
418     zp[i] = 0;
419   }
420
421   auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
422
423   for (uint32_t i = 0; i < size; ++i)
424   {
425     new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
426   }
427
428   return new_bias;
429 }
430
431 bool has_min_max(const CircleNode *node)
432 {
433   return node->quantparam() && !node->quantparam()->min.empty() && !node->quantparam()->max.empty();
434 }
435
436 void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
437                             int32_t &channel_dim_index)
438 {
439   assert(node->dtype() == loco::DataType::FLOAT32);
440
441   const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
442   const int32_t kMinScale = -kMaxScale;
443
444   uint32_t size = node->size<loco::DataType::FLOAT32>();
445   std::vector<int32_t> quantized_values(size);
446
447   auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
448     int channel_idx = indices[channel_dim_index];
449     const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
450     auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
451     quantized_values[cal_offset(dimension, indices)] =
452       static_cast<int32_t>(std::round(data * scaling_factor_inv));
453   };
454
455   iterate_per_channel(node, channel_dim_index, quantize);
456
457   node->dtype(loco::DataType::S16);      // change the type of tensor
458   node->size<loco::DataType::S16>(size); // resize tensor
459   for (uint32_t i = 0; i < size; ++i)
460   {
461     node->at<loco::DataType::S16>(i) =
462       std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
463   }
464 }
465
466 void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
467                              std::vector<float> &scaling_factor, int32_t &channel_dim_index)
468 {
469   assert(node->dtype() == loco::DataType::FLOAT32);
470
471   const int32_t kMinScale = 0;
472   const int32_t kMaxScale = 255;
473
474   uint32_t size = node->size<loco::DataType::FLOAT32>();
475   std::vector<int32_t> quantized_values(size);
476
477   auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
478     int channel_idx = indices[channel_dim_index];
479     const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
480     auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
481     quantized_values[cal_offset(dimension, indices)] =
482       static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
483   };
484
485   iterate_per_channel(node, channel_dim_index, quantize);
486
487   node->dtype(loco::DataType::U8);      // change the type of tensor
488   node->size<loco::DataType::U8>(size); // resize tensor
489   for (uint32_t i = 0; i < size; ++i)
490   {
491     node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
492   }
493 }
494
495 void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
496 {
497   const int32_t kMinScale = 0;
498   const int32_t kMaxScale = 255;
499
500   uint32_t size = node->size<loco::DataType::FLOAT32>();
501
502   const float scaling_factor_inv = 1.0 / scaling_factor;
503   std::vector<int32_t> quantized_values(size);
504   for (uint32_t i = 0; i < size; ++i)
505   {
506     auto data = node->at<loco::DataType::FLOAT32>(i);
507     quantized_values[i] = static_cast<int32_t>(std::round((data - min) * scaling_factor_inv));
508   }
509
510   node->dtype(loco::DataType::U8);      // change the type of tensor
511   node->size<loco::DataType::U8>(size); // resize tensor
512   for (uint32_t i = 0; i < size; ++i)
513   {
514     node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
515   }
516 }
517
518 void set_bias(luci::CircleNode *node, luci::CircleConst *bias)
519 {
520   if (auto conv = dynamic_cast<CircleConv2D *>(node))
521     conv->bias(bias);
522   else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node))
523     dconv->bias(bias);
524   else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node))
525     tconv->bias(bias);
526   else if (auto fc = dynamic_cast<CircleFullyConnected *>(node))
527     fc->bias(bias);
528   else
529     throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and "
530                              "fully-connected layer have bias");
531 }
532
533 void set_act_qparam(luci::CircleNode *node, float scale, int64_t zp)
534 {
535   assert(node);               // FIX_CALLER_UNLESS
536   assert(node->quantparam()); // FIX_CALLER_UNLESS
537
538   auto qparam = node->quantparam();
539   assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
540   assert(qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
541   qparam->scale[0] = scale;
542   qparam->zerop[0] = zp;
543 }
544
545 /**
546  * @brief Manually set scale/zp of output tensor of special Ops
547  */
548 struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
549 {
550   QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
551     : input_type(input), output_type(output)
552   {
553   }
554
555   loco::DataType input_type;
556   loco::DataType output_type;
557
558   void visit(luci::CircleNode *)
559   {
560     // Do nothing by default
561   }
562
563   void visit(luci::CircleLogistic *node)
564   {
565     if (output_type == loco::DataType::U8)
566       set_act_qparam(node, 1.0f / 256.0f, 0);
567     else
568     {
569       assert(output_type == loco::DataType::S16);
570       set_act_qparam(node, 1.0f / 32768.0f, 0);
571     }
572   }
573
574   void visit(luci::CircleTanh *node)
575   {
576     if (output_type == loco::DataType::U8)
577       set_act_qparam(node, 2.0f / 256.0f, 128);
578     else
579     {
580       assert(output_type == loco::DataType::S16);
581       set_act_qparam(node, 1.0f / 32768.0f, 0);
582     }
583   }
584
585   void visit(luci::CircleStridedSlice *node)
586   {
587     auto input = loco::must_cast<luci::CircleNode *>(node->input());
588     auto i_qparam = input->quantparam();
589     assert(i_qparam);
590     assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
591     assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
592     auto i_scale = i_qparam->scale[0];
593     auto i_zp = i_qparam->zerop[0];
594
595     set_act_qparam(node, i_scale, i_zp);
596   }
597
598   void visit(luci::CircleSplitOut *node)
599   {
600     auto split = loco::must_cast<luci::CircleSplit *>(node->input());
601     auto input = loco::must_cast<luci::CircleNode *>(split->input());
602     auto i_qparam = input->quantparam();
603     assert(i_qparam);
604     assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
605     assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
606     auto i_scale = i_qparam->scale[0];
607     auto i_zp = i_qparam->zerop[0];
608
609     set_act_qparam(node, i_scale, i_zp);
610   }
611
612   void visit(luci::CircleSplitVOut *node)
613   {
614     auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
615     auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
616     auto i_qparam = input->quantparam();
617     assert(i_qparam);
618     assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
619     assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
620     auto i_scale = i_qparam->scale[0];
621     auto i_zp = i_qparam->zerop[0];
622
623     set_act_qparam(node, i_scale, i_zp);
624   }
625
626   void visit(luci::CircleUnpackOut *node)
627   {
628     auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
629     auto input = loco::must_cast<luci::CircleNode *>(unpack->value());
630     auto i_qparam = input->quantparam();
631     assert(i_qparam);
632     assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
633     assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
634     auto i_scale = i_qparam->scale[0];
635     auto i_zp = i_qparam->zerop[0];
636
637     set_act_qparam(node, i_scale, i_zp);
638   }
639
640   // TODO Move Softmax, Floor, Ceil from QuantizeActivation to here
641 };
642
643 /**
644  * @brief QuantizeActivation quantizes tensors for activations
645  * @details Quantize using recorded min/max values
646  */
647 struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
648 {
649   QuantizeActivation(loco::DataType input, loco::DataType output)
650     : input_type(input), output_type(output)
651   {
652   }
653
654   loco::DataType input_type;
655   loco::DataType output_type;
656
657   // Quantize input tensors of each node
658   bool visit(luci::CircleNode *node)
659   {
660     LOGGER(l);
661     INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
662     auto arity = node->arity();
663     for (uint32_t i = 0; i < arity; i++)
664     {
665       auto input_node = node->arg(i);
666       auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
667
668       // Check if this is already quantized
669       if (is_quantized(circle_node))
670         continue;
671
672       // Check if this is bias (bias is quantized later)
673       auto iwo = get_input_weight_output_of_bias(circle_node);
674       if (iwo.size() > 0)
675         continue;
676
677       // Check if this is bool type (bool type is not quantized)
678       if (circle_node->dtype() == loco::DataType::BOOL)
679         continue;
680
681       // Check if this is activation
682       // We assume min/max are recorded only for activations
683       if (has_min_max(circle_node) && !is_weights(circle_node))
684       {
685         // Quantize using recorded min/max
686         auto quantparam = circle_node->quantparam();
687         assert(quantparam);
688         assert(quantparam->min.size() == 1); // only support layer-wise quant
689         assert(quantparam->max.size() == 1); // only support layer-wise quant
690         auto min = quantparam->min[0];
691         auto max = quantparam->max[0];
692
693         // Special values
694         if (circle_node->opcode() == luci::CircleOpcode::SOFTMAX)
695         {
696           min = 0.0f;
697           max = 1.0f;
698         }
699
700         float scaling_factor{0};
701         int64_t zp{0};
702         float nudged_min{0};
703         float nudged_max{0};
704
705         if (output_type == loco::DataType::U8)
706         {
707           compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
708           circle_node->dtype(loco::DataType::U8);
709         }
710         else
711         {
712           compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
713           circle_node->dtype(loco::DataType::S16);
714         }
715
716         // Nodes fused with activation functions which need special quantization
717         auto fused_act_node =
718           dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(circle_node);
719         if (fused_act_node != nullptr &&
720             fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
721         {
722           if (output_type == loco::DataType::U8)
723           {
724             scaling_factor = 2.0f / 256.0f;
725             zp = 128;
726           }
727           else
728           {
729             assert(output_type == loco::DataType::S16);
730             scaling_factor = 1.0f / 32768.0f;
731             zp = 0;
732           }
733         }
734
735         // The output of these Ops should be integer, so scale should be integer
736         // TODO Handle cases where the integer scale needs to be propagated
737         if (circle_node->opcode() == CircleOpcode::FLOOR ||
738             circle_node->opcode() == CircleOpcode::FLOOR_DIV ||
739             circle_node->opcode() == CircleOpcode::FLOOR_MOD ||
740             circle_node->opcode() == CircleOpcode::CEIL)
741         {
742           assert(scaling_factor >= 0); // FIX_ME_UNLESS
743           scaling_factor = scaling_factor < 1 ? 1.0f : std::round(scaling_factor);
744         }
745
746         circle_node->quantparam()->min.clear();
747         circle_node->quantparam()->max.clear();
748         circle_node->quantparam()->scale.push_back(scaling_factor);
749         circle_node->quantparam()->zerop.push_back(zp);
750       }
751       // Fix special attributes
752       if (circle_node->opcode() == luci::CircleOpcode::CAST)
753       {
754         auto *cast = loco::must_cast<luci::CircleCast *>(circle_node);
755         auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
756
757         // make sure that cast_input is already quantized
758         assert(cast_input->dtype() != loco::DataType::FLOAT32);
759         cast->in_data_type(cast_input->dtype());
760         cast->out_data_type(cast->dtype());
761       }
762     }
763     return false;
764   }
765 };
766
767 struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
768 {
769   QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
770     : input_type(input), output_type(output), granularity(gr)
771   {
772   }
773
774   loco::DataType input_type;
775   loco::DataType output_type;
776   QuantizationGranularity granularity;
777
778   // Quantize bias node
779   bool visit(luci::CircleNode *node)
780   {
781     // Check if this is already quantized
782     if (is_quantized(node))
783       return false;
784
785     auto iwo_list = get_input_weight_output_of_bias(node);
786
787     for (auto iwo : iwo_list)
788     {
789       assert(iwo.size() == 3);
790
791       auto input = loco::must_cast<luci::CircleNode *>(iwo[0]);
792       auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]);
793       auto output = loco::must_cast<luci::CircleNode *>(iwo[2]);
794
795       auto const_bias = loco::must_cast<luci::CircleConst *>(node);
796       assert(const_bias->dtype() == loco::DataType::FLOAT32);
797
798       // If input is const, it is quantized here, not in QuantizeActivation
799       if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
800       {
801         quant_const(const_input, output_type);
802       }
803
804       CircleConst *new_bias = nullptr;
805
806       if (granularity == QuantizationGranularity::ChannelWise)
807       {
808         auto input_q = input->quantparam();
809         assert(input_q);
810         assert(input_q->scale.size() == 1); // input scale's layer-wise
811         auto input_scale = input_q->scale[0];
812
813         assert(weight->quantparam() != nullptr); // weight scale's channel-wise
814         auto weight_scale = weight->quantparam()->scale;
815
816         uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
817         assert(size == weight_scale.size());
818         std::vector<float> scaling_factor(size);
819         std::vector<int64_t> zp(size);
820
821         if (output_type == loco::DataType::U8)
822         {
823           new_bias =
824             quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
825         }
826         else if (output_type == loco::DataType::S16)
827         {
828           new_bias =
829             int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
830         }
831         else
832         {
833           throw std::runtime_error("Unsupported quantization type.");
834         }
835
836         auto quantparam = std::make_unique<CircleQuantParam>();
837         quantparam->scale = scaling_factor;
838         quantparam->zerop = zp;
839         assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
840         new_bias->quantparam(std::move(quantparam));
841
842         set_bias(output, new_bias);
843       }
844       else
845       {
846         auto input_q = input->quantparam();
847         assert(input_q);
848         assert(input_q->scale.size() == 1); // Only support per-layer quant
849         auto input_scale = input_q->scale[0];
850
851         auto weight_q = weight->quantparam();
852         assert(weight_q);
853         assert(weight_q->scale.size() == 1); // Only support per-layer quant
854         auto weight_scale = weight_q->scale[0];
855
856         float scaling_factor{0};
857         int64_t zp{0};
858         new_bias =
859           asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
860         auto quantparam = std::make_unique<CircleQuantParam>();
861         quantparam->scale.push_back(scaling_factor);
862         quantparam->zerop.push_back(zp);
863         assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
864         new_bias->quantparam(std::move(quantparam));
865
866         set_bias(output, new_bias);
867       }
868     }
869     return false;
870   }
871 };
872
873 /**
874  * @brief QuantizeWeights quantizes tensors for weights
875  * @details Find min/max values on the fly and then quantize
876  */
877 struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
878 {
879   QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
880     : input_type(input), output_type(output), granularity(gr)
881   {
882   }
883
884   loco::DataType input_type;
885   loco::DataType output_type;
886   QuantizationGranularity granularity;
887
888 private:
889   void quantize_weights(luci::CircleConst *weights)
890   {
891     // Find min/max per channel-wise
892     if (granularity == QuantizationGranularity::ChannelWise)
893     {
894       auto quantparam = weights->quantparam();
895       if (quantparam == nullptr)
896       {
897         assert(false && "quantparam is nullptr");
898         return;
899       }
900
901       auto min = quantparam->min;
902       auto scaling_factor = quantparam->scale;
903       int32_t channel_dim_index = 0;
904
905       if (output_type == loco::DataType::U8)
906       {
907         asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
908       }
909       else
910       {
911         sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
912       }
913       quantparam->min.clear();
914       quantparam->max.clear();
915       quantparam->quantized_dimension = channel_dim_index;
916     }
917     // Find min/max per layer-wise
918     else
919     {
920       // Quantize using recorded quantparam
921       auto quantparam = weights->quantparam();
922       assert(quantparam != nullptr);
923       assert(quantparam->min.size() == 1);   // only support layer-wise quant
924       assert(quantparam->scale.size() == 1); // only support layer-wise quant
925       auto min = quantparam->min[0];
926       auto scaling_factor = quantparam->scale[0];
927       asym_wquant_per_layer(weights, min, scaling_factor);
928       quantparam->min.clear();
929       quantparam->max.clear();
930     }
931   }
932
933   bool visit(luci::CircleConv2D *node)
934   {
935     LOGGER(l);
936     INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
937
938     auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
939     if (!is_quantized(weights))
940     {
941       auto new_weights = luci::clone(weights);
942       node->filter(new_weights);
943       quantize_weights(new_weights);
944       return true;
945     }
946     return false;
947   }
948
949   bool visit(luci::CircleDepthwiseConv2D *node)
950   {
951     LOGGER(l);
952     INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
953
954     auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
955     if (!is_quantized(weights))
956     {
957       auto new_weights = luci::clone(weights);
958       node->filter(new_weights);
959       quantize_weights(new_weights);
960       return true;
961     }
962     return false;
963   }
964
965   bool visit(luci::CircleInstanceNorm *node)
966   {
967     LOGGER(l);
968     INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
969
970     auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
971     auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
972
973     bool changed = false;
974     if (!is_quantized(gamma))
975     {
976       assert(gamma->dtype() == loco::DataType::FLOAT32);
977       auto new_gamma = luci::clone(gamma);
978       if (granularity == QuantizationGranularity::LayerWise)
979         quant_const(new_gamma, output_type);
980       else if (granularity == QuantizationGranularity::ChannelWise)
981         quant_const_per_channel(new_gamma, output_type);
982       node->gamma(new_gamma);
983       changed = true;
984     }
985     if (!is_quantized(beta))
986     {
987       assert(beta->dtype() == loco::DataType::FLOAT32);
988       auto new_beta = luci::clone(beta);
989       if (granularity == QuantizationGranularity::LayerWise)
990         quant_const(new_beta, output_type);
991       else if (granularity == QuantizationGranularity::ChannelWise)
992         quant_const_per_channel(new_beta, output_type);
993       node->beta(new_beta);
994       changed = true;
995     }
996
997     return changed;
998   }
999
1000   bool visit(luci::CirclePRelu *node)
1001   {
1002     LOGGER(l);
1003     INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
1004
1005     auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
1006
1007     if (!is_quantized(alpha))
1008     {
1009       assert(alpha->dtype() == loco::DataType::FLOAT32);
1010       auto new_alpha = luci::clone(alpha);
1011       if (granularity == QuantizationGranularity::LayerWise)
1012         quant_const(new_alpha, output_type);
1013       else if (granularity == QuantizationGranularity::ChannelWise)
1014         quant_const_per_channel(new_alpha, output_type);
1015       node->alpha(new_alpha);
1016       return true;
1017     }
1018
1019     return false;
1020   }
1021
1022   bool visit(luci::CircleTransposeConv *node)
1023   {
1024     LOGGER(l);
1025     INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
1026
1027     auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
1028     if (!is_quantized(weights))
1029     {
1030       auto new_weights = luci::clone(weights);
1031       node->filter(new_weights);
1032       quantize_weights(new_weights);
1033       return true;
1034     }
1035     return false;
1036   }
1037
1038   bool visit(luci::CircleFullyConnected *node)
1039   {
1040     LOGGER(l);
1041     INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
1042
1043     auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
1044     if (!is_quantized(weights))
1045     {
1046       auto new_weights = luci::clone(weights);
1047       node->weights(new_weights);
1048       quantize_weights(new_weights);
1049       return true;
1050     }
1051     return false;
1052   }
1053
1054   bool visit(luci::CircleNode *) { return false; }
1055 };
1056
1057 /** EXAMPLE
1058  *
1059  * BEFORE
1060  *
1061  *         [CircleNode]       [CircleConst]
1062  *           (qparam1)           (FP32)
1063  *                   \            /
1064  *                    \          /
1065  *                    [CirclePack]
1066  *                     (qparam2)
1067  *
1068  *  AFTER
1069  *
1070  *         [CircleNode]        [CircleConst]   [CircleConst] <- Dead node
1071  *           (qparam2)           (qparam2)         (FP32)
1072  *                   \            /
1073  *                    \          /
1074  *                    [CirclePack]
1075  *                     (qparam2)
1076  *
1077  * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
1078  */
1079 void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type)
1080 {
1081   assert(pack->quantparam() != nullptr);
1082
1083   const auto num_inputs = pack->values_count();
1084
1085   for (uint32_t i = 0; i < num_inputs; i++)
1086   {
1087     auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
1088
1089     // Skip if this input is PACK Op
1090     if (node->opcode() == luci::CircleOpcode::PACK)
1091       continue;
1092
1093     // Quantize constant values
1094     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
1095     {
1096       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
1097       if (const_node->dtype() != loco::DataType::FLOAT32)
1098         throw std::runtime_error("Unsupported data type for constant input of pack Op");
1099
1100       const auto pack_qparam = pack->quantparam();
1101       if (pack_qparam == nullptr)
1102         throw std::runtime_error("quantparam of pack is not found during propagation");
1103
1104       assert(pack_qparam->scale.size() == 1);
1105       assert(pack_qparam->zerop.size() == 1);
1106       const auto scaling_factor = pack_qparam->scale[0];
1107       const auto zerop = pack_qparam->zerop[0];
1108
1109       auto new_const = luci::clone(const_node);
1110       quant_const_values(new_const, scaling_factor, zerop, quant_type);
1111       pack->values(i, new_const);
1112       overwrite_quantparam(pack, new_const);
1113     }
1114     else
1115     {
1116       const auto succs = loco::succs(node);
1117       if (succs.size() > 1)
1118         continue;
1119
1120       // Non-const input must have been quantized
1121       assert(node->quantparam() != nullptr);
1122       overwrite_quantparam(pack, node);
1123     }
1124   }
1125 }
1126
1127 /**
1128  * @brief Quantize const input tensors using min/max of const values
1129  */
1130 void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
1131 {
1132   auto opcode = node->opcode();
1133   auto arity = node->arity();
1134
1135   loco::Node *input_node{nullptr};
1136   luci::CircleConst *const_node{nullptr};
1137
1138   switch (opcode)
1139   {
1140     case luci::CircleOpcode::CONV_2D:
1141     case luci::CircleOpcode::DEPTHWISE_CONV_2D:
1142     case luci::CircleOpcode::FULLY_CONNECTED:
1143     case luci::CircleOpcode::INSTANCE_NORM:
1144     case luci::CircleOpcode::PRELU:
1145     case luci::CircleOpcode::TRANSPOSE_CONV:
1146       // Handled in QuantizeWeights and QuantizeBias
1147       break;
1148
1149     case luci::CircleOpcode::CONCATENATION:
1150       // Handled in propagate_concat_quantparam
1151       break;
1152
1153     case luci::CircleOpcode::LOGICAL_OR:
1154       // Inputs of logical Ops are bool, thus not quantized
1155       break;
1156
1157     case luci::CircleOpcode::ARG_MAX:
1158     case luci::CircleOpcode::ARG_MIN:
1159     case luci::CircleOpcode::BATCH_TO_SPACE_ND:
1160     case luci::CircleOpcode::LOCAL_RESPONSE_NORMALIZATION:
1161     case luci::CircleOpcode::MEAN:
1162     case luci::CircleOpcode::MIRROR_PAD:
1163     case luci::CircleOpcode::PAD:
1164     case luci::CircleOpcode::REDUCE_ANY:
1165     case luci::CircleOpcode::REDUCE_PROD:
1166     case luci::CircleOpcode::REDUCE_MAX:
1167     case luci::CircleOpcode::REDUCE_MIN:
1168     case luci::CircleOpcode::RESHAPE:
1169     case luci::CircleOpcode::RESIZE_BILINEAR:
1170     case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR:
1171     case luci::CircleOpcode::REVERSE_SEQUENCE:
1172     case luci::CircleOpcode::SLICE:
1173     case luci::CircleOpcode::SPACE_TO_BATCH_ND:
1174     case luci::CircleOpcode::SPLIT_V:
1175     case luci::CircleOpcode::STRIDED_SLICE:
1176     case luci::CircleOpcode::SUM:
1177     case luci::CircleOpcode::TILE:
1178     case luci::CircleOpcode::TOPK_V2:
1179     case luci::CircleOpcode::TRANSPOSE:
1180       // The second input of these Ops should not be quantized
1181       // Ex: axis, paddings
1182       input_node = node->arg(0);
1183       const_node = dynamic_cast<luci::CircleConst *>(input_node);
1184       if (const_node != nullptr && !is_quantized(const_node))
1185         quant_const(const_node, output_type);
1186       break;
1187
1188     case luci::CircleOpcode::ADD:
1189     case luci::CircleOpcode::ADD_N:
1190     case luci::CircleOpcode::DEPTH_TO_SPACE:
1191     case luci::CircleOpcode::DIV:
1192     case luci::CircleOpcode::ELU:
1193     case luci::CircleOpcode::EQUAL:
1194     case luci::CircleOpcode::EXP:
1195     case luci::CircleOpcode::FLOOR:
1196     case luci::CircleOpcode::FLOOR_DIV:
1197     case luci::CircleOpcode::GREATER:
1198     case luci::CircleOpcode::GREATER_EQUAL:
1199     case luci::CircleOpcode::LESS:
1200     case luci::CircleOpcode::LESS_EQUAL:
1201     case luci::CircleOpcode::LOGISTIC:
1202     case luci::CircleOpcode::MAXIMUM:
1203     case luci::CircleOpcode::MINIMUM:
1204     case luci::CircleOpcode::MUL:
1205     case luci::CircleOpcode::NOT_EQUAL:
1206     case luci::CircleOpcode::POW:
1207     case luci::CircleOpcode::RSQRT:
1208     case luci::CircleOpcode::SOFTMAX:
1209     case luci::CircleOpcode::SPACE_TO_DEPTH:
1210     case luci::CircleOpcode::SQRT:
1211     case luci::CircleOpcode::SUB:
1212     case luci::CircleOpcode::TANH:
1213     case luci::CircleOpcode::UNPACK:
1214       // Quantize all const inputs using their values
1215       for (uint32_t i = 0; i < arity; i++)
1216       {
1217         input_node = node->arg(i);
1218         const_node = dynamic_cast<luci::CircleConst *>(input_node);
1219         if (const_node != nullptr && !is_quantized(const_node))
1220           quant_const(const_node, output_type);
1221       }
1222       break;
1223
1224     case luci::CircleOpcode::SPLIT:
1225       // Only the second input is quantized
1226       // First input should not be quantized (e.g., split_dim)
1227       input_node = node->arg(1);
1228       const_node = dynamic_cast<luci::CircleConst *>(input_node);
1229       if (const_node != nullptr && !is_quantized(const_node))
1230         quant_const(const_node, output_type);
1231       break;
1232
1233     case luci::CircleOpcode::PADV2:
1234       // First and third constant inputs are quantized
1235       // Second input should not be quantized (e.g., paddings)
1236       // Quant params are propagated either from output range to the non-constant input
1237       // or from input to output and constant values
1238       propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type);
1239       break;
1240
1241     case luci::CircleOpcode::PACK:
1242       // Quant param is propagated from output to inputs
1243       propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type);
1244       break;
1245
1246     default:
1247       for (uint32_t i = 0; i < arity; i++)
1248       {
1249         input_node = node->arg(i);
1250         const_node = dynamic_cast<luci::CircleConst *>(input_node);
1251         if (const_node != nullptr)
1252           throw std::runtime_error("Unsupported Op for const inputs");
1253       }
1254       break;
1255   }
1256 }
1257
1258 } // namespace
1259
1260 /** BEFORE
1261  *
1262  *         [CircleNode]             [CircleConst]
1263  *         (U8 qparam1)                 (FP32)
1264  *                   \                    /
1265  *                    \                  /
1266  *                    [CircleConcatenation]
1267  *                        (U8 qparam2)
1268  *
1269  *  AFTER
1270  *         [CircleNode]             [CircleConst]   [CircleConst] <- Dead node
1271  *         (U8 qparam2)             (U8 qparam2)       (FP32)
1272  *                   \                    /
1273  *                    \                  /
1274  *                    [CircleConcatenation]
1275  *                        (U8 qparam2)
1276  */
1277 void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type)
1278 {
1279   assert(concat->quantparam() != nullptr);
1280
1281   const auto num_inputs = concat->numValues();
1282
1283   // Quantize const inputs using their values if concat has fused act function
1284   if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
1285   {
1286     for (uint32_t i = 0; i < num_inputs; i++)
1287     {
1288       auto node = concat->arg(i);
1289       auto const_node = dynamic_cast<luci::CircleConst *>(node);
1290       if (const_node != nullptr)
1291       {
1292         auto new_const = luci::clone(const_node);
1293         quant_const(new_const, quant_type);
1294         concat->values(i, new_const);
1295       }
1296     }
1297     return;
1298   }
1299
1300   for (uint32_t i = 0; i < num_inputs; i++)
1301   {
1302     auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
1303
1304     // Skip if this input is CONCAT Op
1305     if (node->opcode() == luci::CircleOpcode::CONCATENATION)
1306       continue;
1307
1308     // Quantize constant values
1309     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
1310     {
1311       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
1312       if (const_node->dtype() != loco::DataType::FLOAT32)
1313         throw std::runtime_error("Unsupported data type for constant input of concatenation Op");
1314
1315       const auto concat_qparam = concat->quantparam();
1316       if (concat_qparam == nullptr)
1317         throw std::runtime_error("quantparam of concat is not found during propagation");
1318
1319       assert(concat_qparam->scale.size() == 1);
1320       const auto scaling_factor = concat_qparam->scale[0];
1321       const auto zerop = concat_qparam->zerop[0];
1322
1323       auto new_const = luci::clone(const_node);
1324       quant_const_values(new_const, scaling_factor, zerop, quant_type);
1325       concat->values(i, new_const);
1326       overwrite_quantparam(concat, new_const);
1327     }
1328     else
1329     {
1330       const auto succs = loco::succs(node);
1331       if (succs.size() > 1)
1332         continue;
1333
1334       // Non-const input must have been quantized
1335       assert(node->quantparam() != nullptr);
1336       overwrite_quantparam(concat, node);
1337     }
1338   }
1339 }
1340
1341 /**
1342  * tells if pad_v2 quantization should ignore padding value
1343  * In that case padding const will be quantized with input parameters, and probably clipped
1344  */
1345 bool ignore_pad_v2_const_quantization(luci::CirclePadV2 *pad)
1346 {
1347   // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
1348   // TODO use metadata hints to detect this case
1349   auto const_value_node = dynamic_cast<luci::CircleConst *>(pad->arg(2));
1350   if (!const_value_node)
1351     return false;
1352   if (const_value_node->dtype() == loco::DataType::FLOAT32)
1353   {
1354     float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
1355     if (const_value == std::numeric_limits<float>::lowest())
1356       return true;
1357   }
1358   return false;
1359 }
1360
1361 /** BEFORE
1362  *
1363  *         [CircleNode] [CircleConst] [CircleConst]
1364  *         (U8 qparam1)     (S32)       (FP32)
1365  *                   \        |         /
1366  *                    \       |        /
1367  *                      [CirclePadV2]
1368  *                       (U8 qparam2)
1369  *
1370  *  AFTER (case 1)
1371  *
1372  *  By default qparam is propagated from output to inputs to meet backend requirements.
1373  *
1374  *         [CircleNode] [CircleConst] [CircleConst]   [CircleConst] <- Dead node
1375  *         (U8 qparam2)     (S32)      (U8 qparam2)       (FP32)
1376  *                   \        |         /
1377  *                    \       |        /
1378  *                      [CirclePadV2]
1379  *                       (U8 qparam2)
1380  *
1381  *  AFTER (case 2)
1382  *
1383  * In case padded value is the lowest float value
1384  * Qparam is propagated from input to output and constant.
1385  *
1386  * This is a special case for optimization constructed pad, needed to guarantee that
1387  * extremely large negative constant do not stretch output quantization range.
1388  *
1389  *         [CircleNode] [CircleConst] [CircleConst]   [CircleConst] <- Dead node
1390  *         (U8 qparam1)     (S32)      (U8 qparam1)       (FP32)
1391  *                   \        |         /
1392  *                    \       |        /
1393  *                      [CirclePadV2]
1394  *                       (U8 qparam1)
1395  */
1396 void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type)
1397 {
1398   if (ignore_pad_v2_const_quantization(pad_v2))
1399   {
1400     // propagate input quantization paramters from input to output and padding const value
1401     auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
1402     overwrite_quantparam(pad_v2_input, pad_v2);
1403
1404     auto const_value_node = loco::must_cast<luci::CircleConst *>(
1405       pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
1406     auto new_const = luci::clone(const_value_node);
1407
1408     const auto pad_v2_input_qparam = pad_v2_input->quantparam();
1409     assert(pad_v2_input_qparam != nullptr);
1410     assert(pad_v2_input_qparam->scale.size() == 1);
1411     const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
1412     const auto zerop = pad_v2_input_qparam->zerop.at(0);
1413
1414     quant_const_values(new_const, scaling_factor, zerop, quant_type);
1415     overwrite_quantparam(pad_v2_input, new_const);
1416     pad_v2->constant_values(new_const);
1417     return;
1418   }
1419
1420   // Propagate quantization paramters from output to inputs,
1421   // to fit both input and counstant_value in one quant range.
1422   auto quant_input = [pad_v2, quant_type](void (CirclePadV2::*arg_setter)(loco::Node *),
1423                                           uint32_t arg) {
1424     auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
1425
1426     // Quantize constant values
1427     if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
1428     {
1429       luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
1430       if (is_quantized(const_node))
1431         return;
1432
1433       if (const_node->dtype() != loco::DataType::FLOAT32)
1434         throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
1435
1436       const auto pad_v2_qparam = pad_v2->quantparam();
1437       if (pad_v2_qparam == nullptr)
1438         throw std::runtime_error("quantparam of PadV2 is not found during propagation");
1439
1440       assert(pad_v2_qparam->scale.size() == 1);
1441       const auto scaling_factor = pad_v2_qparam->scale.at(0);
1442       const auto zerop = pad_v2_qparam->zerop.at(0);
1443
1444       auto new_const = luci::clone(const_node);
1445       quant_const_values(new_const, scaling_factor, zerop, quant_type);
1446       overwrite_quantparam(pad_v2, new_const);
1447       (pad_v2->*arg_setter)(new_const);
1448     }
1449     // Subsequent PadV2 Ops quant params are not propagated
1450     else if (node->opcode() == luci::CircleOpcode::PADV2)
1451     {
1452       return;
1453     }
1454     else
1455     {
1456       const auto succs = loco::succs(node);
1457       if (succs.size() > 1)
1458         return;
1459
1460       // Non-const input must have been quantized
1461       assert(node->quantparam() != nullptr);
1462       overwrite_quantparam(pad_v2, node);
1463     }
1464   };
1465
1466   quant_input(&CirclePadV2::input, 0);
1467   quant_input(&CirclePadV2::constant_values, 2);
1468 }
1469
1470 bool QuantizeWithMinMaxPass::run(loco::Graph *g)
1471 {
1472   LOGGER(l);
1473   INFO(l) << "QuantizeWithMinMaxPass Start" << std::endl;
1474
1475   // Quantize activation
1476   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1477   {
1478     QuantizeActivation qa(_input_model_dtype, _output_model_dtype);
1479     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1480     circle_node->accept(&qa);
1481   }
1482
1483   // Quantize weights
1484   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1485   {
1486     QuantizeWeights qw(_input_model_dtype, _output_model_dtype, _granularity);
1487     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1488     circle_node->accept(&qw);
1489   }
1490
1491   // Quantize bias
1492   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1493   {
1494     QuantizeBias qb(_input_model_dtype, _output_model_dtype, _granularity);
1495     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1496     circle_node->accept(&qb);
1497   }
1498
1499   // Propagate quantization parameters of concat Op
1500   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1501   {
1502     auto concat = dynamic_cast<luci::CircleConcatenation *>(node);
1503     if (not concat)
1504       continue;
1505
1506     // Propagate qparam of concat to its inputs if
1507     // (1) concat is uint8-quantized
1508     // (2) concat has no fused activation function
1509     // (3) the input is not concatenation Op
1510     // (4) the input is not produced to Ops other than concat
1511     propagate_concat_quantparam(concat, _output_model_dtype);
1512   }
1513
1514   // Quantize const inputs other than weights and bias
1515   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1516   {
1517     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1518     quantize_const_inputs(circle_node, _output_model_dtype);
1519   }
1520
1521   // Update qparam of output of special Ops
1522   for (auto node : loco::active_nodes(loco::output_nodes(g)))
1523   {
1524     QuantizeSpecialActivation qsa(_input_model_dtype, _output_model_dtype);
1525     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
1526     circle_node->accept(&qsa);
1527   }
1528
1529   // Update output dtype
1530   auto graph_outputs = g->outputs();
1531   for (auto node : loco::output_nodes(g))
1532   {
1533     auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
1534     if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_model_dtype)
1535     {
1536       circle_node->dtype(_output_model_dtype);
1537       auto graph_output = graph_outputs->at(circle_node->index());
1538       graph_output->dtype(_output_model_dtype);
1539     }
1540   }
1541
1542   INFO(l) << "QuantizeWithMinMaxPass End" << std::endl;
1543   return false; // one time run
1544 }
1545
1546 } // namespace luci