2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/RequantizePass.h"
18 #include "QuantizationUtils.h"
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
24 #include <oops/UserExn.h>
35 // Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
36 bool is_bias(CircleConst *node)
41 auto succs = loco::succs(node);
42 if (succs.size() != 1) // assume bias is used by only one node
45 for (auto out : succs)
47 auto conv = dynamic_cast<CircleConv2D *>(out);
48 if (conv != nullptr && conv->bias() == node)
51 auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
52 if (dw_conv != nullptr && dw_conv->bias() == node)
55 auto fc = dynamic_cast<CircleFullyConnected *>(out);
56 if (fc != nullptr && fc->bias() == node)
59 // TODO: add TransposeConv when bias is supported in CircleTransposeConv
64 void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
66 assert(circle_node->dtype() == loco::DataType::S8);
68 auto quantparam = circle_node->quantparam();
69 assert(quantparam != nullptr);
70 for (size_t i = 0; i < quantparam->zerop.size(); ++i)
72 quantparam->zerop[i] += 128;
74 circle_node->dtype(loco::DataType::U8);
77 // Requantize CircleConst from symmetric int8 to asymmetric uint8
78 // Original values: -127 ~ 127
79 // After requantization: 1 ~ 255 (zp <- zp + 128)
80 void requant_const_int8_to_uint8(CircleConst *node)
82 assert(node->dtype() == loco::DataType::S8);
84 uint32_t size = node->size<loco::DataType::S8>();
85 std::vector<int32_t> requantized_values(size);
86 for (uint32_t i = 0; i < size; ++i)
88 int32_t data = node->at<loco::DataType::S8>(i);
89 requantized_values[i] = data + 128;
92 node->dtype(loco::DataType::U8); // change the type of tensor
93 node->size<loco::DataType::U8>(size);
94 for (uint32_t i = 0; i < size; ++i)
96 assert(1 <= requantized_values[i] && requantized_values[i] <= 255);
97 node->at<loco::DataType::U8>(i) = requantized_values[i];
100 auto quantparam = node->quantparam();
101 assert(quantparam != nullptr);
102 for (size_t i = 0; i < quantparam->zerop.size(); ++i)
104 quantparam->zerop[i] += 128;
109 * @brief RequantizeNonConst requantizes tensors for activations
111 struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
113 RequantizeNonConst(loco::DataType input, loco::DataType output)
114 : _input_type(input), _output_type(output)
118 loco::DataType _input_type;
119 loco::DataType _output_type;
121 // Requantize input tensors of each node
122 bool visit(luci::CircleNode *node)
125 INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl;
126 auto arity = node->arity();
127 for (uint32_t i = 0; i < arity; i++)
129 auto input_node = node->arg(i);
130 auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
132 // Check if this was quantized (only quantized tensors are requantized)
133 if (circle_node->quantparam() == nullptr)
136 // Check if this is already requantized
137 if (circle_node->dtype() == _output_type)
140 // Check if this is not const (only non-const is requantized in this function)
141 auto circle_const = dynamic_cast<CircleConst *>(circle_node);
142 if (circle_const != nullptr)
145 if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
146 requant_nonconst_int8_to_uint8(circle_node);
153 * @brief RequantizeConst requantizes tensors for weights
155 struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
157 RequantizeConst(loco::DataType input, loco::DataType output)
158 : _input_type(input), _output_type(output)
162 loco::DataType _input_type;
163 loco::DataType _output_type;
165 // Requantize input tensors of each node
166 bool visit(luci::CircleNode *node)
169 INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl;
170 auto arity = node->arity();
171 for (uint32_t i = 0; i < arity; i++)
173 auto input_node = node->arg(i);
174 auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
176 // Check if this was quantized (only quantized tensors are requantized)
177 if (circle_node->quantparam() == nullptr)
180 // Check if this is already requantized
181 if (circle_node->dtype() == _output_type)
184 // Check if this is const (only const is requantized in this function)
185 auto circle_const = dynamic_cast<CircleConst *>(circle_node);
186 if (circle_const == nullptr)
189 // Check if this is not bias
190 // bias is not requantized when int8 -> uint8
191 if (is_bias(circle_const))
194 if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
195 requant_const_int8_to_uint8(circle_const);
203 bool RequantizePass::run(loco::Graph *g)
206 INFO(l) << "RequantizePass Start" << std::endl;
208 // Requantize non-const (activations)
209 for (auto node : loco::active_nodes(loco::output_nodes(g)))
211 RequantizeNonConst rqnc(_input_dtype, _output_dtype);
212 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
213 circle_node->accept(&rqnc);
216 // Requantize const (including weights, constants)
217 for (auto node : loco::active_nodes(loco::output_nodes(g)))
219 RequantizeConst rqc(_input_dtype, _output_dtype);
220 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
221 circle_node->accept(&rqc);
224 // Update output dtype
225 auto graph_outputs = g->outputs();
226 for (auto node : loco::output_nodes(g))
228 auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
229 if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype)
231 circle_node->dtype(_output_dtype);
232 auto graph_output = graph_outputs->at(circle_node->index());
233 graph_output->dtype(_output_dtype);
237 INFO(l) << "RequantizePass End" << std::endl;
238 return false; // one time run