Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / VerifyQuantizedNodeLayerWiseGranularity.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
17 #define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/Pass/QuantizationParameters.h>
22
23 using Granularity = luci::QuantizationGranularity;
24
25 // This macro is undef at the end of the file
26 #define RETURN_FALSE_UNLESS(ARG) \
27   if (not(ARG))                  \
28   {                              \
29     return false;                \
30   }
31
32 namespace luci
33 {
34
35 /**
36  * @brief Verify the granualrity of layer-wise quantized node
37  * @details
38  *
39  * Targets to verify
40  * - node's output (i.e., node itself)
41  * - node's inputs
42  */
43 struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool>
44 {
45 private:
46   bool is_lwq(const loco::Node *node)
47   {
48     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
49
50     if (circle_node->quantparam() == nullptr)
51       return false;
52
53     if (circle_node->quantparam()->scale.size() != 1)
54       return false;
55
56     if (circle_node->quantparam()->zerop.size() != 1)
57       return false;
58
59     return true;
60   }
61
62   bool is_lwq_const(const loco::Node *node)
63   {
64     auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
65
66     if (circle_node->quantparam() == nullptr)
67       return false;
68
69     if (circle_node->quantparam()->scale.size() != 1)
70       return false;
71
72     if (circle_node->quantparam()->zerop.size() != 1)
73       return false;
74
75     return true;
76   }
77
78 private:
79   bool visit(const luci::CircleConv2D *node)
80   {
81     RETURN_FALSE_UNLESS(is_lwq(node))
82     RETURN_FALSE_UNLESS(is_lwq(node->input()))
83     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
84     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
85     if (bias != nullptr)
86       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
87     return true;
88   }
89
90   bool visit(const luci::CircleConcatenation *node)
91   {
92     RETURN_FALSE_UNLESS(is_lwq(node))
93     for (uint32_t i = 0; i < node->numValues(); i++)
94     {
95       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
96     }
97     return true;
98   }
99
100   bool visit(const luci::CircleDepthToSpace *node)
101   {
102     RETURN_FALSE_UNLESS(is_lwq(node))
103     RETURN_FALSE_UNLESS(is_lwq(node->input()))
104     return true;
105   }
106
107   bool visit(const luci::CircleDepthwiseConv2D *node)
108   {
109     RETURN_FALSE_UNLESS(is_lwq(node))
110     RETURN_FALSE_UNLESS(is_lwq(node->input()))
111     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
112     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
113     if (bias != nullptr)
114       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
115     return true;
116   }
117
118   bool visit(const luci::CircleInstanceNorm *node)
119   {
120     RETURN_FALSE_UNLESS(is_lwq(node))
121     RETURN_FALSE_UNLESS(is_lwq(node->input()))
122     RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
123     RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
124     return true;
125   }
126
127   bool visit(const luci::CirclePack *node)
128   {
129     RETURN_FALSE_UNLESS(is_lwq(node))
130     for (uint32_t i = 0; i < node->values_count(); i++)
131     {
132       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
133     }
134     return true;
135   }
136
137   bool visit(const luci::CirclePad *node)
138   {
139     RETURN_FALSE_UNLESS(is_lwq(node))
140     RETURN_FALSE_UNLESS(is_lwq(node->input()))
141     return true;
142   }
143
144   bool visit(const luci::CirclePadV2 *node)
145   {
146     RETURN_FALSE_UNLESS(is_lwq(node))
147     RETURN_FALSE_UNLESS(is_lwq(node->input()))
148     RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
149     return true;
150   }
151
152   bool visit(const luci::CircleMirrorPad *node)
153   {
154     RETURN_FALSE_UNLESS(is_lwq(node))
155     RETURN_FALSE_UNLESS(is_lwq(node->input()))
156     return true;
157   }
158
159   bool visit(const luci::CirclePRelu *node)
160   {
161     RETURN_FALSE_UNLESS(is_lwq(node))
162     RETURN_FALSE_UNLESS(is_lwq(node->input()))
163     RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
164     return true;
165   }
166
167   bool visit(const luci::CircleTransposeConv *node)
168   {
169     RETURN_FALSE_UNLESS(is_lwq(node))
170     RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
171     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
172     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
173     if (bias != nullptr)
174       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
175     return true;
176   }
177
178   bool visit(const luci::CircleFullyConnected *node)
179   {
180     RETURN_FALSE_UNLESS(is_lwq(node))
181     RETURN_FALSE_UNLESS(is_lwq(node->input()))
182     RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
183     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
184     if (bias != nullptr)
185       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
186     return true;
187   }
188
189   bool visit(const luci::CircleAdd *node)
190   {
191     RETURN_FALSE_UNLESS(is_lwq(node))
192     RETURN_FALSE_UNLESS(is_lwq(node->x()));
193     RETURN_FALSE_UNLESS(is_lwq(node->y()));
194     return true;
195   }
196
197   bool visit(const luci::CircleAveragePool2D *node)
198   {
199     RETURN_FALSE_UNLESS(is_lwq(node))
200     RETURN_FALSE_UNLESS(is_lwq(node->value()));
201     return true;
202   }
203
204   bool visit(const luci::CircleLogicalOr *)
205   {
206     // Logical OR has bool-type inputs and output
207     // Nothing to be checked
208     return true;
209   }
210
211   bool visit(const luci::CircleMaxPool2D *node)
212   {
213     RETURN_FALSE_UNLESS(is_lwq(node))
214     RETURN_FALSE_UNLESS(is_lwq(node->value()));
215     return true;
216   }
217
218   bool visit(const luci::CircleLocalResponseNormalization *node)
219   {
220     RETURN_FALSE_UNLESS(is_lwq(node))
221     RETURN_FALSE_UNLESS(is_lwq(node->input()));
222     return true;
223   }
224
225   bool visit(const luci::CircleMean *node)
226   {
227     RETURN_FALSE_UNLESS(is_lwq(node))
228     RETURN_FALSE_UNLESS(is_lwq(node->input()));
229     return true;
230   }
231
232   bool visit(const luci::CircleMul *node)
233   {
234     RETURN_FALSE_UNLESS(is_lwq(node))
235     RETURN_FALSE_UNLESS(is_lwq(node->x()));
236     RETURN_FALSE_UNLESS(is_lwq(node->y()));
237     return true;
238   }
239
240   bool visit(const luci::CircleNotEqual *node)
241   {
242     RETURN_FALSE_UNLESS(is_lwq(node->x()));
243     RETURN_FALSE_UNLESS(is_lwq(node->y()));
244     return true;
245   }
246
247   bool visit(const luci::CircleRelu *node)
248   {
249     RETURN_FALSE_UNLESS(is_lwq(node))
250     RETURN_FALSE_UNLESS(is_lwq(node->features()));
251     return true;
252   }
253
254   bool visit(const luci::CircleReshape *node)
255   {
256     auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
257     bool input_quantized = input->quantparam() != nullptr;
258     bool node_quantized = node->quantparam() != nullptr;
259     RETURN_FALSE_UNLESS(input_quantized == node_quantized);
260     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
261     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
262     return true;
263   }
264
265   bool visit(const luci::CircleLogistic *node)
266   {
267     RETURN_FALSE_UNLESS(is_lwq(node));
268     RETURN_FALSE_UNLESS(is_lwq(node->x()));
269     return true;
270   }
271
272   bool visit(const luci::CircleSoftmax *node)
273   {
274     RETURN_FALSE_UNLESS(is_lwq(node));
275     RETURN_FALSE_UNLESS(is_lwq(node->logits()));
276     return true;
277   }
278
279   bool visit(const luci::CircleSpaceToBatchND *node)
280   {
281     RETURN_FALSE_UNLESS(is_lwq(node));
282     RETURN_FALSE_UNLESS(is_lwq(node->input()));
283     return true;
284   }
285
286   bool visit(const luci::CircleSpaceToDepth *node)
287   {
288     RETURN_FALSE_UNLESS(is_lwq(node));
289     RETURN_FALSE_UNLESS(is_lwq(node->input()));
290     return true;
291   }
292
293   bool visit(const luci::CircleSlice *node)
294   {
295     RETURN_FALSE_UNLESS(is_lwq(node));
296     RETURN_FALSE_UNLESS(is_lwq(node->input()));
297     return true;
298   }
299
300   bool visit(const luci::CircleSplit *node)
301   {
302     // node's output is the input of CircleSplitOut, thus not quantized
303     RETURN_FALSE_UNLESS(is_lwq(node->input()));
304     return true;
305   }
306
307   bool visit(const luci::CircleSplitOut *node)
308   {
309     RETURN_FALSE_UNLESS(is_lwq(node));
310     return true;
311   }
312
313   bool visit(const luci::CircleSplitV *node)
314   {
315     // node's output is the input of CircleSplitVOut, thus not quantized
316     RETURN_FALSE_UNLESS(is_lwq(node->input()));
317     return true;
318   }
319
320   bool visit(const luci::CircleSplitVOut *node)
321   {
322     RETURN_FALSE_UNLESS(is_lwq(node));
323     return true;
324   }
325
326   bool visit(const luci::CircleStridedSlice *node)
327   {
328     RETURN_FALSE_UNLESS(is_lwq(node));
329     RETURN_FALSE_UNLESS(is_lwq(node->input()));
330     return true;
331   }
332
333   bool visit(const luci::CircleArgMax *node)
334   {
335     // node's output is index, thus not quantized
336     RETURN_FALSE_UNLESS(is_lwq(node->input()));
337     return true;
338   }
339
340   bool visit(const luci::CircleBatchToSpaceND *node)
341   {
342     RETURN_FALSE_UNLESS(is_lwq(node));
343     RETURN_FALSE_UNLESS(is_lwq(node->input()));
344     return true;
345   }
346
347   bool visit(const luci::CircleTanh *node)
348   {
349     RETURN_FALSE_UNLESS(is_lwq(node));
350     RETURN_FALSE_UNLESS(is_lwq(node->x()));
351     return true;
352   }
353
354   bool visit(const luci::CircleTranspose *node)
355   {
356     RETURN_FALSE_UNLESS(is_lwq(node));
357     RETURN_FALSE_UNLESS(is_lwq(node->a()));
358     return true;
359   }
360
361   bool visit(const luci::CircleFloor *node)
362   {
363     RETURN_FALSE_UNLESS(is_lwq(node));
364     RETURN_FALSE_UNLESS(is_lwq(node->x()));
365     return true;
366   }
367
368   bool visit(const luci::CircleGreater *node)
369   {
370     RETURN_FALSE_UNLESS(is_lwq(node->x()));
371     RETURN_FALSE_UNLESS(is_lwq(node->y()));
372     return true;
373   }
374
375   bool visit(const luci::CircleGreaterEqual *node)
376   {
377     RETURN_FALSE_UNLESS(is_lwq(node->x()));
378     RETURN_FALSE_UNLESS(is_lwq(node->y()));
379     return true;
380   }
381
382   bool visit(const luci::CircleDiv *node)
383   {
384     RETURN_FALSE_UNLESS(is_lwq(node));
385     RETURN_FALSE_UNLESS(is_lwq(node->x()));
386     RETURN_FALSE_UNLESS(is_lwq(node->y()));
387     return true;
388   }
389
390   bool visit(const luci::CircleFloorDiv *node)
391   {
392     RETURN_FALSE_UNLESS(is_lwq(node));
393     RETURN_FALSE_UNLESS(is_lwq(node->x()));
394     RETURN_FALSE_UNLESS(is_lwq(node->y()));
395     return true;
396   }
397
398   bool visit(const luci::CircleRsqrt *node)
399   {
400     RETURN_FALSE_UNLESS(is_lwq(node));
401     RETURN_FALSE_UNLESS(is_lwq(node->x()));
402     return true;
403   }
404
405   bool visit(const luci::CircleSqrt *node)
406   {
407     RETURN_FALSE_UNLESS(is_lwq(node));
408     RETURN_FALSE_UNLESS(is_lwq(node->x()));
409     return true;
410   }
411
412   bool visit(const luci::CircleElu *node)
413   {
414     RETURN_FALSE_UNLESS(is_lwq(node));
415     RETURN_FALSE_UNLESS(is_lwq(node->features()));
416     return true;
417   }
418
419   bool visit(const luci::CirclePow *node)
420   {
421     RETURN_FALSE_UNLESS(is_lwq(node));
422     RETURN_FALSE_UNLESS(is_lwq(node->x()));
423     RETURN_FALSE_UNLESS(is_lwq(node->y()));
424     return true;
425   }
426
427   bool visit(const luci::CircleResizeBilinear *node)
428   {
429     RETURN_FALSE_UNLESS(is_lwq(node));
430     RETURN_FALSE_UNLESS(is_lwq(node->input()));
431     return true;
432   }
433
434   bool visit(const luci::CircleResizeNearestNeighbor *node)
435   {
436     RETURN_FALSE_UNLESS(is_lwq(node));
437     RETURN_FALSE_UNLESS(is_lwq(node->input()));
438     return true;
439   }
440
441   bool visit(const luci::CircleUnpack *node)
442   {
443     // node's output is the input of CircleUnpackOut, thus not quantized
444     RETURN_FALSE_UNLESS(is_lwq(node->value()));
445     return true;
446   }
447
448   bool visit(const luci::CircleUnpackOut *node)
449   {
450     RETURN_FALSE_UNLESS(is_lwq(node));
451     return true;
452   }
453
454   bool visit(const luci::CircleCast *node)
455   {
456     auto input = loco::must_cast<const luci::CircleNode *>(node->x());
457     bool input_quantized = input->quantparam() != nullptr;
458     bool node_quantized = node->quantparam() != nullptr;
459     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
460     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
461     return true;
462   }
463
464   // TODO: Implement more Ops
465
466   bool visit(const luci::CircleNode *) { return true; }
467 };
468
469 } // namespace luci
470
471 #undef RETURN_FALSE_UNLESS
472
473 #endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__