Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / VerifyQuantizedNodeGranularity.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
18 #define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Pass/QuantizationParameters.h>
23
24 #include <memory>
25
26 using Granularity = luci::QuantizationGranularity;
27
28 // This macro is undef at the end of the file
29 #define RETURN_FALSE_UNLESS(ARG) \
30   if (not(ARG))                  \
31   {                              \
32     return false;                \
33   }
34
35 namespace luci
36 {
37
38 /**
39  * @brief Verify the granualrity of quantized node
40  * @details
41  *
42  * Targets to verify
43  * - node's output (i.e., node itself)
44  * - node's inputs
45  */
46 class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
47 {
48 public:
49   static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
50
51 protected:
52   bool is_lwq(const loco::Node *node)
53   {
54     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
55
56     if (circle_node->quantparam() == nullptr)
57       return false;
58
59     if (circle_node->quantparam()->scale.size() != 1)
60       return false;
61
62     if (circle_node->quantparam()->zerop.size() != 1)
63       return false;
64
65     return true;
66   }
67
68 private:
69   virtual bool visit(const luci::CircleConv2D *node) = 0;
70
71   bool visit(const luci::CircleConcatenation *node)
72   {
73     // Skip granularity check for concatenation of indices
74     if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
75       return true;
76
77     RETURN_FALSE_UNLESS(is_lwq(node))
78     for (uint32_t i = 0; i < node->numValues(); i++)
79     {
80       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
81     }
82     return true;
83   }
84
85   bool visit(const luci::CircleDepthToSpace *node)
86   {
87     RETURN_FALSE_UNLESS(is_lwq(node))
88     RETURN_FALSE_UNLESS(is_lwq(node->input()))
89     return true;
90   }
91
92   virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
93
94   virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
95
96   bool visit(const luci::CirclePack *node)
97   {
98     RETURN_FALSE_UNLESS(is_lwq(node))
99     for (uint32_t i = 0; i < node->values_count(); i++)
100     {
101       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
102     }
103     return true;
104   }
105
106   bool visit(const luci::CirclePad *node)
107   {
108     RETURN_FALSE_UNLESS(is_lwq(node))
109     RETURN_FALSE_UNLESS(is_lwq(node->input()))
110     return true;
111   }
112
113   bool visit(const luci::CirclePadV2 *node)
114   {
115     RETURN_FALSE_UNLESS(is_lwq(node))
116     RETURN_FALSE_UNLESS(is_lwq(node->input()))
117     RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
118     return true;
119   }
120
121   bool visit(const luci::CircleMirrorPad *node)
122   {
123     RETURN_FALSE_UNLESS(is_lwq(node))
124     RETURN_FALSE_UNLESS(is_lwq(node->input()))
125     return true;
126   }
127
128   virtual bool visit(const luci::CirclePRelu *node) = 0;
129
130   virtual bool visit(const luci::CircleTransposeConv *node) = 0;
131
132   virtual bool visit(const luci::CircleFullyConnected *node) = 0;
133
134   bool visit(const luci::CircleAdd *node)
135   {
136     // Skip granularity check for indices
137     if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
138       return true;
139
140     RETURN_FALSE_UNLESS(is_lwq(node));
141     RETURN_FALSE_UNLESS(is_lwq(node->x()));
142     RETURN_FALSE_UNLESS(is_lwq(node->y()));
143     return true;
144   }
145
146   bool visit(const luci::CircleAveragePool2D *node)
147   {
148     RETURN_FALSE_UNLESS(is_lwq(node));
149     RETURN_FALSE_UNLESS(is_lwq(node->value()));
150     return true;
151   }
152
153   bool visit(const luci::CircleLogicalOr *)
154   {
155     // Logical OR has bool-type inputs and output
156     // Nothing to be checked
157     return true;
158   }
159
160   bool visit(const luci::CircleMaxPool2D *node)
161   {
162     RETURN_FALSE_UNLESS(is_lwq(node));
163     RETURN_FALSE_UNLESS(is_lwq(node->value()));
164     return true;
165   }
166
167   bool visit(const luci::CircleLocalResponseNormalization *node)
168   {
169     RETURN_FALSE_UNLESS(is_lwq(node))
170     RETURN_FALSE_UNLESS(is_lwq(node->input()));
171     return true;
172   }
173
174   bool visit(const luci::CircleMean *node)
175   {
176     RETURN_FALSE_UNLESS(is_lwq(node));
177     RETURN_FALSE_UNLESS(is_lwq(node->input()));
178     return true;
179   }
180
181   bool visit(const luci::CircleMul *node)
182   {
183     // Skip granularity check for indices
184     if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
185       return true;
186
187     RETURN_FALSE_UNLESS(is_lwq(node));
188     RETURN_FALSE_UNLESS(is_lwq(node->x()));
189     RETURN_FALSE_UNLESS(is_lwq(node->y()));
190     return true;
191   }
192
193   bool visit(const luci::CircleNotEqual *node)
194   {
195     RETURN_FALSE_UNLESS(is_lwq(node->x()));
196     RETURN_FALSE_UNLESS(is_lwq(node->y()));
197     return true;
198   }
199
200   bool visit(const luci::CircleOneHot *node)
201   {
202     RETURN_FALSE_UNLESS(is_lwq(node));
203     RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
204     RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
205     return true;
206   }
207
208   bool visit(const luci::CircleReduceMax *node)
209   {
210     RETURN_FALSE_UNLESS(is_lwq(node));
211     RETURN_FALSE_UNLESS(is_lwq(node->input()));
212     return true;
213   }
214
215   bool visit(const luci::CircleRelu *node)
216   {
217     RETURN_FALSE_UNLESS(is_lwq(node));
218     RETURN_FALSE_UNLESS(is_lwq(node->features()));
219     return true;
220   }
221
222   bool visit(const luci::CircleReshape *node)
223   {
224     auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
225     bool input_quantized = input->quantparam() != nullptr;
226     bool node_quantized = node->quantparam() != nullptr;
227     RETURN_FALSE_UNLESS(input_quantized == node_quantized);
228     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
229     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
230     return true;
231   }
232
233   bool visit(const luci::CircleLogistic *node)
234   {
235     RETURN_FALSE_UNLESS(is_lwq(node));
236     RETURN_FALSE_UNLESS(is_lwq(node->x()));
237     return true;
238   }
239
240   bool visit(const luci::CircleSoftmax *node)
241   {
242     RETURN_FALSE_UNLESS(is_lwq(node));
243     RETURN_FALSE_UNLESS(is_lwq(node->logits()));
244     return true;
245   }
246
247   bool visit(const luci::CircleSpaceToBatchND *node)
248   {
249     RETURN_FALSE_UNLESS(is_lwq(node));
250     RETURN_FALSE_UNLESS(is_lwq(node->input()));
251     return true;
252   }
253
254   bool visit(const luci::CircleSpaceToDepth *node)
255   {
256     RETURN_FALSE_UNLESS(is_lwq(node));
257     RETURN_FALSE_UNLESS(is_lwq(node->input()));
258     return true;
259   }
260
261   bool visit(const luci::CircleSlice *node)
262   {
263     RETURN_FALSE_UNLESS(is_lwq(node));
264     RETURN_FALSE_UNLESS(is_lwq(node->input()));
265     return true;
266   }
267
268   bool visit(const luci::CircleSplit *node)
269   {
270     // node's output is the input of CircleSplitOut, thus not quantized
271     RETURN_FALSE_UNLESS(is_lwq(node->input()));
272     return true;
273   }
274
275   bool visit(const luci::CircleSplitOut *node)
276   {
277     RETURN_FALSE_UNLESS(is_lwq(node));
278     return true;
279   }
280
281   bool visit(const luci::CircleSplitV *node)
282   {
283     // node's output is the input of CircleSplitVOut, thus not quantized
284     RETURN_FALSE_UNLESS(is_lwq(node->input()));
285     return true;
286   }
287
288   bool visit(const luci::CircleSplitVOut *node)
289   {
290     RETURN_FALSE_UNLESS(is_lwq(node));
291     return true;
292   }
293
294   bool visit(const luci::CircleStridedSlice *node)
295   {
296     RETURN_FALSE_UNLESS(is_lwq(node));
297     RETURN_FALSE_UNLESS(is_lwq(node->input()));
298     return true;
299   }
300
301   bool visit(const luci::CircleSum *node)
302   {
303     RETURN_FALSE_UNLESS(is_lwq(node));
304     RETURN_FALSE_UNLESS(is_lwq(node->input()));
305     return true;
306   }
307
308   bool visit(const luci::CircleArgMax *node)
309   {
310     // node's output is index, thus not quantized
311     RETURN_FALSE_UNLESS(is_lwq(node->input()));
312     return true;
313   }
314
315   bool visit(const luci::CircleBatchToSpaceND *node)
316   {
317     RETURN_FALSE_UNLESS(is_lwq(node));
318     RETURN_FALSE_UNLESS(is_lwq(node->input()));
319     return true;
320   }
321
322   bool visit(const luci::CircleTanh *node)
323   {
324     RETURN_FALSE_UNLESS(is_lwq(node));
325     RETURN_FALSE_UNLESS(is_lwq(node->x()));
326     return true;
327   }
328
329   bool visit(const luci::CircleTranspose *node)
330   {
331     RETURN_FALSE_UNLESS(is_lwq(node));
332     RETURN_FALSE_UNLESS(is_lwq(node->a()));
333     return true;
334   }
335
336   bool visit(const luci::CircleFloor *node)
337   {
338     RETURN_FALSE_UNLESS(is_lwq(node));
339     RETURN_FALSE_UNLESS(is_lwq(node->x()));
340     return true;
341   }
342
343   bool visit(const luci::CircleGelu *node)
344   {
345     RETURN_FALSE_UNLESS(is_lwq(node));
346     RETURN_FALSE_UNLESS(is_lwq(node->features()));
347     return true;
348   }
349
350   bool visit(const luci::CircleGreater *node)
351   {
352     RETURN_FALSE_UNLESS(is_lwq(node->x()));
353     RETURN_FALSE_UNLESS(is_lwq(node->y()));
354     return true;
355   }
356
357   bool visit(const luci::CircleGreaterEqual *node)
358   {
359     RETURN_FALSE_UNLESS(is_lwq(node->x()));
360     RETURN_FALSE_UNLESS(is_lwq(node->y()));
361     return true;
362   }
363
364   bool visit(const luci::CircleDiv *node)
365   {
366     RETURN_FALSE_UNLESS(is_lwq(node));
367     RETURN_FALSE_UNLESS(is_lwq(node->x()));
368     RETURN_FALSE_UNLESS(is_lwq(node->y()));
369     return true;
370   }
371
372   bool visit(const luci::CircleFloorDiv *node)
373   {
374     RETURN_FALSE_UNLESS(is_lwq(node));
375     RETURN_FALSE_UNLESS(is_lwq(node->x()));
376     RETURN_FALSE_UNLESS(is_lwq(node->y()));
377     return true;
378   }
379
380   bool visit(const luci::CircleRsqrt *node)
381   {
382     RETURN_FALSE_UNLESS(is_lwq(node));
383     RETURN_FALSE_UNLESS(is_lwq(node->x()));
384     return true;
385   }
386
387   bool visit(const luci::CircleSqrt *node)
388   {
389     RETURN_FALSE_UNLESS(is_lwq(node));
390     RETURN_FALSE_UNLESS(is_lwq(node->x()));
391     return true;
392   }
393
394   bool visit(const luci::CircleElu *node)
395   {
396     RETURN_FALSE_UNLESS(is_lwq(node));
397     RETURN_FALSE_UNLESS(is_lwq(node->features()));
398     return true;
399   }
400
401   bool visit(const luci::CirclePow *node)
402   {
403     RETURN_FALSE_UNLESS(is_lwq(node));
404     RETURN_FALSE_UNLESS(is_lwq(node->x()));
405     RETURN_FALSE_UNLESS(is_lwq(node->y()));
406     return true;
407   }
408
409   bool visit(const luci::CircleResizeBilinear *node)
410   {
411     RETURN_FALSE_UNLESS(is_lwq(node));
412     RETURN_FALSE_UNLESS(is_lwq(node->input()));
413     return true;
414   }
415
416   bool visit(const luci::CircleResizeNearestNeighbor *node)
417   {
418     RETURN_FALSE_UNLESS(is_lwq(node));
419     RETURN_FALSE_UNLESS(is_lwq(node->input()));
420     return true;
421   }
422
423   bool visit(const luci::CircleUnpack *node)
424   {
425     // node's output is the input of CircleUnpackOut, thus not quantized
426     RETURN_FALSE_UNLESS(is_lwq(node->value()));
427     return true;
428   }
429
430   bool visit(const luci::CircleUnpackOut *node)
431   {
432     RETURN_FALSE_UNLESS(is_lwq(node));
433     return true;
434   }
435
436   bool visit(const luci::CircleCast *node)
437   {
438     auto input = loco::must_cast<const luci::CircleNode *>(node->x());
439     bool input_quantized = input->quantparam() != nullptr;
440     bool node_quantized = node->quantparam() != nullptr;
441     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
442     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
443     return true;
444   }
445
446   // TODO: Implement more Ops
447
448   bool visit(const luci::CircleNode *) { return true; }
449 };
450
451 class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
452 {
453 private:
454   uint32_t rank(const loco::Node *node)
455   {
456     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
457     return circle_node->rank();
458   }
459
460   bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
461   {
462     auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
463
464     assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
465     auto channel_size = circle_node->dim(channel_dim).value();
466
467     if (circle_node->quantparam() == nullptr)
468       return false;
469
470     if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
471       return false;
472
473     if (circle_node->quantparam()->scale.size() != channel_size)
474       return false;
475
476     if (circle_node->quantparam()->zerop.size() != channel_size)
477       return false;
478
479     return true;
480   }
481
482 private:
483   bool visit(const luci::CircleConv2D *node)
484   {
485     RETURN_FALSE_UNLESS(is_lwq(node))
486     RETURN_FALSE_UNLESS(is_lwq(node->input()))
487     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
488     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
489     if (bias != nullptr)
490       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
491     return true;
492   }
493
494   bool visit(const luci::CircleDepthwiseConv2D *node)
495   {
496     RETURN_FALSE_UNLESS(is_lwq(node))
497     RETURN_FALSE_UNLESS(is_lwq(node->input()))
498     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
499     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
500     if (bias != nullptr)
501       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
502     return true;
503   }
504
505   bool visit(const luci::CircleInstanceNorm *node)
506   {
507     RETURN_FALSE_UNLESS(is_lwq(node))
508     RETURN_FALSE_UNLESS(is_lwq(node->input()))
509     RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
510     RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
511     return true;
512   }
513
514   bool visit(const luci::CirclePRelu *node)
515   {
516     RETURN_FALSE_UNLESS(is_lwq(node))
517     RETURN_FALSE_UNLESS(is_lwq(node->input()))
518     RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
519     return true;
520   }
521
522   bool visit(const luci::CircleTransposeConv *node)
523   {
524     RETURN_FALSE_UNLESS(is_lwq(node))
525     RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
526     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
527     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
528     if (bias != nullptr)
529       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
530
531     return true;
532   }
533
534   bool visit(const luci::CircleFullyConnected *node)
535   {
536     RETURN_FALSE_UNLESS(is_lwq(node))
537     RETURN_FALSE_UNLESS(is_lwq(node->input()))
538     RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
539     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
540     // Bias is optional (it can be CircleOutputExclude)
541     if (bias != nullptr)
542       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
543     return true;
544   }
545 };
546
547 class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
548 {
549 private:
550   bool is_lwq_const(const loco::Node *node)
551   {
552     auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
553
554     if (circle_node->quantparam() == nullptr)
555       return false;
556
557     if (circle_node->quantparam()->scale.size() != 1)
558       return false;
559
560     if (circle_node->quantparam()->zerop.size() != 1)
561       return false;
562
563     return true;
564   }
565
566 private:
567   bool visit(const luci::CircleConv2D *node)
568   {
569     RETURN_FALSE_UNLESS(is_lwq(node))
570     RETURN_FALSE_UNLESS(is_lwq(node->input()))
571     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
572     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
573     if (bias != nullptr)
574       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
575     return true;
576   }
577
578   bool visit(const luci::CircleDepthwiseConv2D *node)
579   {
580     RETURN_FALSE_UNLESS(is_lwq(node))
581     RETURN_FALSE_UNLESS(is_lwq(node->input()))
582     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
583     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
584     if (bias != nullptr)
585       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
586     return true;
587   }
588
589   bool visit(const luci::CircleInstanceNorm *node)
590   {
591     RETURN_FALSE_UNLESS(is_lwq(node))
592     RETURN_FALSE_UNLESS(is_lwq(node->input()))
593     RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
594     RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
595     return true;
596   }
597
598   bool visit(const luci::CirclePRelu *node)
599   {
600     RETURN_FALSE_UNLESS(is_lwq(node))
601     RETURN_FALSE_UNLESS(is_lwq(node->input()))
602     RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
603     return true;
604   }
605
606   bool visit(const luci::CircleTransposeConv *node)
607   {
608     RETURN_FALSE_UNLESS(is_lwq(node))
609     RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
610     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
611     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
612     if (bias != nullptr)
613       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
614     return true;
615   }
616
617   bool visit(const luci::CircleFullyConnected *node)
618   {
619     RETURN_FALSE_UNLESS(is_lwq(node))
620     RETURN_FALSE_UNLESS(is_lwq(node->input()))
621     RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
622     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
623     if (bias != nullptr)
624       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
625     return true;
626   }
627 };
628
629 } // namespace luci
630
631 #undef RETURN_FALSE_UNLESS
632
633 #endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__