Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / VerifyQuantizedNodeType.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
18 #define __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__
19
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22
23 namespace luci
24 {
25
26 /**
27  * @brief Verify the data type of quantized node
28  * @details
29  *
30  * Targets to verify
31  * - node's output (i.e., node itself)
32  * - node's inputs
33  */
34 class VerifyQuantizedNodeType
35 {
36 public:
37   static std::shared_ptr<VerifyQuantizedNodeType> create(loco::DataType dtype);
38
39 public:
40   virtual bool verify(luci::CircleNode *node) = 0;
41 };
42
43 /**
44  * @brief Verify using quantization type of a node and bias
45  *
46  * @tparam Qtype Quantization type for a node (e.g. Q8, Q16, ...)
47  * @tparam Btype Bias quantization type (e.g. For Q8, S32 is used)
48  */
49 template <loco::DataType Qtype, loco::DataType Btype>
50 class VerifyQuantizedNodeTypeBase : public luci::CircleNodeVisitor<bool>,
51                                     public VerifyQuantizedNodeType
52 {
53 public:
54   bool verify(luci::CircleNode *node) { return node->accept(this); }
55
56 protected:
57   bool has_type(const loco::Node *node, loco::DataType dtype)
58   {
59     auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
60     return circle_node->dtype() == dtype;
61   }
62
63   // Check whether a node and all of its inputs have dtype or not
64   bool group_has_type(const loco::Node *node, loco::DataType dtype)
65   {
66     if (!has_type(node, dtype))
67       return false;
68
69     for (uint32_t i = 0; i < node->arity(); ++i)
70       if (!has_type(node->arg(i), dtype))
71         return false;
72
73     return true;
74   }
75
76 private:
77   bool visit(const luci::CircleAdd *node);
78   bool visit(const luci::CircleArgMax *node);
79   bool visit(const luci::CircleAveragePool2D *node);
80   bool visit(const luci::CircleBatchToSpaceND *node);
81   bool visit(const luci::CircleCast *node);
82   bool visit(const luci::CircleConv2D *node);
83   bool visit(const luci::CircleConcatenation *node);
84   bool visit(const luci::CircleDepthToSpace *node);
85   bool visit(const luci::CircleDepthwiseConv2D *node);
86   bool visit(const luci::CircleDiv *node);
87   bool visit(const luci::CircleElu *node);
88   bool visit(const luci::CircleFloor *node);
89   bool visit(const luci::CircleFloorDiv *node);
90   bool visit(const luci::CircleFullyConnected *node);
91   bool visit(const luci::CircleGelu *node);
92   bool visit(const luci::CircleGreater *node);
93   bool visit(const luci::CircleGreaterEqual *node);
94   bool visit(const luci::CircleInstanceNorm *node);
95   bool visit(const luci::CircleLocalResponseNormalization *node);
96   bool visit(const luci::CircleLogicalOr *node);
97   bool visit(const luci::CircleMaxPool2D *node);
98   bool visit(const luci::CircleMean *node);
99   bool visit(const luci::CircleMirrorPad *node);
100   bool visit(const luci::CircleMul *node);
101   bool visit(const luci::CircleNotEqual *node);
102   bool visit(const luci::CircleOneHot *node);
103   bool visit(const luci::CirclePack *node);
104   bool visit(const luci::CirclePad *node);
105   bool visit(const luci::CirclePadV2 *node);
106   bool visit(const luci::CirclePRelu *node);
107   bool visit(const luci::CirclePow *node);
108   bool visit(const luci::CircleReduceMax *node);
109   bool visit(const luci::CircleRelu *node);
110   bool visit(const luci::CircleReshape *node);
111   bool visit(const luci::CircleResizeBilinear *node);
112   bool visit(const luci::CircleResizeNearestNeighbor *node);
113   bool visit(const luci::CircleRsqrt *node);
114   bool visit(const luci::CircleSlice *node);
115   bool visit(const luci::CircleSpaceToBatchND *node);
116   bool visit(const luci::CircleSpaceToDepth *node);
117   bool visit(const luci::CircleSplit *node);
118   bool visit(const luci::CircleSplitOut *node);
119   bool visit(const luci::CircleSplitV *node);
120   bool visit(const luci::CircleSplitVOut *node);
121   bool visit(const luci::CircleSqrt *node);
122   bool visit(const luci::CircleStridedSlice *node);
123   bool visit(const luci::CircleSum *node);
124   bool visit(const luci::CircleTranspose *node);
125   bool visit(const luci::CircleTransposeConv *node);
126   bool visit(const luci::CircleUnpack *node);
127   bool visit(const luci::CircleUnpackOut *node);
128
129   // NOTE below nodes has differnent implementation for Qtype/Btype and
130   //      implementations exist in VerifyQuantizedNodeU8Type, VerifyQuantizedNodeS16Type
131   // bool visit(const luci::CircleLogistic *node);
132   // bool visit(const luci::CircleSoftmax *node);
133   // bool visit(const luci::CircleTanh *node);
134
135   // TODO: Implement more Ops
136
137   bool visit(const luci::CircleNode *) { return true; }
138 };
139
140 class VerifyQuantizedNodeU8Type
141   : public VerifyQuantizedNodeTypeBase<loco::DataType::U8, loco::DataType::S32>
142 {
143 private:
144   bool visit(const luci::CircleLogistic *node);
145   bool visit(const luci::CircleSoftmax *node);
146   bool visit(const luci::CircleTanh *node);
147 };
148
149 class VerifyQuantizedNodeS16Type
150   : public VerifyQuantizedNodeTypeBase<loco::DataType::S16, loco::DataType::S64>
151 {
152 private:
153   bool visit(const luci::CircleLogistic *node);
154   bool visit(const luci::CircleSoftmax *node);
155   bool visit(const luci::CircleTanh *node);
156 };
157
158 } // namespace luci
159
160 #endif // __LUCI_VERIFY_QUANTIZED_NODE_TYPE_H__