2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __CIRCLE_CLONE_NODE_H__
18 #define __CIRCLE_CLONE_NODE_H__
20 #include <luci/IR/CircleNodes.h>
22 #include <luci/IR/CircleNodeVisitor.h>
39 template <CN ct> class CloneNodeLet;
41 template <> class CloneNodeLet<CN::ABC> final : public luci::CircleNodeVisitor<luci::CircleNode *>
44 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
47 luci::CircleNode *visit(const luci::CircleAbs *) final;
48 luci::CircleNode *visit(const luci::CircleAdd *) final;
49 luci::CircleNode *visit(const luci::CircleAddN *) final;
50 luci::CircleNode *visit(const luci::CircleArgMax *) final;
51 luci::CircleNode *visit(const luci::CircleArgMin *) final;
52 luci::CircleNode *visit(const luci::CircleAveragePool2D *) final;
53 luci::CircleNode *visit(const luci::CircleBatchMatMul *) final;
54 luci::CircleNode *visit(const luci::CircleBatchToSpaceND *) final;
55 luci::CircleNode *visit(const luci::CircleCast *) final;
56 luci::CircleNode *visit(const luci::CircleCeil *) final;
57 luci::CircleNode *visit(const luci::CircleConcatenation *) final;
58 luci::CircleNode *visit(const luci::CircleConst *) final;
59 luci::CircleNode *visit(const luci::CircleConv2D *) final;
60 luci::CircleNode *visit(const luci::CircleCos *) final;
61 luci::CircleNode *visit(const luci::CircleCustom *) final;
63 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
66 loco::Graph *_graph = nullptr;
69 template <> class CloneNodeLet<CN::DEF> final : public luci::CircleNodeVisitor<luci::CircleNode *>
72 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
75 luci::CircleNode *visit(const luci::CircleDensify *) final;
76 luci::CircleNode *visit(const luci::CircleDepthToSpace *) final;
77 luci::CircleNode *visit(const luci::CircleDepthwiseConv2D *) final;
78 luci::CircleNode *visit(const luci::CircleDequantize *) final;
79 luci::CircleNode *visit(const luci::CircleDiv *) final;
80 luci::CircleNode *visit(const luci::CircleElu *) final;
81 luci::CircleNode *visit(const luci::CircleEqual *) final;
82 luci::CircleNode *visit(const luci::CircleExp *) final;
83 luci::CircleNode *visit(const luci::CircleExpandDims *) final;
84 luci::CircleNode *visit(const luci::CircleFakeQuant *) final;
85 luci::CircleNode *visit(const luci::CircleFill *) final;
86 luci::CircleNode *visit(const luci::CircleFloor *) final;
87 luci::CircleNode *visit(const luci::CircleFloorDiv *) final;
88 luci::CircleNode *visit(const luci::CircleFloorMod *) final;
89 luci::CircleNode *visit(const luci::CircleFullyConnected *) final;
91 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
94 loco::Graph *_graph = nullptr;
97 template <> class CloneNodeLet<CN::GHIJ> final : public luci::CircleNodeVisitor<luci::CircleNode *>
100 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
103 luci::CircleNode *visit(const luci::CircleGather *) final;
104 luci::CircleNode *visit(const luci::CircleGatherNd *) final;
105 luci::CircleNode *visit(const luci::CircleGelu *) final;
106 luci::CircleNode *visit(const luci::CircleGreater *) final;
107 luci::CircleNode *visit(const luci::CircleGreaterEqual *) final;
108 luci::CircleNode *visit(const luci::CircleHardSwish *) final;
109 luci::CircleNode *visit(const luci::CircleIf *) final;
111 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
114 loco::Graph *_graph = nullptr;
117 template <> class CloneNodeLet<CN::KLMN> final : public luci::CircleNodeVisitor<luci::CircleNode *>
120 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
123 luci::CircleNode *visit(const luci::CircleL2Normalize *) final;
124 luci::CircleNode *visit(const luci::CircleL2Pool2D *) final;
125 luci::CircleNode *visit(const luci::CircleLeakyRelu *) final;
126 luci::CircleNode *visit(const luci::CircleLess *) final;
127 luci::CircleNode *visit(const luci::CircleLessEqual *) final;
128 luci::CircleNode *visit(const luci::CircleLocalResponseNormalization *) final;
129 luci::CircleNode *visit(const luci::CircleLog *) final;
130 luci::CircleNode *visit(const luci::CircleLogicalAnd *) final;
131 luci::CircleNode *visit(const luci::CircleLogicalNot *) final;
132 luci::CircleNode *visit(const luci::CircleLogicalOr *) final;
133 luci::CircleNode *visit(const luci::CircleLogistic *) final;
134 luci::CircleNode *visit(const luci::CircleLogSoftmax *) final;
135 luci::CircleNode *visit(const luci::CircleMatrixDiag *) final;
136 luci::CircleNode *visit(const luci::CircleMatrixSetDiag *) final;
137 luci::CircleNode *visit(const luci::CircleMaximum *) final;
138 luci::CircleNode *visit(const luci::CircleMaxPool2D *) final;
139 luci::CircleNode *visit(const luci::CircleMean *) final;
140 luci::CircleNode *visit(const luci::CircleMinimum *) final;
141 luci::CircleNode *visit(const luci::CircleMirrorPad *) final;
142 luci::CircleNode *visit(const luci::CircleMul *) final;
143 luci::CircleNode *visit(const luci::CircleNeg *) final;
144 luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4 *) final;
145 luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5 *) final;
146 luci::CircleNode *visit(const luci::CircleNotEqual *) final;
148 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
151 loco::Graph *_graph = nullptr;
154 template <> class CloneNodeLet<CN::OPQR> final : public luci::CircleNodeVisitor<luci::CircleNode *>
157 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
160 luci::CircleNode *visit(const luci::CircleOneHot *) final;
161 luci::CircleNode *visit(const luci::CirclePack *) final;
162 luci::CircleNode *visit(const luci::CirclePad *) final;
163 luci::CircleNode *visit(const luci::CirclePadV2 *) final;
164 luci::CircleNode *visit(const luci::CirclePow *) final;
165 luci::CircleNode *visit(const luci::CirclePRelu *) final;
166 luci::CircleNode *visit(const luci::CircleQuantize *) final;
167 luci::CircleNode *visit(const luci::CircleRange *) final;
168 luci::CircleNode *visit(const luci::CircleRank *) final;
169 luci::CircleNode *visit(const luci::CircleReduceAny *) final;
170 luci::CircleNode *visit(const luci::CircleReduceMax *) final;
171 luci::CircleNode *visit(const luci::CircleReduceMin *) final;
172 luci::CircleNode *visit(const luci::CircleReduceProd *) final;
173 luci::CircleNode *visit(const luci::CircleRelu *) final;
174 luci::CircleNode *visit(const luci::CircleRelu6 *) final;
175 luci::CircleNode *visit(const luci::CircleReluN1To1 *) final;
176 luci::CircleNode *visit(const luci::CircleReshape *) final;
177 luci::CircleNode *visit(const luci::CircleResizeBilinear *) final;
178 luci::CircleNode *visit(const luci::CircleResizeNearestNeighbor *) final;
179 luci::CircleNode *visit(const luci::CircleReverseSequence *) final;
180 luci::CircleNode *visit(const luci::CircleReverseV2 *) final;
181 luci::CircleNode *visit(const luci::CircleRound *) final;
182 luci::CircleNode *visit(const luci::CircleRsqrt *) final;
184 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
187 loco::Graph *_graph = nullptr;
190 template <> class CloneNodeLet<CN::STUV> final : public luci::CircleNodeVisitor<luci::CircleNode *>
193 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
196 luci::CircleNode *visit(const luci::CircleScatterNd *) final;
197 luci::CircleNode *visit(const luci::CircleSegmentSum *) final;
198 luci::CircleNode *visit(const luci::CircleSelect *) final;
199 luci::CircleNode *visit(const luci::CircleSelectV2 *) final;
200 luci::CircleNode *visit(const luci::CircleShape *) final;
201 luci::CircleNode *visit(const luci::CircleSin *) final;
202 luci::CircleNode *visit(const luci::CircleSlice *) final;
203 luci::CircleNode *visit(const luci::CircleSoftmax *) final;
204 luci::CircleNode *visit(const luci::CircleSpaceToBatchND *) final;
205 luci::CircleNode *visit(const luci::CircleSpaceToDepth *) final;
206 luci::CircleNode *visit(const luci::CircleSparseToDense *) final;
207 luci::CircleNode *visit(const luci::CircleSplit *) final;
208 luci::CircleNode *visit(const luci::CircleSplitV *) final;
209 luci::CircleNode *visit(const luci::CircleSqrt *) final;
210 luci::CircleNode *visit(const luci::CircleSquare *) final;
211 luci::CircleNode *visit(const luci::CircleSquaredDifference *) final;
212 luci::CircleNode *visit(const luci::CircleSqueeze *) final;
213 luci::CircleNode *visit(const luci::CircleStridedSlice *) final;
214 luci::CircleNode *visit(const luci::CircleSVDF *) final;
215 luci::CircleNode *visit(const luci::CircleSub *) final;
216 luci::CircleNode *visit(const luci::CircleSum *) final;
217 luci::CircleNode *visit(const luci::CircleTanh *) final;
218 luci::CircleNode *visit(const luci::CircleTile *) final;
219 luci::CircleNode *visit(const luci::CircleTopKV2 *) final;
220 luci::CircleNode *visit(const luci::CircleTranspose *) final;
221 luci::CircleNode *visit(const luci::CircleTransposeConv *) final;
222 luci::CircleNode *visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
223 luci::CircleNode *visit(const luci::CircleUnique *) final;
224 luci::CircleNode *visit(const luci::CircleUnpack *) final;
226 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
229 loco::Graph *_graph = nullptr;
232 template <> class CloneNodeLet<CN::WXYZ> final : public luci::CircleNodeVisitor<luci::CircleNode *>
235 CloneNodeLet(loco::Graph *graph) : _graph(graph){};
238 luci::CircleNode *visit(const luci::CircleWhere *) final;
239 luci::CircleNode *visit(const luci::CircleWhile *) final;
240 luci::CircleNode *visit(const luci::CircleZerosLike *) final;
242 luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
245 loco::Graph *_graph = nullptr;
248 class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
251 CloneNode(loco::Graph *graph) : _graph(graph){};
255 luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final;
256 luci::CircleNode *visit(const luci::CircleBCQGather *) final;
257 luci::CircleNode *visit(const luci::CircleInstanceNorm *) final;
259 // NOTE CircleInput and CircleOutput are not handled here as these need
260 // link with graph I/O
263 luci::CircleNode *visit(const luci::CircleCustomOut *) final;
264 luci::CircleNode *visit(const luci::CircleIfOut *) final;
265 // luci::CircleNode *visit(const luci::CircleInput *) final;
266 luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4Out *) final;
267 luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5Out *) final;
268 // luci::CircleNode *visit(const luci::CircleOutput *) final;
269 luci::CircleNode *visit(const luci::CircleOutputDummy *) final;
270 luci::CircleNode *visit(const luci::CircleOutputExclude *) final;
271 luci::CircleNode *visit(const luci::CircleSplitOut *) final;
272 luci::CircleNode *visit(const luci::CircleSplitVOut *) final;
273 luci::CircleNode *visit(const luci::CircleTopKV2Out *) final;
274 luci::CircleNode *visit(const luci::CircleUniqueOut *) final;
275 luci::CircleNode *visit(const luci::CircleUnpackOut *) final;
276 luci::CircleNode *visit(const luci::CircleVariable *) final;
277 luci::CircleNode *visit(const luci::CircleWhileOut *) final;
279 // Handle in CircleNode
280 luci::CircleNode *visit(const luci::CircleNode *) final;
282 // NOTE CircleNodeVisitor will throw if not supported here
285 loco::Graph *_graph = nullptr;
290 #endif // __CIRCLE_CLONE_NODE_H__