2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
16 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
17 #define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/Pass/QuantizationParameters.h>
23 using Granularity = luci::QuantizationGranularity;
25 // This macro is undef at the end of the file
26 #define RETURN_FALSE_UNLESS(ARG) \
36 * @brief Verify the granualrity of layer-wise quantized node
40 * - node's output (i.e., node itself)
43 struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool>
46 bool is_lwq(const loco::Node *node)
48 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
50 if (circle_node->quantparam() == nullptr)
53 if (circle_node->quantparam()->scale.size() != 1)
56 if (circle_node->quantparam()->zerop.size() != 1)
62 bool is_lwq_const(const loco::Node *node)
64 auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
66 if (circle_node->quantparam() == nullptr)
69 if (circle_node->quantparam()->scale.size() != 1)
72 if (circle_node->quantparam()->zerop.size() != 1)
79 bool visit(const luci::CircleConv2D *node)
81 RETURN_FALSE_UNLESS(is_lwq(node))
82 RETURN_FALSE_UNLESS(is_lwq(node->input()))
83 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
84 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
86 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
90 bool visit(const luci::CircleConcatenation *node)
92 RETURN_FALSE_UNLESS(is_lwq(node))
93 for (uint32_t i = 0; i < node->numValues(); i++)
95 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
100 bool visit(const luci::CircleDepthToSpace *node)
102 RETURN_FALSE_UNLESS(is_lwq(node))
103 RETURN_FALSE_UNLESS(is_lwq(node->input()))
107 bool visit(const luci::CircleDepthwiseConv2D *node)
109 RETURN_FALSE_UNLESS(is_lwq(node))
110 RETURN_FALSE_UNLESS(is_lwq(node->input()))
111 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
112 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
114 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
118 bool visit(const luci::CircleInstanceNorm *node)
120 RETURN_FALSE_UNLESS(is_lwq(node))
121 RETURN_FALSE_UNLESS(is_lwq(node->input()))
122 RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
123 RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
127 bool visit(const luci::CirclePack *node)
129 RETURN_FALSE_UNLESS(is_lwq(node))
130 for (uint32_t i = 0; i < node->values_count(); i++)
132 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
137 bool visit(const luci::CirclePad *node)
139 RETURN_FALSE_UNLESS(is_lwq(node))
140 RETURN_FALSE_UNLESS(is_lwq(node->input()))
144 bool visit(const luci::CirclePadV2 *node)
146 RETURN_FALSE_UNLESS(is_lwq(node))
147 RETURN_FALSE_UNLESS(is_lwq(node->input()))
148 RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
152 bool visit(const luci::CircleMirrorPad *node)
154 RETURN_FALSE_UNLESS(is_lwq(node))
155 RETURN_FALSE_UNLESS(is_lwq(node->input()))
159 bool visit(const luci::CirclePRelu *node)
161 RETURN_FALSE_UNLESS(is_lwq(node))
162 RETURN_FALSE_UNLESS(is_lwq(node->input()))
163 RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
167 bool visit(const luci::CircleTransposeConv *node)
169 RETURN_FALSE_UNLESS(is_lwq(node))
170 RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
171 RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
172 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
174 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
178 bool visit(const luci::CircleFullyConnected *node)
180 RETURN_FALSE_UNLESS(is_lwq(node))
181 RETURN_FALSE_UNLESS(is_lwq(node->input()))
182 RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
183 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
185 RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
189 bool visit(const luci::CircleAdd *node)
191 RETURN_FALSE_UNLESS(is_lwq(node))
192 RETURN_FALSE_UNLESS(is_lwq(node->x()));
193 RETURN_FALSE_UNLESS(is_lwq(node->y()));
197 bool visit(const luci::CircleAveragePool2D *node)
199 RETURN_FALSE_UNLESS(is_lwq(node))
200 RETURN_FALSE_UNLESS(is_lwq(node->value()));
204 bool visit(const luci::CircleLogicalOr *)
206 // Logical OR has bool-type inputs and output
207 // Nothing to be checked
211 bool visit(const luci::CircleMaxPool2D *node)
213 RETURN_FALSE_UNLESS(is_lwq(node))
214 RETURN_FALSE_UNLESS(is_lwq(node->value()));
218 bool visit(const luci::CircleLocalResponseNormalization *node)
220 RETURN_FALSE_UNLESS(is_lwq(node))
221 RETURN_FALSE_UNLESS(is_lwq(node->input()));
225 bool visit(const luci::CircleMean *node)
227 RETURN_FALSE_UNLESS(is_lwq(node))
228 RETURN_FALSE_UNLESS(is_lwq(node->input()));
232 bool visit(const luci::CircleMul *node)
234 RETURN_FALSE_UNLESS(is_lwq(node))
235 RETURN_FALSE_UNLESS(is_lwq(node->x()));
236 RETURN_FALSE_UNLESS(is_lwq(node->y()));
240 bool visit(const luci::CircleNotEqual *node)
242 RETURN_FALSE_UNLESS(is_lwq(node->x()));
243 RETURN_FALSE_UNLESS(is_lwq(node->y()));
247 bool visit(const luci::CircleRelu *node)
249 RETURN_FALSE_UNLESS(is_lwq(node))
250 RETURN_FALSE_UNLESS(is_lwq(node->features()));
254 bool visit(const luci::CircleReshape *node)
256 auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
257 bool input_quantized = input->quantparam() != nullptr;
258 bool node_quantized = node->quantparam() != nullptr;
259 RETURN_FALSE_UNLESS(input_quantized == node_quantized);
260 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
261 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
265 bool visit(const luci::CircleLogistic *node)
267 RETURN_FALSE_UNLESS(is_lwq(node));
268 RETURN_FALSE_UNLESS(is_lwq(node->x()));
272 bool visit(const luci::CircleSoftmax *node)
274 RETURN_FALSE_UNLESS(is_lwq(node));
275 RETURN_FALSE_UNLESS(is_lwq(node->logits()));
279 bool visit(const luci::CircleSpaceToBatchND *node)
281 RETURN_FALSE_UNLESS(is_lwq(node));
282 RETURN_FALSE_UNLESS(is_lwq(node->input()));
286 bool visit(const luci::CircleSpaceToDepth *node)
288 RETURN_FALSE_UNLESS(is_lwq(node));
289 RETURN_FALSE_UNLESS(is_lwq(node->input()));
293 bool visit(const luci::CircleSlice *node)
295 RETURN_FALSE_UNLESS(is_lwq(node));
296 RETURN_FALSE_UNLESS(is_lwq(node->input()));
300 bool visit(const luci::CircleSplit *node)
302 // node's output is the input of CircleSplitOut, thus not quantized
303 RETURN_FALSE_UNLESS(is_lwq(node->input()));
307 bool visit(const luci::CircleSplitOut *node)
309 RETURN_FALSE_UNLESS(is_lwq(node));
313 bool visit(const luci::CircleSplitV *node)
315 // node's output is the input of CircleSplitVOut, thus not quantized
316 RETURN_FALSE_UNLESS(is_lwq(node->input()));
320 bool visit(const luci::CircleSplitVOut *node)
322 RETURN_FALSE_UNLESS(is_lwq(node));
326 bool visit(const luci::CircleStridedSlice *node)
328 RETURN_FALSE_UNLESS(is_lwq(node));
329 RETURN_FALSE_UNLESS(is_lwq(node->input()));
333 bool visit(const luci::CircleArgMax *node)
335 // node's output is index, thus not quantized
336 RETURN_FALSE_UNLESS(is_lwq(node->input()));
340 bool visit(const luci::CircleBatchToSpaceND *node)
342 RETURN_FALSE_UNLESS(is_lwq(node));
343 RETURN_FALSE_UNLESS(is_lwq(node->input()));
347 bool visit(const luci::CircleTanh *node)
349 RETURN_FALSE_UNLESS(is_lwq(node));
350 RETURN_FALSE_UNLESS(is_lwq(node->x()));
354 bool visit(const luci::CircleTranspose *node)
356 RETURN_FALSE_UNLESS(is_lwq(node));
357 RETURN_FALSE_UNLESS(is_lwq(node->a()));
361 bool visit(const luci::CircleFloor *node)
363 RETURN_FALSE_UNLESS(is_lwq(node));
364 RETURN_FALSE_UNLESS(is_lwq(node->x()));
368 bool visit(const luci::CircleGreater *node)
370 RETURN_FALSE_UNLESS(is_lwq(node->x()));
371 RETURN_FALSE_UNLESS(is_lwq(node->y()));
375 bool visit(const luci::CircleGreaterEqual *node)
377 RETURN_FALSE_UNLESS(is_lwq(node->x()));
378 RETURN_FALSE_UNLESS(is_lwq(node->y()));
382 bool visit(const luci::CircleDiv *node)
384 RETURN_FALSE_UNLESS(is_lwq(node));
385 RETURN_FALSE_UNLESS(is_lwq(node->x()));
386 RETURN_FALSE_UNLESS(is_lwq(node->y()));
390 bool visit(const luci::CircleFloorDiv *node)
392 RETURN_FALSE_UNLESS(is_lwq(node));
393 RETURN_FALSE_UNLESS(is_lwq(node->x()));
394 RETURN_FALSE_UNLESS(is_lwq(node->y()));
398 bool visit(const luci::CircleRsqrt *node)
400 RETURN_FALSE_UNLESS(is_lwq(node));
401 RETURN_FALSE_UNLESS(is_lwq(node->x()));
405 bool visit(const luci::CircleSqrt *node)
407 RETURN_FALSE_UNLESS(is_lwq(node));
408 RETURN_FALSE_UNLESS(is_lwq(node->x()));
412 bool visit(const luci::CircleElu *node)
414 RETURN_FALSE_UNLESS(is_lwq(node));
415 RETURN_FALSE_UNLESS(is_lwq(node->features()));
419 bool visit(const luci::CirclePow *node)
421 RETURN_FALSE_UNLESS(is_lwq(node));
422 RETURN_FALSE_UNLESS(is_lwq(node->x()));
423 RETURN_FALSE_UNLESS(is_lwq(node->y()));
427 bool visit(const luci::CircleResizeBilinear *node)
429 RETURN_FALSE_UNLESS(is_lwq(node));
430 RETURN_FALSE_UNLESS(is_lwq(node->input()));
434 bool visit(const luci::CircleResizeNearestNeighbor *node)
436 RETURN_FALSE_UNLESS(is_lwq(node));
437 RETURN_FALSE_UNLESS(is_lwq(node->input()));
441 bool visit(const luci::CircleUnpack *node)
443 // node's output is the input of CircleUnpackOut, thus not quantized
444 RETURN_FALSE_UNLESS(is_lwq(node->value()));
448 bool visit(const luci::CircleUnpackOut *node)
450 RETURN_FALSE_UNLESS(is_lwq(node));
454 bool visit(const luci::CircleCast *node)
456 auto input = loco::must_cast<const luci::CircleNode *>(node->x());
457 bool input_quantized = input->quantparam() != nullptr;
458 bool node_quantized = node->quantparam() != nullptr;
459 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
460 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
464 // TODO: Implement more Ops
466 bool visit(const luci::CircleNode *) { return true; }
471 #undef RETURN_FALSE_UNLESS
473 #endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__