2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
18 #define __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
27 * @brief Verify the data type of quantized node
31 * - node's output (i.e., node itself)
34 class VerifyQuantizedNodeType
37 static std::shared_ptr<VerifyQuantizedNodeType> create(loco::DataType dtype);
40 virtual bool verify(luci::CircleNode *node) = 0;
44 * @brief Verify using quantization type of a node and bias
46 * @tparam Qtype Quantization type for a node (e.g. Q8, Q16, ...)
47 * @tparam Btype Bias quantization type (e.g. For Q8, S32 is used)
49 template <loco::DataType Qtype, loco::DataType Btype>
50 class VerifyQuantizedNodeTypeBase : public luci::CircleNodeVisitor<bool>,
51 public VerifyQuantizedNodeType
54 bool verify(luci::CircleNode *node) { return node->accept(this); }
57 bool has_type(const loco::Node *node, loco::DataType dtype)
59 auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
60 return circle_node->dtype() == dtype;
63 // Check whether a node and all of its inputs have dtype or not
64 bool group_has_type(const loco::Node *node, loco::DataType dtype)
66 if (!has_type(node, dtype))
69 for (uint32_t i = 0; i < node->arity(); ++i)
70 if (!has_type(node->arg(i), dtype))
77 bool visit(const luci::CircleAdd *node);
78 bool visit(const luci::CircleArgMax *node);
79 bool visit(const luci::CircleAveragePool2D *node);
80 bool visit(const luci::CircleBatchToSpaceND *node);
81 bool visit(const luci::CircleCast *node);
82 bool visit(const luci::CircleConv2D *node);
83 bool visit(const luci::CircleConcatenation *node);
84 bool visit(const luci::CircleDepthToSpace *node);
85 bool visit(const luci::CircleDepthwiseConv2D *node);
86 bool visit(const luci::CircleDiv *node);
87 bool visit(const luci::CircleElu *node);
88 bool visit(const luci::CircleFloor *node);
89 bool visit(const luci::CircleFloorDiv *node);
90 bool visit(const luci::CircleFullyConnected *node);
91 bool visit(const luci::CircleGreater *node);
92 bool visit(const luci::CircleGreaterEqual *node);
93 bool visit(const luci::CircleInstanceNorm *node);
94 bool visit(const luci::CircleLocalResponseNormalization *node);
95 bool visit(const luci::CircleLogicalOr *node);
96 bool visit(const luci::CircleMaxPool2D *node);
97 bool visit(const luci::CircleMean *node);
98 bool visit(const luci::CircleMirrorPad *node);
99 bool visit(const luci::CircleMul *node);
100 bool visit(const luci::CircleNotEqual *node);
101 bool visit(const luci::CircleOneHot *node);
102 bool visit(const luci::CirclePack *node);
103 bool visit(const luci::CirclePad *node);
104 bool visit(const luci::CirclePadV2 *node);
105 bool visit(const luci::CirclePRelu *node);
106 bool visit(const luci::CirclePow *node);
107 bool visit(const luci::CircleReduceMax *node);
108 bool visit(const luci::CircleRelu *node);
109 bool visit(const luci::CircleReshape *node);
110 bool visit(const luci::CircleResizeBilinear *node);
111 bool visit(const luci::CircleResizeNearestNeighbor *node);
112 bool visit(const luci::CircleRsqrt *node);
113 bool visit(const luci::CircleSlice *node);
114 bool visit(const luci::CircleSpaceToBatchND *node);
115 bool visit(const luci::CircleSpaceToDepth *node);
116 bool visit(const luci::CircleSplit *node);
117 bool visit(const luci::CircleSplitOut *node);
118 bool visit(const luci::CircleSplitV *node);
119 bool visit(const luci::CircleSplitVOut *node);
120 bool visit(const luci::CircleSqrt *node);
121 bool visit(const luci::CircleStridedSlice *node);
122 bool visit(const luci::CircleTranspose *node);
123 bool visit(const luci::CircleTransposeConv *node);
124 bool visit(const luci::CircleUnpack *node);
125 bool visit(const luci::CircleUnpackOut *node);
127 // NOTE below nodes has differnent implementation for Qtype/Btype and
128 // implementations exist in VerifyQuantizedNodeU8Type, VerifyQuantizedNodeS16Type
129 // bool visit(const luci::CircleLogistic *node);
130 // bool visit(const luci::CircleSoftmax *node);
131 // bool visit(const luci::CircleTanh *node);
133 // TODO: Implement more Ops
135 bool visit(const luci::CircleNode *) { return true; }
138 class VerifyQuantizedNodeU8Type
139 : public VerifyQuantizedNodeTypeBase<loco::DataType::U8, loco::DataType::S32>
142 bool visit(const luci::CircleLogistic *node);
143 bool visit(const luci::CircleSoftmax *node);
144 bool visit(const luci::CircleTanh *node);
147 class VerifyQuantizedNodeS16Type
148 : public VerifyQuantizedNodeTypeBase<loco::DataType::S16, loco::DataType::S64>
151 bool visit(const luci::CircleLogistic *node);
152 bool visit(const luci::CircleSoftmax *node);
153 bool visit(const luci::CircleTanh *node);
158 #endif // __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__