2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "VerifyQuantizedNodeType.h"
22 // This macro is undef at the end of the file
23 #define RETURN_FALSE_UNLESS(ARG) \
32 std::shared_ptr<VerifyQuantizedNodeType> VerifyQuantizedNodeType::create(loco::DataType dtype)
34 if (dtype == loco::DataType::U8)
35 return std::make_shared<VerifyQuantizedNodeU8Type>();
36 else if (dtype == loco::DataType::S16)
37 return std::make_shared<VerifyQuantizedNodeS16Type>();
39 throw std::domain_error("Not supported Quantized type");
47 template <loco::DataType Qtype, loco::DataType Btype>
48 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node)
50 // Allow add of indices
51 if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
54 return group_has_type(node, Qtype);
57 template <loco::DataType Qtype, loco::DataType Btype>
58 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleArgMax *node)
60 RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
61 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
62 RETURN_FALSE_UNLESS(has_type(node->dimension(), loco::DataType::S32) ||
63 has_type(node->dimension(), loco::DataType::S64))
67 template <loco::DataType Qtype, loco::DataType Btype>
68 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAveragePool2D *node)
70 return group_has_type(node, Qtype);
73 template <loco::DataType Qtype, loco::DataType Btype>
74 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleBatchToSpaceND *node)
76 RETURN_FALSE_UNLESS(has_type(node, Qtype))
77 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
81 template <loco::DataType Qtype, loco::DataType Btype>
82 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleCast *node)
84 auto *input = loco::must_cast<luci::CircleNode *>(node->x());
85 bool input_quantized = input->quantparam() != nullptr;
88 RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
89 RETURN_FALSE_UNLESS(has_type(input, Qtype))
92 bool node_quantized = node->quantparam() != nullptr;
95 RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
96 RETURN_FALSE_UNLESS(has_type(node, Qtype))
101 template <loco::DataType Qtype, loco::DataType Btype>
102 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConv2D *node)
104 RETURN_FALSE_UNLESS(has_type(node, Qtype))
105 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
106 RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
107 RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
111 template <loco::DataType Qtype, loco::DataType Btype>
112 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleConcatenation *node)
114 // Allow concatenation of indices
115 if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
118 return group_has_type(node, Qtype);
121 template <loco::DataType Qtype, loco::DataType Btype>
122 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthToSpace *node)
124 return group_has_type(node, Qtype);
127 template <loco::DataType Qtype, loco::DataType Btype>
128 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDepthwiseConv2D *node)
130 RETURN_FALSE_UNLESS(has_type(node, Qtype))
131 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
132 RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
133 RETURN_FALSE_UNLESS(has_type(node->bias(), Btype))
137 template <loco::DataType Qtype, loco::DataType Btype>
138 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleDiv *node)
140 return group_has_type(node, Qtype);
143 template <loco::DataType Qtype, loco::DataType Btype>
144 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleElu *node)
146 return group_has_type(node, Qtype);
149 template <loco::DataType Qtype, loco::DataType Btype>
150 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloor *node)
152 RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
154 // This checks the value of scale is an integer
155 RETURN_FALSE_UNLESS(node->quantparam());
156 RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
160 template <loco::DataType Qtype, loco::DataType Btype>
161 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFloorDiv *node)
163 RETURN_FALSE_UNLESS(group_has_type(node, Qtype));
165 // This checks the value of scale is an integer
166 RETURN_FALSE_UNLESS(node->quantparam());
167 RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
171 template <loco::DataType Qtype, loco::DataType Btype>
172 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleFullyConnected *node)
174 RETURN_FALSE_UNLESS(has_type(node, Qtype))
175 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
176 RETURN_FALSE_UNLESS(has_type(node->weights(), Qtype))
177 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
179 RETURN_FALSE_UNLESS(has_type(bias, Btype))
183 template <loco::DataType Qtype, loco::DataType Btype>
184 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreater *node)
186 RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
187 RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
188 RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
192 template <loco::DataType Qtype, loco::DataType Btype>
193 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleGreaterEqual *node)
195 RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
196 RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
197 RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
201 template <loco::DataType Qtype, loco::DataType Btype>
202 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleInstanceNorm *node)
204 return group_has_type(node, Qtype);
207 template <loco::DataType Qtype, loco::DataType Btype>
208 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(
209 const luci::CircleLocalResponseNormalization *node)
211 return group_has_type(node, Qtype);
214 template <loco::DataType Qtype, loco::DataType Btype>
215 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleLogicalOr *node)
217 return group_has_type(node, loco::DataType::BOOL);
220 template <loco::DataType Qtype, loco::DataType Btype>
221 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMaxPool2D *node)
223 return group_has_type(node, Qtype);
226 template <loco::DataType Qtype, loco::DataType Btype>
227 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMean *node)
229 RETURN_FALSE_UNLESS(has_type(node, Qtype))
230 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
231 RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
235 template <loco::DataType Qtype, loco::DataType Btype>
236 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPad *node)
238 RETURN_FALSE_UNLESS(has_type(node, Qtype))
239 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
240 RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
244 template <loco::DataType Qtype, loco::DataType Btype>
245 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node)
247 // Allow mul of indices
248 if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
251 return group_has_type(node, Qtype);
254 template <loco::DataType Qtype, loco::DataType Btype>
255 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleNotEqual *node)
257 RETURN_FALSE_UNLESS(has_type(node, loco::DataType::BOOL))
258 RETURN_FALSE_UNLESS(has_type(node->x(), Qtype))
259 RETURN_FALSE_UNLESS(has_type(node->y(), Qtype))
263 template <loco::DataType Qtype, loco::DataType Btype>
264 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleOneHot *node)
266 RETURN_FALSE_UNLESS(has_type(node, Qtype));
267 RETURN_FALSE_UNLESS(has_type(node->indices(), loco::DataType::S32) ||
268 has_type(node->indices(), loco::DataType::S64));
269 RETURN_FALSE_UNLESS(has_type(node->depth(), loco::DataType::S32));
270 RETURN_FALSE_UNLESS(has_type(node->on_value(), Qtype));
271 RETURN_FALSE_UNLESS(has_type(node->off_value(), Qtype));
275 template <loco::DataType Qtype, loco::DataType Btype>
276 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePack *node)
278 return group_has_type(node, Qtype);
281 template <loco::DataType Qtype, loco::DataType Btype>
282 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePad *node)
284 RETURN_FALSE_UNLESS(has_type(node, Qtype))
285 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
286 RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
290 template <loco::DataType Qtype, loco::DataType Btype>
291 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePadV2 *node)
293 RETURN_FALSE_UNLESS(has_type(node, Qtype))
294 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
295 RETURN_FALSE_UNLESS(has_type(node->paddings(), loco::DataType::S32))
296 RETURN_FALSE_UNLESS(has_type(node->constant_values(), Qtype))
300 template <loco::DataType Qtype, loco::DataType Btype>
301 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePRelu *node)
303 return group_has_type(node, Qtype);
306 template <loco::DataType Qtype, loco::DataType Btype>
307 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *node)
309 return group_has_type(node, Qtype);
312 template <loco::DataType Qtype, loco::DataType Btype>
313 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReduceMax *node)
315 RETURN_FALSE_UNLESS(has_type(node, Qtype))
316 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
317 RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
321 template <loco::DataType Qtype, loco::DataType Btype>
322 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node)
324 return group_has_type(node, Qtype);
327 template <loco::DataType Qtype, loco::DataType Btype>
328 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReshape *node)
330 if (node->quantparam())
332 RETURN_FALSE_UNLESS(has_type(node, Qtype))
333 RETURN_FALSE_UNLESS(has_type(node->tensor(), Qtype))
337 RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
339 luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
340 if (shape != nullptr)
341 RETURN_FALSE_UNLESS(has_type(shape, loco::DataType::S32))
345 template <loco::DataType Qtype, loco::DataType Btype>
346 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeBilinear *node)
348 RETURN_FALSE_UNLESS(has_type(node, Qtype))
349 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
353 template <loco::DataType Qtype, loco::DataType Btype>
354 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleResizeNearestNeighbor *node)
356 RETURN_FALSE_UNLESS(has_type(node, Qtype))
357 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
361 template <loco::DataType Qtype, loco::DataType Btype>
362 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRsqrt *node)
364 return group_has_type(node, Qtype);
367 template <loco::DataType Qtype, loco::DataType Btype>
368 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSlice *node)
370 RETURN_FALSE_UNLESS(has_type(node, Qtype))
371 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
372 RETURN_FALSE_UNLESS(has_type(node->begin(), loco::DataType::S32) ||
373 has_type(node->begin(), loco::DataType::S64))
374 RETURN_FALSE_UNLESS(has_type(node->size(), loco::DataType::S32) ||
375 has_type(node->size(), loco::DataType::S64))
379 template <loco::DataType Qtype, loco::DataType Btype>
380 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToBatchND *node)
382 RETURN_FALSE_UNLESS(has_type(node, Qtype))
383 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
387 template <loco::DataType Qtype, loco::DataType Btype>
388 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSpaceToDepth *node)
390 return group_has_type(node, Qtype);
393 template <loco::DataType Qtype, loco::DataType Btype>
394 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplit *node)
396 // node's output is the input of CircleSplitOut, thus not quantized
397 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
401 template <loco::DataType Qtype, loco::DataType Btype>
402 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitOut *node)
404 RETURN_FALSE_UNLESS(has_type(node, Qtype))
406 // SplitOut has the same qparam with the input of Split
407 auto split = loco::must_cast<luci::CircleSplit *>(node->input());
408 auto input = loco::must_cast<luci::CircleNode *>(split->input());
409 RETURN_FALSE_UNLESS(node->quantparam());
410 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
411 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
415 template <loco::DataType Qtype, loco::DataType Btype>
416 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitV *node)
418 // node's output is the input of CircleSplitVOut, thus not quantized
419 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
423 template <loco::DataType Qtype, loco::DataType Btype>
424 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSplitVOut *node)
426 RETURN_FALSE_UNLESS(has_type(node, Qtype))
428 // SplitVOut has the same qparam with the input of SplitV
429 auto splitv = loco::must_cast<luci::CircleSplitV *>(node->input());
430 auto input = loco::must_cast<luci::CircleNode *>(splitv->input());
431 RETURN_FALSE_UNLESS(node->quantparam());
432 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
433 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
437 template <loco::DataType Qtype, loco::DataType Btype>
438 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleSqrt *node)
440 return group_has_type(node, Qtype);
443 template <loco::DataType Qtype, loco::DataType Btype>
444 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleStridedSlice *node)
446 RETURN_FALSE_UNLESS(has_type(node, Qtype))
447 RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
449 auto input = loco::must_cast<luci::CircleNode *>(node->input());
450 RETURN_FALSE_UNLESS(node->quantparam());
451 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
452 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
456 template <loco::DataType Qtype, loco::DataType Btype>
457 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTranspose *node)
459 RETURN_FALSE_UNLESS(has_type(node, Qtype))
460 RETURN_FALSE_UNLESS(has_type(node->a(), Qtype))
461 RETURN_FALSE_UNLESS(has_type(node->perm(), loco::DataType::S32))
465 template <loco::DataType Qtype, loco::DataType Btype>
466 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleTransposeConv *node)
468 RETURN_FALSE_UNLESS(has_type(node, Qtype))
469 RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Qtype))
470 RETURN_FALSE_UNLESS(has_type(node->filter(), Qtype))
471 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
473 RETURN_FALSE_UNLESS(has_type(bias, Btype))
477 template <loco::DataType Qtype, loco::DataType Btype>
478 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpack *node)
480 // node's output is the input of CircleUnpackOut, thus not quantized
481 RETURN_FALSE_UNLESS(has_type(node->value(), Qtype))
485 template <loco::DataType Qtype, loco::DataType Btype>
486 bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleUnpackOut *node)
488 RETURN_FALSE_UNLESS(has_type(node, Qtype))
490 // UnpackOut has the same qparam with the input of Unpack
491 auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
492 auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
493 RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
494 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
495 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
504 bool VerifyQuantizedNodeU8Type::visit(const luci::CircleTanh *node)
506 RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
508 RETURN_FALSE_UNLESS(node->quantparam());
509 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
510 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
514 bool VerifyQuantizedNodeU8Type::visit(const luci::CircleLogistic *node)
516 RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
518 RETURN_FALSE_UNLESS(node->quantparam());
519 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
520 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
524 bool VerifyQuantizedNodeU8Type::visit(const luci::CircleSoftmax *node)
526 RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::U8));
528 RETURN_FALSE_UNLESS(node->quantparam());
529 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
530 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
539 bool VerifyQuantizedNodeS16Type::visit(const luci::CircleTanh *node)
541 RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
543 RETURN_FALSE_UNLESS(node->quantparam());
544 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
545 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
549 bool VerifyQuantizedNodeS16Type::visit(const luci::CircleLogistic *node)
551 RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
553 RETURN_FALSE_UNLESS(node->quantparam());
554 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
555 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
559 bool VerifyQuantizedNodeS16Type::visit(const luci::CircleSoftmax *node)
561 RETURN_FALSE_UNLESS(group_has_type(node, loco::DataType::S16));
563 RETURN_FALSE_UNLESS(node->quantparam());
564 RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
565 RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
571 #undef RETURN_FALSE_UNLESS