2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
18 #define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Pass/QuantizationParameters.h>
26 using Granularity = luci::QuantizationGranularity;
28 // This macro is undef at the end of the file
29 #define RETURN_FALSE_UNLESS(ARG) \
39 * @brief Verify the granualrity of quantized node
43 * - node's output (i.e., node itself)
46 class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
49 static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
52 bool is_lwq(const loco::Node *node)
54 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
56 if (circle_node->quantparam() == nullptr)
59 if (circle_node->quantparam()->scale.size() != 1)
62 if (circle_node->quantparam()->zerop.size() != 1)
69 virtual bool visit(const luci::CircleConv2D *node) = 0;
71 bool visit(const luci::CircleConcatenation *node)
73 // Skip granularity check for concatenation of indices
74 if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
77 RETURN_FALSE_UNLESS(is_lwq(node))
78 for (uint32_t i = 0; i < node->numValues(); i++)
80 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
85 bool visit(const luci::CircleDepthToSpace *node)
87 RETURN_FALSE_UNLESS(is_lwq(node))
88 RETURN_FALSE_UNLESS(is_lwq(node->input()))
92 virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
94 virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
96 bool visit(const luci::CirclePack *node)
98 RETURN_FALSE_UNLESS(is_lwq(node))
99 for (uint32_t i = 0; i < node->values_count(); i++)
101 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
106 bool visit(const luci::CirclePad *node)
108 RETURN_FALSE_UNLESS(is_lwq(node))
109 RETURN_FALSE_UNLESS(is_lwq(node->input()))
113 bool visit(const luci::CirclePadV2 *node)
115 RETURN_FALSE_UNLESS(is_lwq(node))
116 RETURN_FALSE_UNLESS(is_lwq(node->input()))
117 RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
121 bool visit(const luci::CircleMirrorPad *node)
123 RETURN_FALSE_UNLESS(is_lwq(node))
124 RETURN_FALSE_UNLESS(is_lwq(node->input()))
128 virtual bool visit(const luci::CirclePRelu *node) = 0;
130 virtual bool visit(const luci::CircleTransposeConv *node) = 0;
132 virtual bool visit(const luci::CircleFullyConnected *node) = 0;
134 bool visit(const luci::CircleAdd *node)
136 // Skip granularity check for indices
137 if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
140 RETURN_FALSE_UNLESS(is_lwq(node));
141 RETURN_FALSE_UNLESS(is_lwq(node->x()));
142 RETURN_FALSE_UNLESS(is_lwq(node->y()));
146 bool visit(const luci::CircleAveragePool2D *node)
148 RETURN_FALSE_UNLESS(is_lwq(node));
149 RETURN_FALSE_UNLESS(is_lwq(node->value()));
153 bool visit(const luci::CircleLogicalOr *)
155 // Logical OR has bool-type inputs and output
156 // Nothing to be checked
160 bool visit(const luci::CircleMaxPool2D *node)
162 RETURN_FALSE_UNLESS(is_lwq(node));
163 RETURN_FALSE_UNLESS(is_lwq(node->value()));
167 bool visit(const luci::CircleLocalResponseNormalization *node)
169 RETURN_FALSE_UNLESS(is_lwq(node))
170 RETURN_FALSE_UNLESS(is_lwq(node->input()));
174 bool visit(const luci::CircleMean *node)
176 RETURN_FALSE_UNLESS(is_lwq(node));
177 RETURN_FALSE_UNLESS(is_lwq(node->input()));
181 bool visit(const luci::CircleMul *node)
183 // Skip granularity check for indices
184 if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
187 RETURN_FALSE_UNLESS(is_lwq(node));
188 RETURN_FALSE_UNLESS(is_lwq(node->x()));
189 RETURN_FALSE_UNLESS(is_lwq(node->y()));
193 bool visit(const luci::CircleNotEqual *node)
195 RETURN_FALSE_UNLESS(is_lwq(node->x()));
196 RETURN_FALSE_UNLESS(is_lwq(node->y()));
200 bool visit(const luci::CircleOneHot *node)
202 RETURN_FALSE_UNLESS(is_lwq(node));
203 RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
204 RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
208 bool visit(const luci::CircleReduceMax *node)
210 RETURN_FALSE_UNLESS(is_lwq(node));
211 RETURN_FALSE_UNLESS(is_lwq(node->input()));
215 bool visit(const luci::CircleRelu *node)
217 RETURN_FALSE_UNLESS(is_lwq(node));
218 RETURN_FALSE_UNLESS(is_lwq(node->features()));
222 bool visit(const luci::CircleReshape *node)
224 auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
225 bool input_quantized = input->quantparam() != nullptr;
226 bool node_quantized = node->quantparam() != nullptr;
227 RETURN_FALSE_UNLESS(input_quantized == node_quantized);
228 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
229 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
233 bool visit(const luci::CircleLogistic *node)
235 RETURN_FALSE_UNLESS(is_lwq(node));
236 RETURN_FALSE_UNLESS(is_lwq(node->x()));
240 bool visit(const luci::CircleSoftmax *node)
242 RETURN_FALSE_UNLESS(is_lwq(node));
243 RETURN_FALSE_UNLESS(is_lwq(node->logits()));
247 bool visit(const luci::CircleSpaceToBatchND *node)
249 RETURN_FALSE_UNLESS(is_lwq(node));
250 RETURN_FALSE_UNLESS(is_lwq(node->input()));
254 bool visit(const luci::CircleSpaceToDepth *node)
256 RETURN_FALSE_UNLESS(is_lwq(node));
257 RETURN_FALSE_UNLESS(is_lwq(node->input()));
261 bool visit(const luci::CircleSlice *node)
263 RETURN_FALSE_UNLESS(is_lwq(node));
264 RETURN_FALSE_UNLESS(is_lwq(node->input()));
268 bool visit(const luci::CircleSplit *node)
270 // node's output is the input of CircleSplitOut, thus not quantized
271 RETURN_FALSE_UNLESS(is_lwq(node->input()));
275 bool visit(const luci::CircleSplitOut *node)
277 RETURN_FALSE_UNLESS(is_lwq(node));
281 bool visit(const luci::CircleSplitV *node)
283 // node's output is the input of CircleSplitVOut, thus not quantized
284 RETURN_FALSE_UNLESS(is_lwq(node->input()));
288 bool visit(const luci::CircleSplitVOut *node)
290 RETURN_FALSE_UNLESS(is_lwq(node));
294 bool visit(const luci::CircleStridedSlice *node)
296 RETURN_FALSE_UNLESS(is_lwq(node));
297 RETURN_FALSE_UNLESS(is_lwq(node->input()));
301 bool visit(const luci::CircleArgMax *node)
303 // node's output is index, thus not quantized
304 RETURN_FALSE_UNLESS(is_lwq(node->input()));
308 bool visit(const luci::CircleBatchToSpaceND *node)
310 RETURN_FALSE_UNLESS(is_lwq(node));
311 RETURN_FALSE_UNLESS(is_lwq(node->input()));
315 bool visit(const luci::CircleTanh *node)
317 RETURN_FALSE_UNLESS(is_lwq(node));
318 RETURN_FALSE_UNLESS(is_lwq(node->x()));
322 bool visit(const luci::CircleTranspose *node)
324 RETURN_FALSE_UNLESS(is_lwq(node));
325 RETURN_FALSE_UNLESS(is_lwq(node->a()));
329 bool visit(const luci::CircleFloor *node)
331 RETURN_FALSE_UNLESS(is_lwq(node));
332 RETURN_FALSE_UNLESS(is_lwq(node->x()));
336 bool visit(const luci::CircleGreater *node)
338 RETURN_FALSE_UNLESS(is_lwq(node->x()));
339 RETURN_FALSE_UNLESS(is_lwq(node->y()));
343 bool visit(const luci::CircleGreaterEqual *node)
345 RETURN_FALSE_UNLESS(is_lwq(node->x()));
346 RETURN_FALSE_UNLESS(is_lwq(node->y()));
350 bool visit(const luci::CircleDiv *node)
352 RETURN_FALSE_UNLESS(is_lwq(node));
353 RETURN_FALSE_UNLESS(is_lwq(node->x()));
354 RETURN_FALSE_UNLESS(is_lwq(node->y()));
358 bool visit(const luci::CircleFloorDiv *node)
360 RETURN_FALSE_UNLESS(is_lwq(node));
361 RETURN_FALSE_UNLESS(is_lwq(node->x()));
362 RETURN_FALSE_UNLESS(is_lwq(node->y()));
366 bool visit(const luci::CircleRsqrt *node)
368 RETURN_FALSE_UNLESS(is_lwq(node));
369 RETURN_FALSE_UNLESS(is_lwq(node->x()));
373 bool visit(const luci::CircleSqrt *node)
375 RETURN_FALSE_UNLESS(is_lwq(node));
376 RETURN_FALSE_UNLESS(is_lwq(node->x()));
380 bool visit(const luci::CircleElu *node)
382 RETURN_FALSE_UNLESS(is_lwq(node));
383 RETURN_FALSE_UNLESS(is_lwq(node->features()));
387 bool visit(const luci::CirclePow *node)
389 RETURN_FALSE_UNLESS(is_lwq(node));
390 RETURN_FALSE_UNLESS(is_lwq(node->x()));
391 RETURN_FALSE_UNLESS(is_lwq(node->y()));
395 bool visit(const luci::CircleResizeBilinear *node)
397 RETURN_FALSE_UNLESS(is_lwq(node));
398 RETURN_FALSE_UNLESS(is_lwq(node->input()));
402 bool visit(const luci::CircleResizeNearestNeighbor *node)
404 RETURN_FALSE_UNLESS(is_lwq(node));
405 RETURN_FALSE_UNLESS(is_lwq(node->input()));
409 bool visit(const luci::CircleUnpack *node)
411 // node's output is the input of CircleUnpackOut, thus not quantized
412 RETURN_FALSE_UNLESS(is_lwq(node->value()));
416 bool visit(const luci::CircleUnpackOut *node)
418 RETURN_FALSE_UNLESS(is_lwq(node));
422 bool visit(const luci::CircleCast *node)
424 auto input = loco::must_cast<const luci::CircleNode *>(node->x());
425 bool input_quantized = input->quantparam() != nullptr;
426 bool node_quantized = node->quantparam() != nullptr;
427 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
428 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
432 // TODO: Implement more Ops
434 bool visit(const luci::CircleNode *) { return true; }
437 class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
440 uint32_t rank(const loco::Node *node)
442 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
443 return circle_node->rank();
446 bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
448 auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
450 assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
451 auto channel_size = circle_node->dim(channel_dim).value();
453 if (circle_node->quantparam() == nullptr)
456 if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
459 if (circle_node->quantparam()->scale.size() != channel_size)
462 if (circle_node->quantparam()->zerop.size() != channel_size)
469 bool visit(const luci::CircleConv2D *node)
471 RETURN_FALSE_UNLESS(is_lwq(node))
472 RETURN_FALSE_UNLESS(is_lwq(node->input()))
473 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
474 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
476 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
480 bool visit(const luci::CircleDepthwiseConv2D *node)
482 RETURN_FALSE_UNLESS(is_lwq(node))
483 RETURN_FALSE_UNLESS(is_lwq(node->input()))
484 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
485 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
487 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
491 bool visit(const luci::CircleInstanceNorm *node)
493 RETURN_FALSE_UNLESS(is_lwq(node))
494 RETURN_FALSE_UNLESS(is_lwq(node->input()))
495 RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
496 RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
500 bool visit(const luci::CirclePRelu *node)
502 RETURN_FALSE_UNLESS(is_lwq(node))
503 RETURN_FALSE_UNLESS(is_lwq(node->input()))
504 RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
508 bool visit(const luci::CircleTransposeConv *node)
510 RETURN_FALSE_UNLESS(is_lwq(node))
511 RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
512 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
513 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
515 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
520 bool visit(const luci::CircleFullyConnected *node)
522 RETURN_FALSE_UNLESS(is_lwq(node))
523 RETURN_FALSE_UNLESS(is_lwq(node->input()))
524 RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
525 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
526 // Bias is optional (it can be CircleOutputExclude)
528 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
533 class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
536 bool is_lwq_const(const loco::Node *node)
538 auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
540 if (circle_node->quantparam() == nullptr)
543 if (circle_node->quantparam()->scale.size() != 1)
546 if (circle_node->quantparam()->zerop.size() != 1)
553 bool visit(const luci::CircleConv2D *node)
555 RETURN_FALSE_UNLESS(is_lwq(node))
556 RETURN_FALSE_UNLESS(is_lwq(node->input()))
557 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
558 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
560 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
564 bool visit(const luci::CircleDepthwiseConv2D *node)
566 RETURN_FALSE_UNLESS(is_lwq(node))
567 RETURN_FALSE_UNLESS(is_lwq(node->input()))
568 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
569 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
571 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
575 bool visit(const luci::CircleInstanceNorm *node)
577 RETURN_FALSE_UNLESS(is_lwq(node))
578 RETURN_FALSE_UNLESS(is_lwq(node->input()))
579 RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
580 RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
584 bool visit(const luci::CirclePRelu *node)
586 RETURN_FALSE_UNLESS(is_lwq(node))
587 RETURN_FALSE_UNLESS(is_lwq(node->input()))
588 RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
592 bool visit(const luci::CircleTransposeConv *node)
594 RETURN_FALSE_UNLESS(is_lwq(node))
595 RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
596 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
597 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
599 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
603 bool visit(const luci::CircleFullyConnected *node)
605 RETURN_FALSE_UNLESS(is_lwq(node))
606 RETURN_FALSE_UNLESS(is_lwq(node->input()))
607 RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
608 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
610 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
617 #undef RETURN_FALSE_UNLESS
619 #endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__