Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / VerifyQuantizedNodeChannelWiseGranularity.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
17 #define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/Pass/QuantizationParameters.h>
22
23 using Granularity = luci::QuantizationGranularity;
24
25 // This macro is undef at the end of the file
26 #define RETURN_FALSE_UNLESS(ARG) \
27   if (not(ARG))                  \
28   {                              \
29     return false;                \
30   }
31
32 namespace luci
33 {
34
35 /**
36  * @brief Verify the granualrity of channel-wise quantized node
37  * @details
38  *
39  * Targets to verify
40  * - node's output (i.e., node itself)
41  * - node's inputs
42  */
43 struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool>
44 {
45 private:
46   bool is_lwq(const loco::Node *node)
47   {
48     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
49
50     if (circle_node->quantparam() == nullptr)
51       return false;
52
53     if (circle_node->quantparam()->scale.size() != 1)
54       return false;
55
56     if (circle_node->quantparam()->zerop.size() != 1)
57       return false;
58
59     return true;
60   }
61
62   uint32_t rank(const loco::Node *node)
63   {
64     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
65     return circle_node->rank();
66   }
67
68   bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
69   {
70     auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
71
72     assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
73     auto channel_size = circle_node->dim(channel_dim).value();
74
75     if (circle_node->quantparam() == nullptr)
76       return false;
77
78     if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
79       return false;
80
81     if (circle_node->quantparam()->scale.size() != channel_size)
82       return false;
83
84     if (circle_node->quantparam()->zerop.size() != channel_size)
85       return false;
86
87     return true;
88   }
89
90 private:
91   bool visit(const luci::CircleConv2D *node)
92   {
93     RETURN_FALSE_UNLESS(is_lwq(node))
94     RETURN_FALSE_UNLESS(is_lwq(node->input()))
95     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
96     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
97     if (bias != nullptr)
98       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
99     return true;
100   }
101
102   bool visit(const luci::CircleConcatenation *node)
103   {
104     RETURN_FALSE_UNLESS(is_lwq(node))
105     for (uint32_t i = 0; i < node->numValues(); i++)
106     {
107       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
108     }
109     return true;
110   }
111
112   bool visit(const luci::CircleDepthToSpace *node)
113   {
114     RETURN_FALSE_UNLESS(is_lwq(node))
115     RETURN_FALSE_UNLESS(is_lwq(node->input()))
116     return true;
117   }
118
119   bool visit(const luci::CircleDepthwiseConv2D *node)
120   {
121     RETURN_FALSE_UNLESS(is_lwq(node))
122     RETURN_FALSE_UNLESS(is_lwq(node->input()))
123     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
124     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
125     if (bias != nullptr)
126       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
127     return true;
128   }
129
130   bool visit(const luci::CircleInstanceNorm *node)
131   {
132     RETURN_FALSE_UNLESS(is_lwq(node))
133     RETURN_FALSE_UNLESS(is_lwq(node->input()))
134     RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
135     RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
136     return true;
137   }
138
139   bool visit(const luci::CirclePack *node)
140   {
141     RETURN_FALSE_UNLESS(is_lwq(node))
142     for (uint32_t i = 0; i < node->values_count(); i++)
143     {
144       RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
145     }
146     return true;
147   }
148
149   bool visit(const luci::CirclePad *node)
150   {
151     RETURN_FALSE_UNLESS(is_lwq(node))
152     RETURN_FALSE_UNLESS(is_lwq(node->input()))
153     return true;
154   }
155
156   bool visit(const luci::CirclePadV2 *node)
157   {
158     RETURN_FALSE_UNLESS(is_lwq(node))
159     RETURN_FALSE_UNLESS(is_lwq(node->input()))
160     RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
161     return true;
162   }
163
164   bool visit(const luci::CircleMirrorPad *node)
165   {
166     RETURN_FALSE_UNLESS(is_lwq(node))
167     RETURN_FALSE_UNLESS(is_lwq(node->input()))
168     return true;
169   }
170
171   bool visit(const luci::CirclePRelu *node)
172   {
173     RETURN_FALSE_UNLESS(is_lwq(node))
174     RETURN_FALSE_UNLESS(is_lwq(node->input()))
175     RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
176     return true;
177   }
178
179   bool visit(const luci::CircleTransposeConv *node)
180   {
181     RETURN_FALSE_UNLESS(is_lwq(node))
182     RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
183     RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
184     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
185     if (bias != nullptr)
186       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
187
188     return true;
189   }
190
191   bool visit(const luci::CircleFullyConnected *node)
192   {
193     RETURN_FALSE_UNLESS(is_lwq(node))
194     RETURN_FALSE_UNLESS(is_lwq(node->input()))
195     RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
196     luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
197     // Bias is optional (it can be CircleOutputExclude)
198     if (bias != nullptr)
199       RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
200     return true;
201   }
202
203   bool visit(const luci::CircleAdd *node)
204   {
205     RETURN_FALSE_UNLESS(is_lwq(node));
206     RETURN_FALSE_UNLESS(is_lwq(node->x()));
207     RETURN_FALSE_UNLESS(is_lwq(node->y()));
208     return true;
209   }
210
211   bool visit(const luci::CircleAveragePool2D *node)
212   {
213     RETURN_FALSE_UNLESS(is_lwq(node));
214     RETURN_FALSE_UNLESS(is_lwq(node->value()));
215     return true;
216   }
217
218   bool visit(const luci::CircleLogicalOr *)
219   {
220     // Logical OR has bool-type inputs and output
221     // Nothing to be checked
222     return true;
223   }
224
225   bool visit(const luci::CircleMaxPool2D *node)
226   {
227     RETURN_FALSE_UNLESS(is_lwq(node));
228     RETURN_FALSE_UNLESS(is_lwq(node->value()));
229     return true;
230   }
231
232   bool visit(const luci::CircleLocalResponseNormalization *node)
233   {
234     RETURN_FALSE_UNLESS(is_lwq(node))
235     RETURN_FALSE_UNLESS(is_lwq(node->input()));
236     return true;
237   }
238
239   bool visit(const luci::CircleMean *node)
240   {
241     RETURN_FALSE_UNLESS(is_lwq(node));
242     RETURN_FALSE_UNLESS(is_lwq(node->input()));
243     return true;
244   }
245
246   bool visit(const luci::CircleMul *node)
247   {
248     RETURN_FALSE_UNLESS(is_lwq(node));
249     RETURN_FALSE_UNLESS(is_lwq(node->x()));
250     RETURN_FALSE_UNLESS(is_lwq(node->y()));
251     return true;
252   }
253
254   bool visit(const luci::CircleNotEqual *node)
255   {
256     RETURN_FALSE_UNLESS(is_lwq(node->x()));
257     RETURN_FALSE_UNLESS(is_lwq(node->y()));
258     return true;
259   }
260
261   bool visit(const luci::CircleRelu *node)
262   {
263     RETURN_FALSE_UNLESS(is_lwq(node));
264     RETURN_FALSE_UNLESS(is_lwq(node->features()));
265     return true;
266   }
267
268   bool visit(const luci::CircleReshape *node)
269   {
270     auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
271     bool input_quantized = input->quantparam() != nullptr;
272     bool node_quantized = node->quantparam() != nullptr;
273     RETURN_FALSE_UNLESS(input_quantized == node_quantized);
274     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
275     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
276     return true;
277   }
278
279   bool visit(const luci::CircleLogistic *node)
280   {
281     RETURN_FALSE_UNLESS(is_lwq(node));
282     RETURN_FALSE_UNLESS(is_lwq(node->x()));
283     return true;
284   }
285
286   bool visit(const luci::CircleSoftmax *node)
287   {
288     RETURN_FALSE_UNLESS(is_lwq(node));
289     RETURN_FALSE_UNLESS(is_lwq(node->logits()));
290     return true;
291   }
292
293   bool visit(const luci::CircleSpaceToBatchND *node)
294   {
295     RETURN_FALSE_UNLESS(is_lwq(node));
296     RETURN_FALSE_UNLESS(is_lwq(node->input()));
297     return true;
298   }
299
300   bool visit(const luci::CircleSpaceToDepth *node)
301   {
302     RETURN_FALSE_UNLESS(is_lwq(node));
303     RETURN_FALSE_UNLESS(is_lwq(node->input()));
304     return true;
305   }
306
307   bool visit(const luci::CircleSlice *node)
308   {
309     RETURN_FALSE_UNLESS(is_lwq(node));
310     RETURN_FALSE_UNLESS(is_lwq(node->input()));
311     return true;
312   }
313
314   bool visit(const luci::CircleSplit *node)
315   {
316     // node's output is the input of CircleSplitOut, thus not quantized
317     RETURN_FALSE_UNLESS(is_lwq(node->input()));
318     return true;
319   }
320
321   bool visit(const luci::CircleSplitOut *node)
322   {
323     RETURN_FALSE_UNLESS(is_lwq(node));
324     return true;
325   }
326
327   bool visit(const luci::CircleSplitV *node)
328   {
329     // node's output is the input of CircleSplitVOut, thus not quantized
330     RETURN_FALSE_UNLESS(is_lwq(node->input()));
331     return true;
332   }
333
334   bool visit(const luci::CircleSplitVOut *node)
335   {
336     RETURN_FALSE_UNLESS(is_lwq(node));
337     return true;
338   }
339
340   bool visit(const luci::CircleStridedSlice *node)
341   {
342     RETURN_FALSE_UNLESS(is_lwq(node));
343     RETURN_FALSE_UNLESS(is_lwq(node->input()));
344     return true;
345   }
346
347   bool visit(const luci::CircleArgMax *node)
348   {
349     // node's output is index, thus not quantized
350     RETURN_FALSE_UNLESS(is_lwq(node->input()));
351     return true;
352   }
353
354   bool visit(const luci::CircleBatchToSpaceND *node)
355   {
356     RETURN_FALSE_UNLESS(is_lwq(node));
357     RETURN_FALSE_UNLESS(is_lwq(node->input()));
358     return true;
359   }
360
361   bool visit(const luci::CircleTanh *node)
362   {
363     RETURN_FALSE_UNLESS(is_lwq(node));
364     RETURN_FALSE_UNLESS(is_lwq(node->x()));
365     return true;
366   }
367
368   bool visit(const luci::CircleTranspose *node)
369   {
370     RETURN_FALSE_UNLESS(is_lwq(node));
371     RETURN_FALSE_UNLESS(is_lwq(node->a()));
372     return true;
373   }
374
375   bool visit(const luci::CircleFloor *node)
376   {
377     RETURN_FALSE_UNLESS(is_lwq(node));
378     RETURN_FALSE_UNLESS(is_lwq(node->x()));
379     return true;
380   }
381
382   bool visit(const luci::CircleGreater *node)
383   {
384     RETURN_FALSE_UNLESS(is_lwq(node->x()));
385     RETURN_FALSE_UNLESS(is_lwq(node->y()));
386     return true;
387   }
388
389   bool visit(const luci::CircleGreaterEqual *node)
390   {
391     RETURN_FALSE_UNLESS(is_lwq(node->x()));
392     RETURN_FALSE_UNLESS(is_lwq(node->y()));
393     return true;
394   }
395
396   bool visit(const luci::CircleDiv *node)
397   {
398     RETURN_FALSE_UNLESS(is_lwq(node));
399     RETURN_FALSE_UNLESS(is_lwq(node->x()));
400     RETURN_FALSE_UNLESS(is_lwq(node->y()));
401     return true;
402   }
403
404   bool visit(const luci::CircleFloorDiv *node)
405   {
406     RETURN_FALSE_UNLESS(is_lwq(node));
407     RETURN_FALSE_UNLESS(is_lwq(node->x()));
408     RETURN_FALSE_UNLESS(is_lwq(node->y()));
409     return true;
410   }
411
412   bool visit(const luci::CircleRsqrt *node)
413   {
414     RETURN_FALSE_UNLESS(is_lwq(node));
415     RETURN_FALSE_UNLESS(is_lwq(node->x()));
416     return true;
417   }
418
419   bool visit(const luci::CircleSqrt *node)
420   {
421     RETURN_FALSE_UNLESS(is_lwq(node));
422     RETURN_FALSE_UNLESS(is_lwq(node->x()));
423     return true;
424   }
425
426   bool visit(const luci::CircleElu *node)
427   {
428     RETURN_FALSE_UNLESS(is_lwq(node));
429     RETURN_FALSE_UNLESS(is_lwq(node->features()));
430     return true;
431   }
432
433   bool visit(const luci::CirclePow *node)
434   {
435     RETURN_FALSE_UNLESS(is_lwq(node));
436     RETURN_FALSE_UNLESS(is_lwq(node->x()));
437     RETURN_FALSE_UNLESS(is_lwq(node->y()));
438     return true;
439   }
440
441   bool visit(const luci::CircleResizeBilinear *node)
442   {
443     RETURN_FALSE_UNLESS(is_lwq(node));
444     RETURN_FALSE_UNLESS(is_lwq(node->input()));
445     return true;
446   }
447
448   bool visit(const luci::CircleResizeNearestNeighbor *node)
449   {
450     RETURN_FALSE_UNLESS(is_lwq(node));
451     RETURN_FALSE_UNLESS(is_lwq(node->input()));
452     return true;
453   }
454
455   bool visit(const luci::CircleUnpack *node)
456   {
457     // node's output is the input of CircleUnpackOut, thus not quantized
458     RETURN_FALSE_UNLESS(is_lwq(node->value()));
459     return true;
460   }
461
462   bool visit(const luci::CircleUnpackOut *node)
463   {
464     RETURN_FALSE_UNLESS(is_lwq(node));
465     return true;
466   }
467
468   bool visit(const luci::CircleCast *node)
469   {
470     auto input = loco::must_cast<const luci::CircleNode *>(node->x());
471     bool input_quantized = input->quantparam() != nullptr;
472     bool node_quantized = node->quantparam() != nullptr;
473     RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
474     RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
475     return true;
476   }
477
478   // TODO: Implement more Ops
479
480   bool visit(const luci::CircleNode *) { return true; }
481 };
482
483 } // namespace luci
484
485 #undef RETURN_FALSE_UNLESS
486
487 #endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__