Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ConvertToFakeQuantizedModelPass.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
18 #include "luci/Pass/QuantizationParameters.h"
19
20 #include "QuantizationUtils.h"
21
22 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <luci/IR/CircleNodes.h>
24 #include <luci/IR/CircleNodeVisitor.h>
25 #include <luci/Log.h>
26
27 namespace
28 {
29
30 // Create Quantize Op whose dtype/shape/qparam are the same with node
31 luci::CircleQuantize *create_quantize(luci::CircleNode *node)
32 {
33   auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
34   // DESIGN NOTE: Why use '_FQ_Quantize' instead of '_Quantize'?
35   // '_Quantize' is used in mixed-precision quantization
36   // We add '_FQ' to distinguish Op from mixed-precision quantization
37   quantize->name(node->name() + "_FQ_Quantize");
38   quantize->dtype(node->dtype());
39   quantize->rank(node->rank());
40   for (uint32_t i = 0; i < node->rank(); i++)
41     quantize->dim(i).set(node->dim(i).value());
42
43   quantize->shape_status(luci::ShapeStatus::VALID);
44
45   copy_quantparam(node, quantize);
46
47   luci::add_origin(quantize, luci::get_origin(node));
48
49   return quantize;
50 }
51
52 // Create Dequantize Op whose shape is the same with node
53 luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
54 {
55   auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
56   // DESIGN NOTE: Why use '_FQ_Dequantize' instead of '_Dequantize'?
57   // '_Dequantize' is used in mixed-precision quantization
58   // We add '_FQ' to distinguish Op from mixed-precision quantization
59   dequantize->name(node->name() + "_FQ_Dequantize");
60   dequantize->dtype(loco::DataType::FLOAT32);
61   dequantize->rank(node->rank());
62   for (uint32_t i = 0; i < node->rank(); i++)
63     dequantize->dim(i).set(node->dim(i).value());
64
65   dequantize->shape_status(luci::ShapeStatus::VALID);
66
67   luci::add_origin(dequantize, luci::get_origin(node));
68
69   return dequantize;
70 }
71
72 // Return true if node is quantized activation
73 // 1. dtype is u8 or s16
74 // 2. node has qparam
75 bool is_quant_act(const luci::CircleNode *node)
76 {
77   if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16)
78     return false;
79
80   if (not node->quantparam())
81     return false;
82
83   return true;
84 }
85
86 // Return true if node is quantized const
87 // 1. dtype is not fp32
88 // 2. node has qparam
89 // NOTE Quantized const can have the following types
90 // u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias)
91 bool is_quant_const(const luci::CircleConst *node)
92 {
93   if (node->dtype() == loco::DataType::FLOAT32)
94     return false;
95
96   if (not node->quantparam())
97     return false;
98
99   return true;
100 }
101
102 // Insert dequantize Op after node
103 void insert_dequantize(loco::Node *lnode)
104 {
105   auto node = loco::must_cast<luci::CircleNode *>(lnode);
106   auto dequant = create_dequantize(node);
107   loco::replace(node).with(dequant);
108   dequant->input(node);
109 }
110
111 // Insert quantize Op after node and return the quantize Op
112 luci::CircleQuantize *insert_quantize(loco::Node *lnode)
113 {
114   auto node = loco::must_cast<luci::CircleNode *>(lnode);
115   auto quant = create_quantize(node);
116   loco::replace(node).with(quant);
117   quant->input(node);
118   return quant;
119 }
120
121 // Dequantize node
122 void dequantize(luci::CircleNode *node)
123 {
124   node->dtype(loco::DataType::FLOAT32);
125   node->quantparam(nullptr);
126 }
127
128 // Do fake quantization on quantized activation
129 // 1. Insert Quantize-Dequantize Ops
130 // 2. Update dtype/quantparam of node
131 void fq_activation(luci::CircleNode *node)
132 {
133   if (not is_quant_act(node))
134     return;
135
136   auto quant = insert_quantize(node);
137   insert_dequantize(quant);
138
139   dequantize(node);
140 }
141
142 #define RETURN_UNLESS(COND) \
143   if (not(COND))            \
144     return;
145
146 // Visitor to do fake quantization for each Op
147 // For non-const activation, insert Quantize-Dequantize after the ofm
148 // For quantized const, insert Dequantize after the const
149 struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
150 {
151   void visit(luci::CircleNode *node)
152   {
153     throw std::runtime_error("Unsupported op for fake quantization in " + node->name());
154   }
155
156   void visit(luci::CircleInput *node)
157   {
158     RETURN_UNLESS(is_quant_act(node));
159
160     auto quant = insert_quantize(node);
161     insert_dequantize(quant);
162
163     dequantize(node);
164
165     // Update graph input
166     const auto inputs = node->graph()->inputs();
167     auto graph_input = inputs->at(node->index());
168     graph_input->dtype(loco::DataType::FLOAT32);
169   }
170
171   void visit(luci::CircleOutput *node)
172   {
173     RETURN_UNLESS(is_quant_act(node));
174
175     dequantize(node);
176
177     // Update graph output
178     const auto outputs = node->graph()->outputs();
179     auto graph_output = outputs->at(node->index());
180     graph_output->dtype(loco::DataType::FLOAT32);
181   }
182
183   // For quantized const, insert Dequantize Op
184   void visit(luci::CircleConst *node)
185   {
186     RETURN_UNLESS(is_quant_const(node));
187
188     insert_dequantize(node);
189   }
190
191   // For non-const activation, insert Quantize-Dequantize Ops
192   // and dequantize the node
193   void visit(luci::CircleAbs *node) { fq_activation(node); }
194   void visit(luci::CircleAdd *node) { fq_activation(node); }
195   void visit(luci::CircleAveragePool2D *node) { fq_activation(node); }
196   void visit(luci::CircleBatchMatMul *node) { fq_activation(node); }
197   void visit(luci::CircleConv2D *node) { fq_activation(node); }
198   void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); }
199   void visit(luci::CircleDiv *node) { fq_activation(node); }
200   void visit(luci::CircleFullyConnected *node) { fq_activation(node); }
201   void visit(luci::CircleGelu *node) { fq_activation(node); }
202   void visit(luci::CircleInstanceNorm *node) { fq_activation(node); }
203   void visit(luci::CircleLeakyRelu *node) { fq_activation(node); }
204   void visit(luci::CircleLogistic *node) { fq_activation(node); }
205   void visit(luci::CircleLogSoftmax *node) { fq_activation(node); }
206   void visit(luci::CircleMaxPool2D *node) { fq_activation(node); }
207   void visit(luci::CircleMul *node) { fq_activation(node); }
208   void visit(luci::CircleNeg *node) { fq_activation(node); }
209   void visit(luci::CirclePad *node) { fq_activation(node); }
210   void visit(luci::CirclePRelu *node) { fq_activation(node); }
211   void visit(luci::CircleMean *node) { fq_activation(node); }
212   void visit(luci::CircleReduceProd *node) { fq_activation(node); }
213   void visit(luci::CircleReduceMax *node) { fq_activation(node); }
214   void visit(luci::CircleRelu *node) { fq_activation(node); }
215   void visit(luci::CircleRelu6 *node) { fq_activation(node); }
216   void visit(luci::CircleResizeBilinear *node) { fq_activation(node); }
217   void visit(luci::CircleResizeNearestNeighbor *node) { fq_activation(node); }
218   void visit(luci::CircleRsqrt *node) { fq_activation(node); }
219   void visit(luci::CircleSoftmax *node) { fq_activation(node); }
220   void visit(luci::CircleSqrt *node) { fq_activation(node); }
221   void visit(luci::CircleSquaredDifference *node) { fq_activation(node); }
222   void visit(luci::CircleSub *node) { fq_activation(node); }
223   void visit(luci::CircleSum *node) { fq_activation(node); }
224   void visit(luci::CircleTanh *node) { fq_activation(node); }
225   void visit(luci::CircleTransposeConv *node) { fq_activation(node); }
226
227   // For Ops that do not change the value of input, do nothing
228   // (dtype will be automatically updated by type inference)
229   void visit(luci::CircleCast *) {}
230   void visit(luci::CircleConcatenation *) {}
231   void visit(luci::CircleDepthToSpace *) {}
232   void visit(luci::CircleGather *) {}
233   void visit(luci::CircleSlice *) {}
234   void visit(luci::CircleStridedSlice *) {}
235   void visit(luci::CircleReshape *) {}
236   void visit(luci::CircleSpaceToDepth *) {}
237   void visit(luci::CircleSplit *) {}
238   void visit(luci::CircleSplitOut *) {}
239   void visit(luci::CircleSplitV *) {}
240   void visit(luci::CircleSplitVOut *) {}
241   void visit(luci::CircleTranspose *) {}
242   void visit(luci::CirclePack *) {}
243   void visit(luci::CircleUnpack *) {}
244   void visit(luci::CircleUnpackOut *) {}
245
246   // For Ops that return index, fake quantization is unnecessary
247   void visit(luci::CircleArgMax *) {}
248
249   // Virtual node
250   void visit(luci::CircleOutputExclude *) {}
251
252   void visit(luci::CircleQuantize *node)
253   {
254     RETURN_UNLESS(is_quant_act(node));
255
256     insert_dequantize(node);
257   }
258
259   // Dequantize Op does nothing in fp32 model
260   void visit(luci::CircleDequantize *) {}
261 };
262
263 #undef RETURN_UNLESS
264
265 } // namespace
266
267 namespace luci
268 {
269
270 bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g)
271 {
272   LOGGER(l);
273   for (auto node : loco::active_nodes(loco::output_nodes(g)))
274   {
275     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
276     INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl;
277
278     FakeQuantize fq;
279     circle_node->accept(&fq);
280   }
281
282   // One time run
283   return false;
284 }
285
286 } // namespace luci