2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
18 #include "luci/Pass/QuantizationParameters.h"
20 #include "QuantizationUtils.h"
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/IR/CircleNodes.h>
24 #include <luci/IR/CircleNodeVisitor.h>
30 // Create Quantize Op whose dtype/shape/qparam are the same with node
31 luci::CircleQuantize *create_quantize(luci::CircleNode *node)
33 auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
34 // DESIGN NOTE: Why use '_FQ_Quantize' instead of '_Quantize'?
35 // '_Quantize' is used in mixed-precision quantization
36 // We add '_FQ' to distinguish Op from mixed-precision quantization
37 quantize->name(node->name() + "_FQ_Quantize");
38 quantize->dtype(node->dtype());
39 quantize->rank(node->rank());
40 for (uint32_t i = 0; i < node->rank(); i++)
41 quantize->dim(i).set(node->dim(i).value());
43 quantize->shape_status(luci::ShapeStatus::VALID);
45 copy_quantparam(node, quantize);
47 luci::add_origin(quantize, luci::get_origin(node));
52 // Create Dequantize Op whose shape is the same with node
53 luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
55 auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
56 // DESIGN NOTE: Why use '_FQ_Dequantize' instead of '_Dequantize'?
57 // '_Dequantize' is used in mixed-precision quantization
58 // We add '_FQ' to distinguish Op from mixed-precision quantization
59 dequantize->name(node->name() + "_FQ_Dequantize");
60 dequantize->dtype(loco::DataType::FLOAT32);
61 dequantize->rank(node->rank());
62 for (uint32_t i = 0; i < node->rank(); i++)
63 dequantize->dim(i).set(node->dim(i).value());
65 dequantize->shape_status(luci::ShapeStatus::VALID);
67 luci::add_origin(dequantize, luci::get_origin(node));
72 // Return true if node is quantized activation
73 // 1. dtype is u8 or s16
75 bool is_quant_act(const luci::CircleNode *node)
77 if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16)
80 if (not node->quantparam())
86 // Return true if node is quantized const
87 // 1. dtype is not fp32
89 // NOTE Quantized const can have the following types
90 // u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias)
91 bool is_quant_const(const luci::CircleConst *node)
93 if (node->dtype() == loco::DataType::FLOAT32)
96 if (not node->quantparam())
102 // Insert dequantize Op after node
103 void insert_dequantize(loco::Node *lnode)
105 auto node = loco::must_cast<luci::CircleNode *>(lnode);
106 auto dequant = create_dequantize(node);
107 loco::replace(node).with(dequant);
108 dequant->input(node);
111 // Insert quantize Op after node and return the quantize Op
112 luci::CircleQuantize *insert_quantize(loco::Node *lnode)
114 auto node = loco::must_cast<luci::CircleNode *>(lnode);
115 auto quant = create_quantize(node);
116 loco::replace(node).with(quant);
122 void dequantize(luci::CircleNode *node)
124 node->dtype(loco::DataType::FLOAT32);
125 node->quantparam(nullptr);
128 // Do fake quantization on quantized activation
129 // 1. Insert Quantize-Dequantize Ops
130 // 2. Update dtype/quantparam of node
131 void fq_activation(luci::CircleNode *node)
133 if (not is_quant_act(node))
136 auto quant = insert_quantize(node);
137 insert_dequantize(quant);
142 #define RETURN_UNLESS(COND) \
146 // Visitor to do fake quantization for each Op
147 // For non-const activation, insert Quantize-Dequantize after the ofm
148 // For quantized const, insert Dequantize after the const
149 struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
151 void visit(luci::CircleNode *node)
153 throw std::runtime_error("Unsupported op for fake quantization in " + node->name());
156 void visit(luci::CircleInput *node)
158 RETURN_UNLESS(is_quant_act(node));
160 auto quant = insert_quantize(node);
161 insert_dequantize(quant);
165 // Update graph input
166 const auto inputs = node->graph()->inputs();
167 auto graph_input = inputs->at(node->index());
168 graph_input->dtype(loco::DataType::FLOAT32);
171 void visit(luci::CircleOutput *node)
173 RETURN_UNLESS(is_quant_act(node));
177 // Update graph output
178 const auto outputs = node->graph()->outputs();
179 auto graph_output = outputs->at(node->index());
180 graph_output->dtype(loco::DataType::FLOAT32);
183 // For quantized const, insert Dequantize Op
184 void visit(luci::CircleConst *node)
186 RETURN_UNLESS(is_quant_const(node));
188 insert_dequantize(node);
191 // For non-const activation, insert Quantize-Dequantize Ops
192 // and dequantize the node
193 void visit(luci::CircleAbs *node) { fq_activation(node); }
194 void visit(luci::CircleAdd *node) { fq_activation(node); }
195 void visit(luci::CircleAveragePool2D *node) { fq_activation(node); }
196 void visit(luci::CircleBatchMatMul *node) { fq_activation(node); }
197 void visit(luci::CircleConv2D *node) { fq_activation(node); }
198 void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); }
199 void visit(luci::CircleDiv *node) { fq_activation(node); }
200 void visit(luci::CircleFullyConnected *node) { fq_activation(node); }
201 void visit(luci::CircleGelu *node) { fq_activation(node); }
202 void visit(luci::CircleInstanceNorm *node) { fq_activation(node); }
203 void visit(luci::CircleLeakyRelu *node) { fq_activation(node); }
204 void visit(luci::CircleLogistic *node) { fq_activation(node); }
205 void visit(luci::CircleLogSoftmax *node) { fq_activation(node); }
206 void visit(luci::CircleMaxPool2D *node) { fq_activation(node); }
207 void visit(luci::CircleMul *node) { fq_activation(node); }
208 void visit(luci::CircleNeg *node) { fq_activation(node); }
209 void visit(luci::CirclePad *node) { fq_activation(node); }
210 void visit(luci::CirclePRelu *node) { fq_activation(node); }
211 void visit(luci::CircleMean *node) { fq_activation(node); }
212 void visit(luci::CircleReduceProd *node) { fq_activation(node); }
213 void visit(luci::CircleReduceMax *node) { fq_activation(node); }
214 void visit(luci::CircleRelu *node) { fq_activation(node); }
215 void visit(luci::CircleRelu6 *node) { fq_activation(node); }
216 void visit(luci::CircleResizeBilinear *node) { fq_activation(node); }
217 void visit(luci::CircleResizeNearestNeighbor *node) { fq_activation(node); }
218 void visit(luci::CircleRsqrt *node) { fq_activation(node); }
219 void visit(luci::CircleSoftmax *node) { fq_activation(node); }
220 void visit(luci::CircleSqrt *node) { fq_activation(node); }
221 void visit(luci::CircleSquaredDifference *node) { fq_activation(node); }
222 void visit(luci::CircleSub *node) { fq_activation(node); }
223 void visit(luci::CircleSum *node) { fq_activation(node); }
224 void visit(luci::CircleTanh *node) { fq_activation(node); }
225 void visit(luci::CircleTransposeConv *node) { fq_activation(node); }
227 // For Ops that do not change the value of input, do nothing
228 // (dtype will be automatically updated by type inference)
229 void visit(luci::CircleCast *) {}
230 void visit(luci::CircleConcatenation *) {}
231 void visit(luci::CircleDepthToSpace *) {}
232 void visit(luci::CircleGather *) {}
233 void visit(luci::CircleSlice *) {}
234 void visit(luci::CircleStridedSlice *) {}
235 void visit(luci::CircleReshape *) {}
236 void visit(luci::CircleSpaceToDepth *) {}
237 void visit(luci::CircleSplit *) {}
238 void visit(luci::CircleSplitOut *) {}
239 void visit(luci::CircleSplitV *) {}
240 void visit(luci::CircleSplitVOut *) {}
241 void visit(luci::CircleTranspose *) {}
242 void visit(luci::CirclePack *) {}
243 void visit(luci::CircleUnpack *) {}
244 void visit(luci::CircleUnpackOut *) {}
246 // For Ops that return index, fake quantization is unnecessary
247 void visit(luci::CircleArgMax *) {}
250 void visit(luci::CircleOutputExclude *) {}
252 void visit(luci::CircleQuantize *node)
254 RETURN_UNLESS(is_quant_act(node));
256 insert_dequantize(node);
259 // Dequantize Op does nothing in fp32 model
260 void visit(luci::CircleDequantize *) {}
270 bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g)
273 for (auto node : loco::active_nodes(loco::output_nodes(g)))
275 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
276 INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl;
279 circle_node->accept(&fq);