c6c991a7609b3bc3f3eb61952a2d3620557379b6
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / QuantizeActivation.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_QUANTIZATION_ACTIVATION_H__
18 #define __LUCI_QUANTIZATION_ACTIVATION_H__
19
20 #include <luci/IR/CircleNodeVisitor.h>
21
22 namespace luci
23 {
24
25 /**
26  * @brief Quantize non-const activation using recorded min/max values
27  */
28 struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<void>
29 {
30   QuantizeActivation(loco::DataType input, loco::DataType output)
31     : input_type(input), output_type(output)
32   {
33   }
34
35   loco::DataType input_type;
36   loco::DataType output_type;
37
38   // Quantize each node using recorded min/max
39   void visit(luci::CircleNode *node);
40 };
41
42 /**
43  * @brief Quantize non-const activaion using pre-defined scale/zp for special Ops
44  */
45 struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
46 {
47   QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
48     : input_type(input), output_type(output)
49   {
50   }
51
52   loco::DataType input_type;
53   loco::DataType output_type;
54
55   void visit(luci::CircleNode *node);
56   void visit(luci::CircleLogistic *node);
57   void visit(luci::CircleTanh *node);
58   void visit(luci::CircleSoftmax *node);
59   void visit(luci::CircleFloor *node);
60   void visit(luci::CircleFloorDiv *node);
61   void visit(luci::CircleFloorMod *node);
62   void visit(luci::CircleCeil *node);
63 };
64
65 // Quantize constant input activation of a node
66 // The input of a node is quantized if it is
67 // 1. Constant (instance of CircleConst*)
68 // 2. Activation (other inputs e.g., weights, bias, axis, etc should not be quantized here)
69 struct QuantizeConstInputActivation final : public luci::CircleNodeMutableVisitor<void>
70 {
71   QuantizeConstInputActivation(loco::DataType output_type) : _output_type(output_type) {}
72
73 private:
74   loco::DataType _output_type;
75
76 // Skip NODE
77 #define SKIP(NODE) \
78   void visit(NODE *) {}
79
80   // Handled in QuantizeWeights and QuantizeBias
81   SKIP(luci::CircleConv2D)
82   SKIP(luci::CircleDepthwiseConv2D)
83   SKIP(luci::CircleFullyConnected)
84   SKIP(luci::CircleInstanceNorm)
85   SKIP(luci::CirclePRelu)
86   SKIP(luci::CircleTransposeConv)
87
88   // Handled in PropagateQParamBackwardPass
89   SKIP(luci::CircleConcatenation)
90   SKIP(luci::CirclePadV2)
91   SKIP(luci::CirclePack)
92   SKIP(luci::CircleOneHot)
93
94   // Inputs of logical Ops are bool, thus not quantized
95   SKIP(luci::CircleLogicalOr)
96   SKIP(luci::CircleLogicalAnd)
97   SKIP(luci::CircleLogicalNot)
98
99 #undef SKIP
100
101   // Default behavior (NYI)
102   void visit(luci::CircleNode *node);
103
104   // Ops that receive a single activation as an input
105   void visit(luci::CircleAbs *node);
106   void visit(luci::CircleArgMax *node);
107   void visit(luci::CircleArgMin *node);
108   void visit(luci::CircleBatchToSpaceND *node);
109   void visit(luci::CircleDepthToSpace *node);
110   void visit(luci::CircleElu *node);
111   void visit(luci::CircleExp *node);
112   void visit(luci::CircleFloor *node);
113   void visit(luci::CircleGather *node);
114   void visit(luci::CircleLocalResponseNormalization *node);
115   void visit(luci::CircleLogistic *node);
116   void visit(luci::CircleMean *node);
117   void visit(luci::CircleMirrorPad *node);
118   void visit(luci::CirclePad *node);
119   void visit(luci::CircleReduceAny *node);
120   void visit(luci::CircleReduceProd *node);
121   void visit(luci::CircleReduceMax *node);
122   void visit(luci::CircleReduceMin *node);
123   void visit(luci::CircleReshape *node);
124   void visit(luci::CircleResizeBilinear *node);
125   void visit(luci::CircleResizeNearestNeighbor *node);
126   void visit(luci::CircleReverseSequence *node);
127   void visit(luci::CircleRsqrt *node);
128   void visit(luci::CircleSlice *node);
129   void visit(luci::CircleSoftmax *node);
130   void visit(luci::CircleSpaceToBatchND *node);
131   void visit(luci::CircleSpaceToDepth *node);
132   void visit(luci::CircleSplit *node);
133   void visit(luci::CircleSplitV *node);
134   void visit(luci::CircleSqrt *node);
135   void visit(luci::CircleStridedSlice *node);
136   void visit(luci::CircleSum *node);
137   void visit(luci::CircleTanh *node);
138   void visit(luci::CircleTile *node);
139   void visit(luci::CircleTopKV2 *node);
140   void visit(luci::CircleTranspose *node);
141   void visit(luci::CircleUnpack *node);
142
143   // Ops that receive two activations as inputs
144   void visit(luci::CircleAdd *node);
145   void visit(luci::CircleBatchMatMul *node);
146   void visit(luci::CircleDiv *node);
147   void visit(luci::CircleEqual *node);
148   void visit(luci::CircleFloorDiv *node);
149   void visit(luci::CircleGreater *node);
150   void visit(luci::CircleGreaterEqual *node);
151   void visit(luci::CircleLess *node);
152   void visit(luci::CircleLessEqual *node);
153   void visit(luci::CircleMaximum *node);
154   void visit(luci::CircleMinimum *node);
155   void visit(luci::CircleMul *node);
156   void visit(luci::CircleNotEqual *node);
157   void visit(luci::CirclePow *node);
158   void visit(luci::CircleSub *node);
159
160   // AddN has arbitrary number of inputs
161   void visit(luci::CircleAddN *node);
162 };
163
164 } // namespace luci
165
166 #endif // __LUCI_QUANTIZATION_ACTIVATION_H__