Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / RequantizePass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/RequantizePass.h"
18 #include "QuantizationUtils.h"
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Log.h>
23
24 #include <oops/UserExn.h>
25
26 #include <iostream>
27 #include <cmath>
28
29 namespace luci
30 {
31
32 namespace
33 {
34
35 // Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
36 bool is_bias(CircleConst *node)
37 {
38   if (node == nullptr)
39     return false;
40
41   auto succs = loco::succs(node);
42   if (succs.size() != 1) // assume bias is used by only one node
43     return false;
44
45   for (auto out : succs)
46   {
47     auto conv = dynamic_cast<CircleConv2D *>(out);
48     if (conv != nullptr && conv->bias() == node)
49       return true;
50
51     auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
52     if (dw_conv != nullptr && dw_conv->bias() == node)
53       return true;
54
55     auto fc = dynamic_cast<CircleFullyConnected *>(out);
56     if (fc != nullptr && fc->bias() == node)
57       return true;
58
59     // TODO: add TransposeConv when bias is supported in CircleTransposeConv
60   }
61   return false;
62 }
63
64 void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
65 {
66   assert(circle_node->dtype() == loco::DataType::S8);
67
68   auto quantparam = circle_node->quantparam();
69   assert(quantparam != nullptr);
70   for (size_t i = 0; i < quantparam->zerop.size(); ++i)
71   {
72     quantparam->zerop[i] += 128;
73   }
74   circle_node->dtype(loco::DataType::U8);
75 }
76
77 // Requantize CircleConst from symmetric int8 to asymmetric uint8
78 // Original values: -127 ~ 127
79 // After requantization: 1 ~ 255 (zp <- zp + 128)
80 void requant_const_int8_to_uint8(CircleConst *node)
81 {
82   assert(node->dtype() == loco::DataType::S8);
83
84   uint32_t size = node->size<loco::DataType::S8>();
85   std::vector<int32_t> requantized_values(size);
86   for (uint32_t i = 0; i < size; ++i)
87   {
88     int32_t data = node->at<loco::DataType::S8>(i);
89     requantized_values[i] = data + 128;
90   }
91
92   node->dtype(loco::DataType::U8); // change the type of tensor
93   node->size<loco::DataType::U8>(size);
94   for (uint32_t i = 0; i < size; ++i)
95   {
96     assert(1 <= requantized_values[i] && requantized_values[i] <= 255);
97     node->at<loco::DataType::U8>(i) = requantized_values[i];
98   }
99
100   auto quantparam = node->quantparam();
101   assert(quantparam != nullptr);
102   for (size_t i = 0; i < quantparam->zerop.size(); ++i)
103   {
104     quantparam->zerop[i] += 128;
105   }
106 }
107
108 /**
109  * @brief RequantizeNonConst requantizes tensors for activations
110  */
111 struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
112 {
113   RequantizeNonConst(loco::DataType input, loco::DataType output)
114       : _input_type(input), _output_type(output)
115   {
116   }
117
118   loco::DataType _input_type;
119   loco::DataType _output_type;
120
121   // Requantize input tensors of each node
122   bool visit(luci::CircleNode *node)
123   {
124     LOGGER(l);
125     INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl;
126     auto arity = node->arity();
127     for (uint32_t i = 0; i < arity; i++)
128     {
129       auto input_node = node->arg(i);
130       auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
131
132       // Check if this was quantized (only quantized tensors are requantized)
133       if (circle_node->quantparam() == nullptr)
134         continue;
135
136       // Check if this is already requantized
137       if (circle_node->dtype() == _output_type)
138         continue;
139
140       // Check if this is not const (only non-const is requantized in this function)
141       auto circle_const = dynamic_cast<CircleConst *>(circle_node);
142       if (circle_const != nullptr)
143         continue;
144
145       if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
146         requant_nonconst_int8_to_uint8(circle_node);
147     }
148     return false;
149   }
150 };
151
152 /**
153  * @brief RequantizeConst requantizes tensors for weights
154  */
155 struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
156 {
157   RequantizeConst(loco::DataType input, loco::DataType output)
158       : _input_type(input), _output_type(output)
159   {
160   }
161
162   loco::DataType _input_type;
163   loco::DataType _output_type;
164
165   // Requantize input tensors of each node
166   bool visit(luci::CircleNode *node)
167   {
168     LOGGER(l);
169     INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl;
170     auto arity = node->arity();
171     for (uint32_t i = 0; i < arity; i++)
172     {
173       auto input_node = node->arg(i);
174       auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
175
176       // Check if this was quantized (only quantized tensors are requantized)
177       if (circle_node->quantparam() == nullptr)
178         continue;
179
180       // Check if this is already requantized
181       if (circle_node->dtype() == _output_type)
182         continue;
183
184       // Check if this is const (only const is requantized in this function)
185       auto circle_const = dynamic_cast<CircleConst *>(circle_node);
186       if (circle_const == nullptr)
187         continue;
188
189       // Check if this is not bias
190       // bias is not requantized when int8 -> uint8
191       if (is_bias(circle_const))
192         continue;
193
194       if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
195         requant_const_int8_to_uint8(circle_const);
196     }
197     return false;
198   }
199 };
200
201 } // namespace
202
203 bool RequantizePass::run(loco::Graph *g)
204 {
205   LOGGER(l);
206   INFO(l) << "RequantizePass Start" << std::endl;
207
208   // Requantize non-const (activations)
209   for (auto node : loco::active_nodes(loco::output_nodes(g)))
210   {
211     RequantizeNonConst rqnc(_input_dtype, _output_dtype);
212     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
213     circle_node->accept(&rqnc);
214   }
215
216   // Requantize const (including weights, constants)
217   for (auto node : loco::active_nodes(loco::output_nodes(g)))
218   {
219     RequantizeConst rqc(_input_dtype, _output_dtype);
220     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
221     circle_node->accept(&rqc);
222   }
223
224   // Update output dtype
225   auto graph_outputs = g->outputs();
226   for (auto node : loco::output_nodes(g))
227   {
228     auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
229     if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype)
230     {
231       circle_node->dtype(_output_dtype);
232       auto graph_output = graph_outputs->at(circle_node->index());
233       graph_output->dtype(_output_dtype);
234     }
235   }
236
237   INFO(l) << "RequantizePass End" << std::endl;
238   return false; // one time run
239 }
240
241 } // namespace luci