2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
16 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
17 #define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/Pass/QuantizationParameters.h>
23 using Granularity = luci::QuantizationGranularity;
25 // This macro is undef at the end of the file
26 #define RETURN_FALSE_UNLESS(ARG) \
36 * @brief Verify the granualrity of channel-wise quantized node
40 * - node's output (i.e., node itself)
43 struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool>
46 bool is_lwq(const loco::Node *node)
48 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
50 if (circle_node->quantparam() == nullptr)
53 if (circle_node->quantparam()->scale.size() != 1)
56 if (circle_node->quantparam()->zerop.size() != 1)
62 uint32_t rank(const loco::Node *node)
64 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
65 return circle_node->rank();
68 bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
70 auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
72 assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
73 auto channel_size = circle_node->dim(channel_dim).value();
75 if (circle_node->quantparam() == nullptr)
78 if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
81 if (circle_node->quantparam()->scale.size() != channel_size)
84 if (circle_node->quantparam()->zerop.size() != channel_size)
91 bool visit(const luci::CircleConv2D *node)
93 RETURN_FALSE_UNLESS(is_lwq(node))
94 RETURN_FALSE_UNLESS(is_lwq(node->input()))
95 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
96 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
98 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
102 bool visit(const luci::CircleConcatenation *node)
104 RETURN_FALSE_UNLESS(is_lwq(node))
105 for (uint32_t i = 0; i < node->numValues(); i++)
107 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
112 bool visit(const luci::CircleDepthToSpace *node)
114 RETURN_FALSE_UNLESS(is_lwq(node))
115 RETURN_FALSE_UNLESS(is_lwq(node->input()))
119 bool visit(const luci::CircleDepthwiseConv2D *node)
121 RETURN_FALSE_UNLESS(is_lwq(node))
122 RETURN_FALSE_UNLESS(is_lwq(node->input()))
123 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
124 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
126 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
130 bool visit(const luci::CircleInstanceNorm *node)
132 RETURN_FALSE_UNLESS(is_lwq(node))
133 RETURN_FALSE_UNLESS(is_lwq(node->input()))
134 RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
135 RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
139 bool visit(const luci::CirclePack *node)
141 RETURN_FALSE_UNLESS(is_lwq(node))
142 for (uint32_t i = 0; i < node->values_count(); i++)
144 RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
149 bool visit(const luci::CirclePad *node)
151 RETURN_FALSE_UNLESS(is_lwq(node))
152 RETURN_FALSE_UNLESS(is_lwq(node->input()))
156 bool visit(const luci::CirclePadV2 *node)
158 RETURN_FALSE_UNLESS(is_lwq(node))
159 RETURN_FALSE_UNLESS(is_lwq(node->input()))
160 RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
164 bool visit(const luci::CircleMirrorPad *node)
166 RETURN_FALSE_UNLESS(is_lwq(node))
167 RETURN_FALSE_UNLESS(is_lwq(node->input()))
171 bool visit(const luci::CirclePRelu *node)
173 RETURN_FALSE_UNLESS(is_lwq(node))
174 RETURN_FALSE_UNLESS(is_lwq(node->input()))
175 RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
179 bool visit(const luci::CircleTransposeConv *node)
181 RETURN_FALSE_UNLESS(is_lwq(node))
182 RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
183 RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
184 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
186 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
191 bool visit(const luci::CircleFullyConnected *node)
193 RETURN_FALSE_UNLESS(is_lwq(node))
194 RETURN_FALSE_UNLESS(is_lwq(node->input()))
195 RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
196 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
197 // Bias is optional (it can be CircleOutputExclude)
199 RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
203 bool visit(const luci::CircleAdd *node)
205 RETURN_FALSE_UNLESS(is_lwq(node));
206 RETURN_FALSE_UNLESS(is_lwq(node->x()));
207 RETURN_FALSE_UNLESS(is_lwq(node->y()));
211 bool visit(const luci::CircleAveragePool2D *node)
213 RETURN_FALSE_UNLESS(is_lwq(node));
214 RETURN_FALSE_UNLESS(is_lwq(node->value()));
218 bool visit(const luci::CircleLogicalOr *)
220 // Logical OR has bool-type inputs and output
221 // Nothing to be checked
225 bool visit(const luci::CircleMaxPool2D *node)
227 RETURN_FALSE_UNLESS(is_lwq(node));
228 RETURN_FALSE_UNLESS(is_lwq(node->value()));
232 bool visit(const luci::CircleLocalResponseNormalization *node)
234 RETURN_FALSE_UNLESS(is_lwq(node))
235 RETURN_FALSE_UNLESS(is_lwq(node->input()));
239 bool visit(const luci::CircleMean *node)
241 RETURN_FALSE_UNLESS(is_lwq(node));
242 RETURN_FALSE_UNLESS(is_lwq(node->input()));
246 bool visit(const luci::CircleMul *node)
248 RETURN_FALSE_UNLESS(is_lwq(node));
249 RETURN_FALSE_UNLESS(is_lwq(node->x()));
250 RETURN_FALSE_UNLESS(is_lwq(node->y()));
254 bool visit(const luci::CircleNotEqual *node)
256 RETURN_FALSE_UNLESS(is_lwq(node->x()));
257 RETURN_FALSE_UNLESS(is_lwq(node->y()));
261 bool visit(const luci::CircleRelu *node)
263 RETURN_FALSE_UNLESS(is_lwq(node));
264 RETURN_FALSE_UNLESS(is_lwq(node->features()));
268 bool visit(const luci::CircleReshape *node)
270 auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
271 bool input_quantized = input->quantparam() != nullptr;
272 bool node_quantized = node->quantparam() != nullptr;
273 RETURN_FALSE_UNLESS(input_quantized == node_quantized);
274 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
275 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
279 bool visit(const luci::CircleLogistic *node)
281 RETURN_FALSE_UNLESS(is_lwq(node));
282 RETURN_FALSE_UNLESS(is_lwq(node->x()));
286 bool visit(const luci::CircleSoftmax *node)
288 RETURN_FALSE_UNLESS(is_lwq(node));
289 RETURN_FALSE_UNLESS(is_lwq(node->logits()));
293 bool visit(const luci::CircleSpaceToBatchND *node)
295 RETURN_FALSE_UNLESS(is_lwq(node));
296 RETURN_FALSE_UNLESS(is_lwq(node->input()));
300 bool visit(const luci::CircleSpaceToDepth *node)
302 RETURN_FALSE_UNLESS(is_lwq(node));
303 RETURN_FALSE_UNLESS(is_lwq(node->input()));
307 bool visit(const luci::CircleSlice *node)
309 RETURN_FALSE_UNLESS(is_lwq(node));
310 RETURN_FALSE_UNLESS(is_lwq(node->input()));
314 bool visit(const luci::CircleSplit *node)
316 // node's output is the input of CircleSplitOut, thus not quantized
317 RETURN_FALSE_UNLESS(is_lwq(node->input()));
321 bool visit(const luci::CircleSplitOut *node)
323 RETURN_FALSE_UNLESS(is_lwq(node));
327 bool visit(const luci::CircleSplitV *node)
329 // node's output is the input of CircleSplitVOut, thus not quantized
330 RETURN_FALSE_UNLESS(is_lwq(node->input()));
334 bool visit(const luci::CircleSplitVOut *node)
336 RETURN_FALSE_UNLESS(is_lwq(node));
340 bool visit(const luci::CircleStridedSlice *node)
342 RETURN_FALSE_UNLESS(is_lwq(node));
343 RETURN_FALSE_UNLESS(is_lwq(node->input()));
347 bool visit(const luci::CircleArgMax *node)
349 // node's output is index, thus not quantized
350 RETURN_FALSE_UNLESS(is_lwq(node->input()));
354 bool visit(const luci::CircleBatchToSpaceND *node)
356 RETURN_FALSE_UNLESS(is_lwq(node));
357 RETURN_FALSE_UNLESS(is_lwq(node->input()));
361 bool visit(const luci::CircleTanh *node)
363 RETURN_FALSE_UNLESS(is_lwq(node));
364 RETURN_FALSE_UNLESS(is_lwq(node->x()));
368 bool visit(const luci::CircleTranspose *node)
370 RETURN_FALSE_UNLESS(is_lwq(node));
371 RETURN_FALSE_UNLESS(is_lwq(node->a()));
375 bool visit(const luci::CircleFloor *node)
377 RETURN_FALSE_UNLESS(is_lwq(node));
378 RETURN_FALSE_UNLESS(is_lwq(node->x()));
382 bool visit(const luci::CircleGreater *node)
384 RETURN_FALSE_UNLESS(is_lwq(node->x()));
385 RETURN_FALSE_UNLESS(is_lwq(node->y()));
389 bool visit(const luci::CircleGreaterEqual *node)
391 RETURN_FALSE_UNLESS(is_lwq(node->x()));
392 RETURN_FALSE_UNLESS(is_lwq(node->y()));
396 bool visit(const luci::CircleDiv *node)
398 RETURN_FALSE_UNLESS(is_lwq(node));
399 RETURN_FALSE_UNLESS(is_lwq(node->x()));
400 RETURN_FALSE_UNLESS(is_lwq(node->y()));
404 bool visit(const luci::CircleFloorDiv *node)
406 RETURN_FALSE_UNLESS(is_lwq(node));
407 RETURN_FALSE_UNLESS(is_lwq(node->x()));
408 RETURN_FALSE_UNLESS(is_lwq(node->y()));
412 bool visit(const luci::CircleRsqrt *node)
414 RETURN_FALSE_UNLESS(is_lwq(node));
415 RETURN_FALSE_UNLESS(is_lwq(node->x()));
419 bool visit(const luci::CircleSqrt *node)
421 RETURN_FALSE_UNLESS(is_lwq(node));
422 RETURN_FALSE_UNLESS(is_lwq(node->x()));
426 bool visit(const luci::CircleElu *node)
428 RETURN_FALSE_UNLESS(is_lwq(node));
429 RETURN_FALSE_UNLESS(is_lwq(node->features()));
433 bool visit(const luci::CirclePow *node)
435 RETURN_FALSE_UNLESS(is_lwq(node));
436 RETURN_FALSE_UNLESS(is_lwq(node->x()));
437 RETURN_FALSE_UNLESS(is_lwq(node->y()));
441 bool visit(const luci::CircleResizeBilinear *node)
443 RETURN_FALSE_UNLESS(is_lwq(node));
444 RETURN_FALSE_UNLESS(is_lwq(node->input()));
448 bool visit(const luci::CircleResizeNearestNeighbor *node)
450 RETURN_FALSE_UNLESS(is_lwq(node));
451 RETURN_FALSE_UNLESS(is_lwq(node->input()));
455 bool visit(const luci::CircleUnpack *node)
457 // node's output is the input of CircleUnpackOut, thus not quantized
458 RETURN_FALSE_UNLESS(is_lwq(node->value()));
462 bool visit(const luci::CircleUnpackOut *node)
464 RETURN_FALSE_UNLESS(is_lwq(node));
468 bool visit(const luci::CircleCast *node)
470 auto input = loco::must_cast<const luci::CircleNode *>(node->x());
471 bool input_quantized = input->quantparam() != nullptr;
472 bool node_quantized = node->quantparam() != nullptr;
473 RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
474 RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
478 // TODO: Implement more Ops
480 bool visit(const luci::CircleNode *) { return true; }
485 #undef RETURN_FALSE_UNLESS
487 #endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__