2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
18 #define __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Pass/QuantizationParameters.h>
26 using Granularity = luci::QuantizationGranularity;
28 // This macro is undef at the end of the file
29 #define RETURN_FALSE_UNLESS(ARG) \
39 * @brief Verify the granualrity of quantized node
43 * - node's output (i.e., node itself)
46 class VerifyQuantizedNodeGranularity : public luci::CircleNodeVisitor<bool>
49 static std::shared_ptr<VerifyQuantizedNodeGranularity> create(Granularity granularity);
52 bool is_lwq(const loco::Node *node)
54 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
56 if (circle_node->quantparam() == nullptr)
59 if (circle_node->quantparam()->scale.size() != 1)
62 if (circle_node->quantparam()->zerop.size() != 1)
69 virtual bool visit(const luci::CircleConv2D *node) = 0;
71 bool visit(const luci::CircleConcatenation *node)
73 // Skip granularity check for concatenation of indices
74 if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
77 RETURN_FALSE_UNLESS(is_lwq(node))
78 for (uint32_t i = 0; i < node->numValues(); i++)
80 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
85 bool visit(const luci::CircleDepthToSpace *node)
87 RETURN_FALSE_UNLESS(is_lwq(node))
88 RETURN_FALSE_UNLESS(is_lwq(node->input()))
92 virtual bool visit(const luci::CircleDepthwiseConv2D *node) = 0;
94 virtual bool visit(const luci::CircleInstanceNorm *node) = 0;
96 bool visit(const luci::CirclePack *node)
98 RETURN_FALSE_UNLESS(is_lwq(node))
99 for (uint32_t i = 0; i < node->values_count(); i++)
101 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
106 bool visit(const luci::CirclePad *node)
108 RETURN_FALSE_UNLESS(is_lwq(node))
109 RETURN_FALSE_UNLESS(is_lwq(node->input()))
113 bool visit(const luci::CirclePadV2 *node)
115 RETURN_FALSE_UNLESS(is_lwq(node))
116 RETURN_FALSE_UNLESS(is_lwq(node->input()))
117 RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
121 bool visit(const luci::CircleMirrorPad *node)
123 RETURN_FALSE_UNLESS(is_lwq(node))
124 RETURN_FALSE_UNLESS(is_lwq(node->input()))
128 virtual bool visit(const luci::CirclePRelu *node) = 0;
130 virtual bool visit(const luci::CircleTransposeConv *node) = 0;
132 virtual bool visit(const luci::CircleFullyConnected *node) = 0;
134 bool visit(const luci::CircleAdd *node)
136 // Skip granularity check for indices
137 if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
140 RETURN_FALSE_UNLESS(is_lwq(node));
141 RETURN_FALSE_UNLESS(is_lwq(node->x()));
142 RETURN_FALSE_UNLESS(is_lwq(node->y()));
146 bool visit(const luci::CircleAveragePool2D *node)
148 RETURN_FALSE_UNLESS(is_lwq(node));
149 RETURN_FALSE_UNLESS(is_lwq(node->value()));
153 bool visit(const luci::CircleLogicalOr *)
155 // Logical OR has bool-type inputs and output
156 // Nothing to be checked
160 bool visit(const luci::CircleMaxPool2D *node)
162 RETURN_FALSE_UNLESS(is_lwq(node));
163 RETURN_FALSE_UNLESS(is_lwq(node->value()));
167 bool visit(const luci::CircleLocalResponseNormalization *node)
169 RETURN_FALSE_UNLESS(is_lwq(node))
170 RETURN_FALSE_UNLESS(is_lwq(node->input()));
174 bool visit(const luci::CircleMean *node)
176 RETURN_FALSE_UNLESS(is_lwq(node));
177 RETURN_FALSE_UNLESS(is_lwq(node->input()));
181 bool visit(const luci::CircleMul *node)
183 // Skip granularity check for indices
184 if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
187 RETURN_FALSE_UNLESS(is_lwq(node));
188 RETURN_FALSE_UNLESS(is_lwq(node->x()));
189 RETURN_FALSE_UNLESS(is_lwq(node->y()));
193 bool visit(const luci::CircleNotEqual *node)
195 RETURN_FALSE_UNLESS(is_lwq(node->x()));
196 RETURN_FALSE_UNLESS(is_lwq(node->y()));
200 bool visit(const luci::CircleOneHot *node)
202 RETURN_FALSE_UNLESS(is_lwq(node));
203 RETURN_FALSE_UNLESS(is_lwq(node->off_value()));
204 RETURN_FALSE_UNLESS(is_lwq(node->on_value()));
208 bool visit(const luci::CircleReduceMax *node)
210 RETURN_FALSE_UNLESS(is_lwq(node));
211 RETURN_FALSE_UNLESS(is_lwq(node->input()));
215 bool visit(const luci::CircleRelu *node)
217 RETURN_FALSE_UNLESS(is_lwq(node));
218 RETURN_FALSE_UNLESS(is_lwq(node->features()));
222 bool visit(const luci::CircleReshape *node)
224 auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
225 bool input_quantized = input->quantparam() != nullptr;
226 bool node_quantized = node->quantparam() != nullptr;
227 RETURN_FALSE_UNLESS(input_quantized == node_quantized);
228 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
229 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
233 bool visit(const luci::CircleLogistic *node)
235 RETURN_FALSE_UNLESS(is_lwq(node));
236 RETURN_FALSE_UNLESS(is_lwq(node->x()));
240 bool visit(const luci::CircleSoftmax *node)
242 RETURN_FALSE_UNLESS(is_lwq(node));
243 RETURN_FALSE_UNLESS(is_lwq(node->logits()));
247 bool visit(const luci::CircleSpaceToBatchND *node)
249 RETURN_FALSE_UNLESS(is_lwq(node));
250 RETURN_FALSE_UNLESS(is_lwq(node->input()));
254 bool visit(const luci::CircleSpaceToDepth *node)
256 RETURN_FALSE_UNLESS(is_lwq(node));
257 RETURN_FALSE_UNLESS(is_lwq(node->input()));
261 bool visit(const luci::CircleSlice *node)
263 RETURN_FALSE_UNLESS(is_lwq(node));
264 RETURN_FALSE_UNLESS(is_lwq(node->input()));
268 bool visit(const luci::CircleSplit *node)
270 // node's output is the input of CircleSplitOut, thus not quantized
271 RETURN_FALSE_UNLESS(is_lwq(node->input()));
275 bool visit(const luci::CircleSplitOut *node)
277 RETURN_FALSE_UNLESS(is_lwq(node));
281 bool visit(const luci::CircleSplitV *node)
283 // node's output is the input of CircleSplitVOut, thus not quantized
284 RETURN_FALSE_UNLESS(is_lwq(node->input()));
288 bool visit(const luci::CircleSplitVOut *node)
290 RETURN_FALSE_UNLESS(is_lwq(node));
294 bool visit(const luci::CircleStridedSlice *node)
296 RETURN_FALSE_UNLESS(is_lwq(node));
297 RETURN_FALSE_UNLESS(is_lwq(node->input()));
301 bool visit(const luci::CircleSum *node)
303 RETURN_FALSE_UNLESS(is_lwq(node));
304 RETURN_FALSE_UNLESS(is_lwq(node->input()));
308 bool visit(const luci::CircleArgMax *node)
310 // node's output is index, thus not quantized
311 RETURN_FALSE_UNLESS(is_lwq(node->input()));
315 bool visit(const luci::CircleBatchToSpaceND *node)
317 RETURN_FALSE_UNLESS(is_lwq(node));
318 RETURN_FALSE_UNLESS(is_lwq(node->input()));
322 bool visit(const luci::CircleTanh *node)
324 RETURN_FALSE_UNLESS(is_lwq(node));
325 RETURN_FALSE_UNLESS(is_lwq(node->x()));
329 bool visit(const luci::CircleTranspose *node)
331 RETURN_FALSE_UNLESS(is_lwq(node));
332 RETURN_FALSE_UNLESS(is_lwq(node->a()));
336 bool visit(const luci::CircleFloor *node)
338 RETURN_FALSE_UNLESS(is_lwq(node));
339 RETURN_FALSE_UNLESS(is_lwq(node->x()));
343 bool visit(const luci::CircleGelu *node)
345 RETURN_FALSE_UNLESS(is_lwq(node));
346 RETURN_FALSE_UNLESS(is_lwq(node->features()));
350 bool visit(const luci::CircleGreater *node)
352 RETURN_FALSE_UNLESS(is_lwq(node->x()));
353 RETURN_FALSE_UNLESS(is_lwq(node->y()));
357 bool visit(const luci::CircleGreaterEqual *node)
359 RETURN_FALSE_UNLESS(is_lwq(node->x()));
360 RETURN_FALSE_UNLESS(is_lwq(node->y()));
364 bool visit(const luci::CircleDiv *node)
366 RETURN_FALSE_UNLESS(is_lwq(node));
367 RETURN_FALSE_UNLESS(is_lwq(node->x()));
368 RETURN_FALSE_UNLESS(is_lwq(node->y()));
372 bool visit(const luci::CircleFloorDiv *node)
374 RETURN_FALSE_UNLESS(is_lwq(node));
375 RETURN_FALSE_UNLESS(is_lwq(node->x()));
376 RETURN_FALSE_UNLESS(is_lwq(node->y()));
380 bool visit(const luci::CircleRsqrt *node)
382 RETURN_FALSE_UNLESS(is_lwq(node));
383 RETURN_FALSE_UNLESS(is_lwq(node->x()));
387 bool visit(const luci::CircleSqrt *node)
389 RETURN_FALSE_UNLESS(is_lwq(node));
390 RETURN_FALSE_UNLESS(is_lwq(node->x()));
394 bool visit(const luci::CircleElu *node)
396 RETURN_FALSE_UNLESS(is_lwq(node));
397 RETURN_FALSE_UNLESS(is_lwq(node->features()));
401 bool visit(const luci::CirclePow *node)
403 RETURN_FALSE_UNLESS(is_lwq(node));
404 RETURN_FALSE_UNLESS(is_lwq(node->x()));
405 RETURN_FALSE_UNLESS(is_lwq(node->y()));
409 bool visit(const luci::CircleResizeBilinear *node)
411 RETURN_FALSE_UNLESS(is_lwq(node));
412 RETURN_FALSE_UNLESS(is_lwq(node->input()));
416 bool visit(const luci::CircleResizeNearestNeighbor *node)
418 RETURN_FALSE_UNLESS(is_lwq(node));
419 RETURN_FALSE_UNLESS(is_lwq(node->input()));
423 bool visit(const luci::CircleUnpack *node)
425 // node's output is the input of CircleUnpackOut, thus not quantized
426 RETURN_FALSE_UNLESS(is_lwq(node->value()));
430 bool visit(const luci::CircleUnpackOut *node)
432 RETURN_FALSE_UNLESS(is_lwq(node));
436 bool visit(const luci::CircleCast *node)
438 auto input = loco::must_cast<const luci::CircleNode *>(node->x());
439 bool input_quantized = input->quantparam() != nullptr;
440 bool node_quantized = node->quantparam() != nullptr;
441 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
442 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
446 // TODO: Implement more Ops
448 bool visit(const luci::CircleNode *) { return true; }
451 class VerifyQuantizedNodeChannelWiseGranularity final : public VerifyQuantizedNodeGranularity
454 uint32_t rank(const loco::Node *node)
456 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
457 return circle_node->rank();
460 bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
462 auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
464 assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
465 auto channel_size = circle_node->dim(channel_dim).value();
467 if (circle_node->quantparam() == nullptr)
470 if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
473 if (circle_node->quantparam()->scale.size() != channel_size)
476 if (circle_node->quantparam()->zerop.size() != channel_size)
483 bool visit(const luci::CircleConv2D *node)
485 RETURN_FALSE_UNLESS(is_lwq(node))
486 RETURN_FALSE_UNLESS(is_lwq(node->input()))
487 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
488 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
490 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
494 bool visit(const luci::CircleDepthwiseConv2D *node)
496 RETURN_FALSE_UNLESS(is_lwq(node))
497 RETURN_FALSE_UNLESS(is_lwq(node->input()))
498 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
499 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
501 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
505 bool visit(const luci::CircleInstanceNorm *node)
507 RETURN_FALSE_UNLESS(is_lwq(node))
508 RETURN_FALSE_UNLESS(is_lwq(node->input()))
509 RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
510 RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
514 bool visit(const luci::CirclePRelu *node)
516 RETURN_FALSE_UNLESS(is_lwq(node))
517 RETURN_FALSE_UNLESS(is_lwq(node->input()))
518 RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
522 bool visit(const luci::CircleTransposeConv *node)
524 RETURN_FALSE_UNLESS(is_lwq(node))
525 RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
526 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
527 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
529 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
534 bool visit(const luci::CircleFullyConnected *node)
536 RETURN_FALSE_UNLESS(is_lwq(node))
537 RETURN_FALSE_UNLESS(is_lwq(node->input()))
538 RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
539 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
540 // Bias is optional (it can be CircleOutputExclude)
542 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
547 class VerifyQuantizedNodeLayerWiseGranularity final : public VerifyQuantizedNodeGranularity
550 bool is_lwq_const(const loco::Node *node)
552 auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
554 if (circle_node->quantparam() == nullptr)
557 if (circle_node->quantparam()->scale.size() != 1)
560 if (circle_node->quantparam()->zerop.size() != 1)
567 bool visit(const luci::CircleConv2D *node)
569 RETURN_FALSE_UNLESS(is_lwq(node))
570 RETURN_FALSE_UNLESS(is_lwq(node->input()))
571 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
572 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
574 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
578 bool visit(const luci::CircleDepthwiseConv2D *node)
580 RETURN_FALSE_UNLESS(is_lwq(node))
581 RETURN_FALSE_UNLESS(is_lwq(node->input()))
582 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
583 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
585 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
589 bool visit(const luci::CircleInstanceNorm *node)
591 RETURN_FALSE_UNLESS(is_lwq(node))
592 RETURN_FALSE_UNLESS(is_lwq(node->input()))
593 RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
594 RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
598 bool visit(const luci::CirclePRelu *node)
600 RETURN_FALSE_UNLESS(is_lwq(node))
601 RETURN_FALSE_UNLESS(is_lwq(node->input()))
602 RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
606 bool visit(const luci::CircleTransposeConv *node)
608 RETURN_FALSE_UNLESS(is_lwq(node))
609 RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
610 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
611 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
613 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
617 bool visit(const luci::CircleFullyConnected *node)
619 RETURN_FALSE_UNLESS(is_lwq(node))
620 RETURN_FALSE_UNLESS(is_lwq(node->input()))
621 RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
622 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
624 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
631 #undef RETURN_FALSE_UNLESS
633 #endif // __LUCI_VERIFY_QUANTIZED_NODE_GRANULARITY_H__