6bf7ff6981e720774a8d42cde7590c997426254f
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / VerifyQuantizedNodeGranularity.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
18 #define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Pass/QuantizationParameters.h>
23
24 #include <memory>
25
26 using Granularity = luci::QuantizationGranularity;
27
28 // This macro is undef at the end of the file
29 #define RETURN_FALSE_UNLESS(ARG) \
30   if (not(ARG))                  \
31   {                              \
32     return false;                \
33   }
34
35 namespace luci
36 {
37
38 /**
39  * @brief Verify the granualrity of quantized node
40  * @details
41  *
42  * Targets to verify
43  * - node's output (i.e., node itself)
44  * - node's inputs
45  */
46 class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
47 {
48 public:
49   static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
50
51 protected:
52   bool is_lwq(const loco::Node *node)
53   {
54     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
55
56     if (circle_node->quantparam() == nullptr)
57       return false;
58
59     if (circle_node->quantparam()->scale.size() != 1)
60       return false;
61
62     if (circle_node->quantparam()->zerop.size() != 1)
63       return false;
64
65     return true;
66   }
67
68 private:
69   virtual bool visit(const luci::CircleConv2D *node) = 0;
70
71   bool visit(const luci::CircleConcatenation *node)
72   {
73     // Skip granularity check for concatenation of indices
74     if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
75       return true;
76
77     RETURN_FALSE_UNLESS(is_lwq(node))
78     for (uint32_t i = 0; i < node->numValues(); i++)
79     {
80       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
81     }
82     return true;
83   }
84
85   bool visit(const luci::CircleDepthToSpace *node)
86   {
87     RETURN_FALSE_UNLESS(is_lwq(node))
88     RETURN_FALSE_UNLESS(is_lwq(node->input()))
89     return true;
90   }
91
92   virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
93
94   virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
95
96   bool visit(const luci::CirclePack *node)
97   {
98     RETURN_FALSE_UNLESS(is_lwq(node))
99     for (uint32_t i = 0; i < node->values_count(); i++)
100     {
101       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
102     }
103     return true;
104   }
105
106   bool visit(const luci::CirclePad *node)
107   {
108     RETURN_FALSE_UNLESS(is_lwq(node))
109     RETURN_FALSE_UNLESS(is_lwq(node->input()))
110     return true;
111   }
112
113   bool visit(const luci::CirclePadV2 *node)
114   {
115     RETURN_FALSE_UNLESS(is_lwq(node))
116     RETURN_FALSE_UNLESS(is_lwq(node->input()))
117     RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
118     return true;
119   }
120
121   bool visit(const luci::CircleMirrorPad *node)
122   {
123     RETURN_FALSE_UNLESS(is_lwq(node))
124     RETURN_FALSE_UNLESS(is_lwq(node->input()))
125     return true;
126   }
127
128   virtual bool visit(const luci::CirclePRelu *node) = 0;
129
130   virtual bool visit(const luci::CircleTransposeConv *node) = 0;
131
132   virtual bool visit(const luci::CircleFullyConnected *node) = 0;
133
134   bool visit(const luci::CircleAdd *node)
135   {
136     // Skip granularity check for indices
137     if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
138       return true;
139
140     RETURN_FALSE_UNLESS(is_lwq(node));
141     RETURN_FALSE_UNLESS(is_lwq(node->x()));
142     RETURN_FALSE_UNLESS(is_lwq(node->y()));
143     return true;
144   }
145
146   bool visit(const luci::CircleAveragePool2D *node)
147   {
148     RETURN_FALSE_UNLESS(is_lwq(node));
149     RETURN_FALSE_UNLESS(is_lwq(node->value()));
150     return true;
151   }
152
153   bool visit(const luci::CircleLogicalOr *)
154   {
155     // Logical OR has bool-type inputs and output
156     // Nothing to be checked
157     return true;
158   }
159
160   bool visit(const luci::CircleMaxPool2D *node)
161   {
162     RETURN_FALSE_UNLESS(is_lwq(node));
163     RETURN_FALSE_UNLESS(is_lwq(node->value()));
164     return true;
165   }
166
167   bool visit(const luci::CircleLocalResponseNormalization *node)
168   {
169     RETURN_FALSE_UNLESS(is_lwq(node))
170     RETURN_FALSE_UNLESS(is_lwq(node->input()));
171     return true;
172   }
173
174   bool visit(const luci::CircleMean *node)
175   {
176     RETURN_FALSE_UNLESS(is_lwq(node));
177     RETURN_FALSE_UNLESS(is_lwq(node->input()));
178     return true;
179   }
180
181   bool visit(const luci::CircleMul *node)
182   {
183     // Skip granularity check for indices
184     if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
185       return true;
186
187     RETURN_FALSE_UNLESS(is_lwq(node));
188     RETURN_FALSE_UNLESS(is_lwq(node->x()));
189     RETURN_FALSE_UNLESS(is_lwq(node->y()));
190     return true;
191   }
192
193   bool visit(const luci::CircleNotEqual *node)
194   {
195     RETURN_FALSE_UNLESS(is_lwq(node->x()));
196     RETURN_FALSE_UNLESS(is_lwq(node->y()));
197     return true;
198   }
199
200   bool visit(const luci::CircleOneHot *node)
201   {
202     RETURN_FALSE_UNLESS(is_lwq(node));
203     RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
204     RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
205     return true;
206   }
207
208   bool visit(const luci::CircleReduceMax *node)
209   {
210     RETURN_FALSE_UNLESS(is_lwq(node));
211     RETURN_FALSE_UNLESS(is_lwq(node->input()));
212     return true;
213   }
214
215   bool visit(const luci::CircleRelu *node)
216   {
217     RETURN_FALSE_UNLESS(is_lwq(node));
218     RETURN_FALSE_UNLESS(is_lwq(node->features()));
219     return true;
220   }
221
222   bool visit(const luci::CircleReshape *node)
223   {
224     auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
225     bool input_quantized = input->quantparam() != nullptr;
226     bool node_quantized = node->quantparam() != nullptr;
227     RETURN_FALSE_UNLESS(input_quantized == node_quantized);
228     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
229     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
230     return true;
231   }
232
233   bool visit(const luci::CircleLogistic *node)
234   {
235     RETURN_FALSE_UNLESS(is_lwq(node));
236     RETURN_FALSE_UNLESS(is_lwq(node->x()));
237     return true;
238   }
239
240   bool visit(const luci::CircleSoftmax *node)
241   {
242     RETURN_FALSE_UNLESS(is_lwq(node));
243     RETURN_FALSE_UNLESS(is_lwq(node->logits()));
244     return true;
245   }
246
247   bool visit(const luci::CircleSpaceToBatchND *node)
248   {
249     RETURN_FALSE_UNLESS(is_lwq(node));
250     RETURN_FALSE_UNLESS(is_lwq(node->input()));
251     return true;
252   }
253
254   bool visit(const luci::CircleSpaceToDepth *node)
255   {
256     RETURN_FALSE_UNLESS(is_lwq(node));
257     RETURN_FALSE_UNLESS(is_lwq(node->input()));
258     return true;
259   }
260
261   bool visit(const luci::CircleSlice *node)
262   {
263     RETURN_FALSE_UNLESS(is_lwq(node));
264     RETURN_FALSE_UNLESS(is_lwq(node->input()));
265     return true;
266   }
267
268   bool visit(const luci::CircleSplit *node)
269   {
270     // node's output is the input of CircleSplitOut, thus not quantized
271     RETURN_FALSE_UNLESS(is_lwq(node->input()));
272     return true;
273   }
274
275   bool visit(const luci::CircleSplitOut *node)
276   {
277     RETURN_FALSE_UNLESS(is_lwq(node));
278     return true;
279   }
280
281   bool visit(const luci::CircleSplitV *node)
282   {
283     // node's output is the input of CircleSplitVOut, thus not quantized
284     RETURN_FALSE_UNLESS(is_lwq(node->input()));
285     return true;
286   }
287
288   bool visit(const luci::CircleSplitVOut *node)
289   {
290     RETURN_FALSE_UNLESS(is_lwq(node));
291     return true;
292   }
293
294   bool visit(const luci::CircleStridedSlice *node)
295   {
296     RETURN_FALSE_UNLESS(is_lwq(node));
297     RETURN_FALSE_UNLESS(is_lwq(node->input()));
298     return true;
299   }
300
301   bool visit(const luci::CircleArgMax *node)
302   {
303     // node's output is index, thus not quantized
304     RETURN_FALSE_UNLESS(is_lwq(node->input()));
305     return true;
306   }
307
308   bool visit(const luci::CircleBatchToSpaceND *node)
309   {
310     RETURN_FALSE_UNLESS(is_lwq(node));
311     RETURN_FALSE_UNLESS(is_lwq(node->input()));
312     return true;
313   }
314
315   bool visit(const luci::CircleTanh *node)
316   {
317     RETURN_FALSE_UNLESS(is_lwq(node));
318     RETURN_FALSE_UNLESS(is_lwq(node->x()));
319     return true;
320   }
321
322   bool visit(const luci::CircleTranspose *node)
323   {
324     RETURN_FALSE_UNLESS(is_lwq(node));
325     RETURN_FALSE_UNLESS(is_lwq(node->a()));
326     return true;
327   }
328
329   bool visit(const luci::CircleFloor *node)
330   {
331     RETURN_FALSE_UNLESS(is_lwq(node));
332     RETURN_FALSE_UNLESS(is_lwq(node->x()));
333     return true;
334   }
335
336   bool visit(const luci::CircleGreater *node)
337   {
338     RETURN_FALSE_UNLESS(is_lwq(node->x()));
339     RETURN_FALSE_UNLESS(is_lwq(node->y()));
340     return true;
341   }
342
343   bool visit(const luci::CircleGreaterEqual *node)
344   {
345     RETURN_FALSE_UNLESS(is_lwq(node->x()));
346     RETURN_FALSE_UNLESS(is_lwq(node->y()));
347     return true;
348   }
349
350   bool visit(const luci::CircleDiv *node)
351   {
352     RETURN_FALSE_UNLESS(is_lwq(node));
353     RETURN_FALSE_UNLESS(is_lwq(node->x()));
354     RETURN_FALSE_UNLESS(is_lwq(node->y()));
355     return true;
356   }
357
358   bool visit(const luci::CircleFloorDiv *node)
359   {
360     RETURN_FALSE_UNLESS(is_lwq(node));
361     RETURN_FALSE_UNLESS(is_lwq(node->x()));
362     RETURN_FALSE_UNLESS(is_lwq(node->y()));
363     return true;
364   }
365
366   bool visit(const luci::CircleRsqrt *node)
367   {
368     RETURN_FALSE_UNLESS(is_lwq(node));
369     RETURN_FALSE_UNLESS(is_lwq(node->x()));
370     return true;
371   }
372
373   bool visit(const luci::CircleSqrt *node)
374   {
375     RETURN_FALSE_UNLESS(is_lwq(node));
376     RETURN_FALSE_UNLESS(is_lwq(node->x()));
377     return true;
378   }
379
380   bool visit(const luci::CircleElu *node)
381   {
382     RETURN_FALSE_UNLESS(is_lwq(node));
383     RETURN_FALSE_UNLESS(is_lwq(node->features()));
384     return true;
385   }
386
387   bool visit(const luci::CirclePow *node)
388   {
389     RETURN_FALSE_UNLESS(is_lwq(node));
390     RETURN_FALSE_UNLESS(is_lwq(node->x()));
391     RETURN_FALSE_UNLESS(is_lwq(node->y()));
392     return true;
393   }
394
395   bool visit(const luci::CircleResizeBilinear *node)
396   {
397     RETURN_FALSE_UNLESS(is_lwq(node));
398     RETURN_FALSE_UNLESS(is_lwq(node->input()));
399     return true;
400   }
401
402   bool visit(const luci::CircleResizeNearestNeighbor *node)
403   {
404     RETURN_FALSE_UNLESS(is_lwq(node));
405     RETURN_FALSE_UNLESS(is_lwq(node->input()));
406     return true;
407   }
408
409   bool visit(const luci::CircleUnpack *node)
410   {
411     // node's output is the input of CircleUnpackOut, thus not quantized
412     RETURN_FALSE_UNLESS(is_lwq(node->value()));
413     return true;
414   }
415
416   bool visit(const luci::CircleUnpackOut *node)
417   {
418     RETURN_FALSE_UNLESS(is_lwq(node));
419     return true;
420   }
421
422   bool visit(const luci::CircleCast *node)
423   {
424     auto input = loco::must_cast<const luci::CircleNode *>(node->x());
425     bool input_quantized = input->quantparam() != nullptr;
426     bool node_quantized = node->quantparam() != nullptr;
427     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
428     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
429     return true;
430   }
431
432   // TODO: Implement more Ops
433
434   bool visit(const luci::CircleNode *) { return true; }
435 };
436
437 class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
438 {
439 private:
440   uint32_t rank(const loco::Node *node)
441   {
442     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
443     return circle_node->rank();
444   }
445
446   bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
447   {
448     auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
449
450     assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
451     auto channel_size = circle_node->dim(channel_dim).value();
452
453     if (circle_node->quantparam() == nullptr)
454       return false;
455
456     if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
457       return false;
458
459     if (circle_node->quantparam()->scale.size() != channel_size)
460       return false;
461
462     if (circle_node->quantparam()->zerop.size() != channel_size)
463       return false;
464
465     return true;
466   }
467
468 private:
469   bool visit(const luci::CircleConv2D *node)
470   {
471     RETURN_FALSE_UNLESS(is_lwq(node))
472     RETURN_FALSE_UNLESS(is_lwq(node->input()))
473     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
474     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
475     if (bias != nullptr)
476       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
477     return true;
478   }
479
480   bool visit(const luci::CircleDepthwiseConv2D *node)
481   {
482     RETURN_FALSE_UNLESS(is_lwq(node))
483     RETURN_FALSE_UNLESS(is_lwq(node->input()))
484     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
485     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
486     if (bias != nullptr)
487       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
488     return true;
489   }
490
491   bool visit(const luci::CircleInstanceNorm *node)
492   {
493     RETURN_FALSE_UNLESS(is_lwq(node))
494     RETURN_FALSE_UNLESS(is_lwq(node->input()))
495     RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
496     RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
497     return true;
498   }
499
500   bool visit(const luci::CirclePRelu *node)
501   {
502     RETURN_FALSE_UNLESS(is_lwq(node))
503     RETURN_FALSE_UNLESS(is_lwq(node->input()))
504     RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
505     return true;
506   }
507
508   bool visit(const luci::CircleTransposeConv *node)
509   {
510     RETURN_FALSE_UNLESS(is_lwq(node))
511     RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
512     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
513     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
514     if (bias != nullptr)
515       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
516
517     return true;
518   }
519
520   bool visit(const luci::CircleFullyConnected *node)
521   {
522     RETURN_FALSE_UNLESS(is_lwq(node))
523     RETURN_FALSE_UNLESS(is_lwq(node->input()))
524     RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
525     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
526     // Bias is optional (it can be CircleOutputExclude)
527     if (bias != nullptr)
528       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
529     return true;
530   }
531 };
532
533 class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
534 {
535 private:
536   bool is_lwq_const(const loco::Node *node)
537   {
538     auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
539
540     if (circle_node->quantparam() == nullptr)
541       return false;
542
543     if (circle_node->quantparam()->scale.size() != 1)
544       return false;
545
546     if (circle_node->quantparam()->zerop.size() != 1)
547       return false;
548
549     return true;
550   }
551
552 private:
553   bool visit(const luci::CircleConv2D *node)
554   {
555     RETURN_FALSE_UNLESS(is_lwq(node))
556     RETURN_FALSE_UNLESS(is_lwq(node->input()))
557     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
558     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
559     if (bias != nullptr)
560       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
561     return true;
562   }
563
564   bool visit(const luci::CircleDepthwiseConv2D *node)
565   {
566     RETURN_FALSE_UNLESS(is_lwq(node))
567     RETURN_FALSE_UNLESS(is_lwq(node->input()))
568     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
569     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
570     if (bias != nullptr)
571       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
572     return true;
573   }
574
575   bool visit(const luci::CircleInstanceNorm *node)
576   {
577     RETURN_FALSE_UNLESS(is_lwq(node))
578     RETURN_FALSE_UNLESS(is_lwq(node->input()))
579     RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
580     RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
581     return true;
582   }
583
584   bool visit(const luci::CirclePRelu *node)
585   {
586     RETURN_FALSE_UNLESS(is_lwq(node))
587     RETURN_FALSE_UNLESS(is_lwq(node->input()))
588     RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
589     return true;
590   }
591
592   bool visit(const luci::CircleTransposeConv *node)
593   {
594     RETURN_FALSE_UNLESS(is_lwq(node))
595     RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
596     RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
597     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
598     if (bias != nullptr)
599       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
600     return true;
601   }
602
603   bool visit(const luci::CircleFullyConnected *node)
604   {
605     RETURN_FALSE_UNLESS(is_lwq(node))
606     RETURN_FALSE_UNLESS(is_lwq(node->input()))
607     RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
608     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
609     if (bias != nullptr)
610       RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
611     return true;
612   }
613 };
614
615 } // namespace luci
616
617 #undef RETURN_FALSE_UNLESS
618
619 #endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__