Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / RequantizePass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/RequantizePass.h"
18 #include "QuantizationUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Log.h>
23
24 #include <oops/UserExn.h>
25
26 #include <iostream>
27 #include <cmath>
28
29 namespace luci
30 {
31
32 namespace
33 {
34
35 // Requantize Non-const node from int8 to uint8
36 // Original values: -128 ~ 127
37 // After requantization: 0 ~ 255
38 void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
39 {
40   assert(circle_node->dtype() == loco::DataType::S8);
41
42   auto quantparam = circle_node->quantparam();
43   assert(quantparam != nullptr);
44   for (size_t i = 0; i < quantparam->zerop.size(); ++i)
45   {
46     quantparam->zerop[i] += 128;
47   }
48   circle_node->dtype(loco::DataType::U8);
49 }
50
51 // Requantize CircleConst from symmetric int8 to asymmetric uint8
52 // Original values: -127 ~ 127
53 // After requantization: 1 ~ 255 (zp <- zp + 128)
54 void requant_const_int8_to_uint8(CircleConst *node)
55 {
56   assert(node->dtype() == loco::DataType::S8);
57
58   uint32_t size = node->size<loco::DataType::S8>();
59   std::vector<int32_t> requantized_values(size);
60   for (uint32_t i = 0; i < size; ++i)
61   {
62     int32_t data = node->at<loco::DataType::S8>(i);
63     requantized_values[i] = data + 128;
64   }
65
66   node->dtype(loco::DataType::U8); // change the type of tensor
67   node->size<loco::DataType::U8>(size);
68   for (uint32_t i = 0; i < size; ++i)
69   {
70     assert(1 <= requantized_values[i] && requantized_values[i] <= 255);
71     node->at<loco::DataType::U8>(i) = requantized_values[i];
72   }
73
74   auto quantparam = node->quantparam();
75   assert(quantparam != nullptr);
76   for (size_t i = 0; i < quantparam->zerop.size(); ++i)
77   {
78     quantparam->zerop[i] += 128;
79   }
80 }
81
82 #define RETURN_UNLESS(cond) \
83   if (not(cond))            \
84     return;
85
86 /**
87  * @brief Requantize int8 quantized tensors to uint8 tensors
88  */
89 struct RequantizeS8ToU8 final : public luci::CircleNodeMutableVisitor<void>
90 {
91   // Requantize non-const tensors
92   void visit(luci::CircleNode *node)
93   {
94     LOGGER(l);
95     INFO(l) << "RequantizeS8ToU8 visit non-const node: " << node->name() << std::endl;
96
97     // Ignore non-quantized tensors
98     RETURN_UNLESS(node->quantparam() != nullptr);
99
100     // Check dtype is int8
101     RETURN_UNLESS(node->dtype() == loco::DataType::S8);
102
103     requant_nonconst_int8_to_uint8(node);
104   }
105
106   // Requantize const tensors
107   void visit(luci::CircleConst *node)
108   {
109     LOGGER(l);
110     INFO(l) << "RequantizeS8ToU8 visit const node: " << node->name() << std::endl;
111
112     // Ignore non-quantized tensors
113     RETURN_UNLESS(node->quantparam() != nullptr);
114
115     // Check dtype is int8
116     RETURN_UNLESS(node->dtype() == loco::DataType::S8);
117
118     requant_const_int8_to_uint8(node);
119   }
120 };
121
122 #undef RETURN_UNLESS
123
124 } // namespace
125
126 bool RequantizePass::run(loco::Graph *g)
127 {
128   LOGGER(l);
129   INFO(l) << "RequantizePass Start" << std::endl;
130
131   // Input: int8 model
132   // Output: uint8 model
133   if (_input_dtype == loco::DataType::S8 and _output_dtype == loco::DataType::U8)
134   {
135     for (auto node : loco::active_nodes(loco::output_nodes(g)))
136     {
137       RequantizeS8ToU8 rq;
138       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
139       circle_node->accept(&rq);
140     }
141   }
142   else
143   {
144     // Ignore other cases
145     return false;
146   }
147
148   // Update output dtype
149   auto graph_outputs = g->outputs();
150   for (auto node : loco::output_nodes(g))
151   {
152     auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
153     auto from_node = loco::must_cast<luci::CircleNode *>(circle_node->from());
154     if (from_node->dtype() == _output_dtype)
155     {
156       circle_node->dtype(_output_dtype);
157       auto graph_output = graph_outputs->at(circle_node->index());
158       graph_output->dtype(_output_dtype);
159     }
160   }
161
162   INFO(l) << "RequantizePass End" << std::endl;
163   return false; // one time run
164 }
165
166 } // namespace luci