Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleCloneNode.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __CIRCLE_CLONE_NODE_H__
18 #define __CIRCLE_CLONE_NODE_H__
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include <luci/IR/CircleNodeVisitor.h>
23
24 namespace luci
25 {
26
27 // CloneNode-let type
28 enum class CN
29 {
30   ABC,
31   DEF,
32   GHIJ,
33   KLMN,
34   OPQR,
35   STUV,
36   WXYZ,
37 };
38
39 template <CN ct> class CloneNodeLet;
40
41 template <> class CloneNodeLet<CN::ABC> final : public luci::CircleNodeVisitor<luci::CircleNode *>
42 {
43 public:
44   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
45
46 public:
47   luci::CircleNode *visit(const luci::CircleAbs *) final;
48   luci::CircleNode *visit(const luci::CircleAdd *) final;
49   luci::CircleNode *visit(const luci::CircleAddN *) final;
50   luci::CircleNode *visit(const luci::CircleArgMax *) final;
51   luci::CircleNode *visit(const luci::CircleArgMin *) final;
52   luci::CircleNode *visit(const luci::CircleAveragePool2D *) final;
53   luci::CircleNode *visit(const luci::CircleBatchMatMul *) final;
54   luci::CircleNode *visit(const luci::CircleBatchToSpaceND *) final;
55   luci::CircleNode *visit(const luci::CircleCast *) final;
56   luci::CircleNode *visit(const luci::CircleCeil *) final;
57   luci::CircleNode *visit(const luci::CircleConcatenation *) final;
58   luci::CircleNode *visit(const luci::CircleConst *) final;
59   luci::CircleNode *visit(const luci::CircleConv2D *) final;
60   luci::CircleNode *visit(const luci::CircleCos *) final;
61   luci::CircleNode *visit(const luci::CircleCustom *) final;
62
63   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
64
65 protected:
66   loco::Graph *_graph = nullptr;
67 };
68
69 template <> class CloneNodeLet<CN::DEF> final : public luci::CircleNodeVisitor<luci::CircleNode *>
70 {
71 public:
72   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
73
74 public:
75   luci::CircleNode *visit(const luci::CircleDensify *) final;
76   luci::CircleNode *visit(const luci::CircleDepthToSpace *) final;
77   luci::CircleNode *visit(const luci::CircleDepthwiseConv2D *) final;
78   luci::CircleNode *visit(const luci::CircleDequantize *) final;
79   luci::CircleNode *visit(const luci::CircleDiv *) final;
80   luci::CircleNode *visit(const luci::CircleElu *) final;
81   luci::CircleNode *visit(const luci::CircleEqual *) final;
82   luci::CircleNode *visit(const luci::CircleExp *) final;
83   luci::CircleNode *visit(const luci::CircleExpandDims *) final;
84   luci::CircleNode *visit(const luci::CircleFakeQuant *) final;
85   luci::CircleNode *visit(const luci::CircleFill *) final;
86   luci::CircleNode *visit(const luci::CircleFloor *) final;
87   luci::CircleNode *visit(const luci::CircleFloorDiv *) final;
88   luci::CircleNode *visit(const luci::CircleFloorMod *) final;
89   luci::CircleNode *visit(const luci::CircleFullyConnected *) final;
90
91   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
92
93 protected:
94   loco::Graph *_graph = nullptr;
95 };
96
97 template <> class CloneNodeLet<CN::GHIJ> final : public luci::CircleNodeVisitor<luci::CircleNode *>
98 {
99 public:
100   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
101
102 public:
103   luci::CircleNode *visit(const luci::CircleGather *) final;
104   luci::CircleNode *visit(const luci::CircleGatherNd *) final;
105   luci::CircleNode *visit(const luci::CircleGelu *) final;
106   luci::CircleNode *visit(const luci::CircleGreater *) final;
107   luci::CircleNode *visit(const luci::CircleGreaterEqual *) final;
108   luci::CircleNode *visit(const luci::CircleHardSwish *) final;
109   luci::CircleNode *visit(const luci::CircleIf *) final;
110
111   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
112
113 protected:
114   loco::Graph *_graph = nullptr;
115 };
116
117 template <> class CloneNodeLet<CN::KLMN> final : public luci::CircleNodeVisitor<luci::CircleNode *>
118 {
119 public:
120   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
121
122 public:
123   luci::CircleNode *visit(const luci::CircleL2Normalize *) final;
124   luci::CircleNode *visit(const luci::CircleL2Pool2D *) final;
125   luci::CircleNode *visit(const luci::CircleLeakyRelu *) final;
126   luci::CircleNode *visit(const luci::CircleLess *) final;
127   luci::CircleNode *visit(const luci::CircleLessEqual *) final;
128   luci::CircleNode *visit(const luci::CircleLocalResponseNormalization *) final;
129   luci::CircleNode *visit(const luci::CircleLog *) final;
130   luci::CircleNode *visit(const luci::CircleLogicalAnd *) final;
131   luci::CircleNode *visit(const luci::CircleLogicalNot *) final;
132   luci::CircleNode *visit(const luci::CircleLogicalOr *) final;
133   luci::CircleNode *visit(const luci::CircleLogistic *) final;
134   luci::CircleNode *visit(const luci::CircleLogSoftmax *) final;
135   luci::CircleNode *visit(const luci::CircleMatrixDiag *) final;
136   luci::CircleNode *visit(const luci::CircleMatrixSetDiag *) final;
137   luci::CircleNode *visit(const luci::CircleMaximum *) final;
138   luci::CircleNode *visit(const luci::CircleMaxPool2D *) final;
139   luci::CircleNode *visit(const luci::CircleMean *) final;
140   luci::CircleNode *visit(const luci::CircleMinimum *) final;
141   luci::CircleNode *visit(const luci::CircleMirrorPad *) final;
142   luci::CircleNode *visit(const luci::CircleMul *) final;
143   luci::CircleNode *visit(const luci::CircleNeg *) final;
144   luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4 *) final;
145   luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5 *) final;
146   luci::CircleNode *visit(const luci::CircleNotEqual *) final;
147
148   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
149
150 protected:
151   loco::Graph *_graph = nullptr;
152 };
153
154 template <> class CloneNodeLet<CN::OPQR> final : public luci::CircleNodeVisitor<luci::CircleNode *>
155 {
156 public:
157   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
158
159 public:
160   luci::CircleNode *visit(const luci::CircleOneHot *) final;
161   luci::CircleNode *visit(const luci::CirclePack *) final;
162   luci::CircleNode *visit(const luci::CirclePad *) final;
163   luci::CircleNode *visit(const luci::CirclePadV2 *) final;
164   luci::CircleNode *visit(const luci::CirclePow *) final;
165   luci::CircleNode *visit(const luci::CirclePRelu *) final;
166   luci::CircleNode *visit(const luci::CircleQuantize *) final;
167   luci::CircleNode *visit(const luci::CircleRange *) final;
168   luci::CircleNode *visit(const luci::CircleRank *) final;
169   luci::CircleNode *visit(const luci::CircleReduceAny *) final;
170   luci::CircleNode *visit(const luci::CircleReduceMax *) final;
171   luci::CircleNode *visit(const luci::CircleReduceMin *) final;
172   luci::CircleNode *visit(const luci::CircleReduceProd *) final;
173   luci::CircleNode *visit(const luci::CircleRelu *) final;
174   luci::CircleNode *visit(const luci::CircleRelu6 *) final;
175   luci::CircleNode *visit(const luci::CircleReluN1To1 *) final;
176   luci::CircleNode *visit(const luci::CircleReshape *) final;
177   luci::CircleNode *visit(const luci::CircleResizeBilinear *) final;
178   luci::CircleNode *visit(const luci::CircleResizeNearestNeighbor *) final;
179   luci::CircleNode *visit(const luci::CircleReverseSequence *) final;
180   luci::CircleNode *visit(const luci::CircleReverseV2 *) final;
181   luci::CircleNode *visit(const luci::CircleRound *) final;
182   luci::CircleNode *visit(const luci::CircleRsqrt *) final;
183
184   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
185
186 protected:
187   loco::Graph *_graph = nullptr;
188 };
189
190 template <> class CloneNodeLet<CN::STUV> final : public luci::CircleNodeVisitor<luci::CircleNode *>
191 {
192 public:
193   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
194
195 public:
196   luci::CircleNode *visit(const luci::CircleScatterNd *) final;
197   luci::CircleNode *visit(const luci::CircleSegmentSum *) final;
198   luci::CircleNode *visit(const luci::CircleSelect *) final;
199   luci::CircleNode *visit(const luci::CircleSelectV2 *) final;
200   luci::CircleNode *visit(const luci::CircleShape *) final;
201   luci::CircleNode *visit(const luci::CircleSin *) final;
202   luci::CircleNode *visit(const luci::CircleSlice *) final;
203   luci::CircleNode *visit(const luci::CircleSoftmax *) final;
204   luci::CircleNode *visit(const luci::CircleSpaceToBatchND *) final;
205   luci::CircleNode *visit(const luci::CircleSpaceToDepth *) final;
206   luci::CircleNode *visit(const luci::CircleSparseToDense *) final;
207   luci::CircleNode *visit(const luci::CircleSplit *) final;
208   luci::CircleNode *visit(const luci::CircleSplitV *) final;
209   luci::CircleNode *visit(const luci::CircleSqrt *) final;
210   luci::CircleNode *visit(const luci::CircleSquare *) final;
211   luci::CircleNode *visit(const luci::CircleSquaredDifference *) final;
212   luci::CircleNode *visit(const luci::CircleSqueeze *) final;
213   luci::CircleNode *visit(const luci::CircleStridedSlice *) final;
214   luci::CircleNode *visit(const luci::CircleSVDF *) final;
215   luci::CircleNode *visit(const luci::CircleSub *) final;
216   luci::CircleNode *visit(const luci::CircleSum *) final;
217   luci::CircleNode *visit(const luci::CircleTanh *) final;
218   luci::CircleNode *visit(const luci::CircleTile *) final;
219   luci::CircleNode *visit(const luci::CircleTopKV2 *) final;
220   luci::CircleNode *visit(const luci::CircleTranspose *) final;
221   luci::CircleNode *visit(const luci::CircleTransposeConv *) final;
222   luci::CircleNode *visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
223   luci::CircleNode *visit(const luci::CircleUnique *) final;
224   luci::CircleNode *visit(const luci::CircleUnpack *) final;
225
226   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
227
228 protected:
229   loco::Graph *_graph = nullptr;
230 };
231
232 template <> class CloneNodeLet<CN::WXYZ> final : public luci::CircleNodeVisitor<luci::CircleNode *>
233 {
234 public:
235   CloneNodeLet(loco::Graph *graph) : _graph(graph){};
236
237 public:
238   luci::CircleNode *visit(const luci::CircleWhere *) final;
239   luci::CircleNode *visit(const luci::CircleWhile *) final;
240   luci::CircleNode *visit(const luci::CircleZerosLike *) final;
241
242   luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
243
244 protected:
245   loco::Graph *_graph = nullptr;
246 };
247
248 class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
249 {
250 public:
251   CloneNode(loco::Graph *graph) : _graph(graph){};
252
253 public:
254   // Circle Only
255   luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final;
256   luci::CircleNode *visit(const luci::CircleBCQGather *) final;
257   luci::CircleNode *visit(const luci::CircleInstanceNorm *) final;
258
259   // NOTE CircleInput and CircleOutput are not handled here as these need
260   //      link with graph I/O
261
262   // Virtual
263   luci::CircleNode *visit(const luci::CircleCustomOut *) final;
264   luci::CircleNode *visit(const luci::CircleIfOut *) final;
265   // luci::CircleNode *visit(const luci::CircleInput *) final;
266   luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4Out *) final;
267   luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5Out *) final;
268   // luci::CircleNode *visit(const luci::CircleOutput *) final;
269   luci::CircleNode *visit(const luci::CircleOutputDummy *) final;
270   luci::CircleNode *visit(const luci::CircleOutputExclude *) final;
271   luci::CircleNode *visit(const luci::CircleSplitOut *) final;
272   luci::CircleNode *visit(const luci::CircleSplitVOut *) final;
273   luci::CircleNode *visit(const luci::CircleTopKV2Out *) final;
274   luci::CircleNode *visit(const luci::CircleUniqueOut *) final;
275   luci::CircleNode *visit(const luci::CircleUnpackOut *) final;
276   luci::CircleNode *visit(const luci::CircleVariable *) final;
277   luci::CircleNode *visit(const luci::CircleWhileOut *) final;
278
279   // Handle in CircleNode
280   luci::CircleNode *visit(const luci::CircleNode *) final;
281
282   // NOTE CircleNodeVisitor will throw if not supported here
283
284 protected:
285   loco::Graph *_graph = nullptr;
286 };
287
288 } // namespace luci
289
290 #endif // __CIRCLE_CLONE_NODE_H__