Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / partition / include / luci / ConnectNode.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_PARTITION_CONNECT_NODE_H__
18 #define __LUCI_PARTITION_CONNECT_NODE_H__
19
20 #include <luci/IR/CircleNode.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22
23 namespace luci
24 {
25
26 /**
27  * @note MapNode2Clone is used as a map from original node to cloned node
28  *       to find input of a cloned node
29  *
30  *   (Original)              (Clone)
31  *
32  *     [A]                  [A']
33  *      |   [B]              |   [B']
34  *      |    |               |    |
35  *       \  /                 \  /
36  *        [C]                 [C']
37  *
38  *  From view of [C'] we need to find [A'] and [B']. We know [C] from [C'],
39  *  then we can get from input of [C] as [A], [B] then [A]->[A'] and [B]->[B']
40  *  from the map.
41  */
42 using MapNode2Clone = std::map<const CircleNode * /* ORG */, CircleNode * /* CLONE */>;
43
44 struct CloneContext
45 {
46   std::pair<MapNode2Clone::iterator, bool> emplace(const CircleNode *org, CircleNode *clone)
47   {
48     return node2clone.emplace(org, clone);
49   }
50   MapNode2Clone::iterator find(const CircleNode *org) { return node2clone.find(org); }
51   MapNode2Clone::iterator end(void) { return node2clone.end(); }
52
53   MapNode2Clone::const_iterator find(const CircleNode *org) const { return node2clone.find(org); }
54   MapNode2Clone::const_iterator end(void) const { return node2clone.end(); }
55
56   MapNode2Clone node2clone;
57 };
58
59 class ConnectNode final : public luci::CircleNodeVisitor<void>
60 {
61 public:
62   ConnectNode(luci::CloneContext &clonecontext) : _clonecontext(clonecontext){};
63
64 public:
65   void visit(const luci::CircleAbs *) final;
66   void visit(const luci::CircleAdd *) final;
67   void visit(const luci::CircleAddN *) final;
68   void visit(const luci::CircleArgMax *) final;
69   void visit(const luci::CircleArgMin *) final;
70   void visit(const luci::CircleAveragePool2D *) final;
71   void visit(const luci::CircleBatchMatMul *) final;
72   void visit(const luci::CircleBatchToSpaceND *) final;
73   void visit(const luci::CircleCast *) final;
74   void visit(const luci::CircleCeil *) final;
75   void visit(const luci::CircleConcatenation *) final;
76   void visit(const luci::CircleConst *) final;
77   void visit(const luci::CircleConv2D *) final;
78   void visit(const luci::CircleCos *) final;
79   void visit(const luci::CircleCustom *) final;
80   void visit(const luci::CircleDensify *) final;
81   void visit(const luci::CircleDepthToSpace *) final;
82   void visit(const luci::CircleDepthwiseConv2D *) final;
83   void visit(const luci::CircleDequantize *) final;
84   void visit(const luci::CircleDiv *) final;
85   void visit(const luci::CircleElu *) final;
86   void visit(const luci::CircleEqual *) final;
87   void visit(const luci::CircleExp *) final;
88   void visit(const luci::CircleExpandDims *) final;
89   void visit(const luci::CircleFakeQuant *) final;
90   void visit(const luci::CircleFill *) final;
91   void visit(const luci::CircleFloor *) final;
92   void visit(const luci::CircleFloorDiv *) final;
93   void visit(const luci::CircleFloorMod *) final;
94   void visit(const luci::CircleFullyConnected *) final;
95   void visit(const luci::CircleGather *) final;
96   void visit(const luci::CircleGatherNd *) final;
97   void visit(const luci::CircleGelu *) final;
98   void visit(const luci::CircleGreater *) final;
99   void visit(const luci::CircleGreaterEqual *) final;
100   void visit(const luci::CircleHardSwish *) final;
101   void visit(const luci::CircleIf *) final;
102   void visit(const luci::CircleL2Normalize *) final;
103   void visit(const luci::CircleL2Pool2D *) final;
104   void visit(const luci::CircleLeakyRelu *) final;
105   void visit(const luci::CircleLess *) final;
106   void visit(const luci::CircleLessEqual *) final;
107   void visit(const luci::CircleLocalResponseNormalization *) final;
108   void visit(const luci::CircleLog *) final;
109   void visit(const luci::CircleLogicalAnd *) final;
110   void visit(const luci::CircleLogicalNot *) final;
111   void visit(const luci::CircleLogicalOr *) final;
112   void visit(const luci::CircleLogistic *) final;
113   void visit(const luci::CircleLogSoftmax *) final;
114   void visit(const luci::CircleMatrixDiag *) final;
115   void visit(const luci::CircleMatrixSetDiag *) final;
116   void visit(const luci::CircleMaximum *) final;
117   void visit(const luci::CircleMaxPool2D *) final;
118   void visit(const luci::CircleMean *) final;
119   void visit(const luci::CircleMinimum *) final;
120   void visit(const luci::CircleMirrorPad *) final;
121   void visit(const luci::CircleMul *) final;
122   void visit(const luci::CircleNeg *) final;
123   void visit(const luci::CircleNonMaxSuppressionV4 *) final;
124   void visit(const luci::CircleNonMaxSuppressionV5 *) final;
125   void visit(const luci::CircleNotEqual *) final;
126   void visit(const luci::CircleOneHot *) final;
127   void visit(const luci::CirclePack *) final;
128   void visit(const luci::CirclePad *) final;
129   void visit(const luci::CirclePadV2 *) final;
130   void visit(const luci::CirclePow *) final;
131   void visit(const luci::CirclePRelu *) final;
132   void visit(const luci::CircleQuantize *) final;
133   void visit(const luci::CircleRange *) final;
134   void visit(const luci::CircleRank *) final;
135   void visit(const luci::CircleReduceAny *) final;
136   void visit(const luci::CircleReduceMax *) final;
137   void visit(const luci::CircleReduceMin *) final;
138   void visit(const luci::CircleReduceProd *) final;
139   void visit(const luci::CircleRelu *) final;
140   void visit(const luci::CircleRelu6 *) final;
141   void visit(const luci::CircleReluN1To1 *) final;
142   void visit(const luci::CircleReshape *) final;
143   void visit(const luci::CircleResizeBilinear *) final;
144   void visit(const luci::CircleResizeNearestNeighbor *) final;
145   void visit(const luci::CircleReverseSequence *) final;
146   void visit(const luci::CircleReverseV2 *) final;
147   void visit(const luci::CircleRound *) final;
148   void visit(const luci::CircleRsqrt *) final;
149   void visit(const luci::CircleScatterNd *) final;
150   void visit(const luci::CircleSegmentSum *) final;
151   void visit(const luci::CircleSelect *) final;
152   void visit(const luci::CircleSelectV2 *) final;
153   void visit(const luci::CircleShape *) final;
154   void visit(const luci::CircleSin *) final;
155   void visit(const luci::CircleSlice *) final;
156   void visit(const luci::CircleSoftmax *) final;
157   void visit(const luci::CircleSpaceToBatchND *) final;
158   void visit(const luci::CircleSpaceToDepth *) final;
159   void visit(const luci::CircleSparseToDense *) final;
160   void visit(const luci::CircleSplit *) final;
161   void visit(const luci::CircleSplitV *) final;
162   void visit(const luci::CircleSqrt *) final;
163   void visit(const luci::CircleSquare *) final;
164   void visit(const luci::CircleSquaredDifference *) final;
165   void visit(const luci::CircleSqueeze *) final;
166   void visit(const luci::CircleStridedSlice *) final;
167   void visit(const luci::CircleSVDF *) final;
168   void visit(const luci::CircleSub *) final;
169   void visit(const luci::CircleSum *) final;
170   void visit(const luci::CircleTanh *) final;
171   void visit(const luci::CircleTile *) final;
172   void visit(const luci::CircleTopKV2 *) final;
173   void visit(const luci::CircleTranspose *) final;
174   void visit(const luci::CircleTransposeConv *) final;
175   void visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
176   void visit(const luci::CircleUnique *) final;
177   void visit(const luci::CircleUnpack *) final;
178   void visit(const luci::CircleWhere *) final;
179   void visit(const luci::CircleWhile *) final;
180   void visit(const luci::CircleZerosLike *) final;
181
182   // Circle Only
183   void visit(const luci::CircleBCQFullyConnected *) final;
184   void visit(const luci::CircleBCQGather *) final;
185   void visit(const luci::CircleInstanceNorm *) final;
186
187   // NOTE CircleInput and CircleOutput are not handled here as these need
188   //      link with graph I/O
189
190   // Virtual
191   void visit(const luci::CircleCustomOut *) final;
192   void visit(const luci::CircleIfOut *) final;
193   // void visit(const luci::CircleInput *) final;
194   void visit(const luci::CircleNonMaxSuppressionV4Out *) final;
195   void visit(const luci::CircleNonMaxSuppressionV5Out *) final;
196   // void visit(const luci::CircleOutput *) final;
197   void visit(const luci::CircleOutputDummy *) final;
198   void visit(const luci::CircleOutputExclude *) final;
199   void visit(const luci::CircleSplitOut *) final;
200   void visit(const luci::CircleSplitVOut *) final;
201   void visit(const luci::CircleTopKV2Out *) final;
202   void visit(const luci::CircleUniqueOut *) final;
203   void visit(const luci::CircleUnpackOut *) final;
204   void visit(const luci::CircleVariable *) final;
205   void visit(const luci::CircleWhileOut *) final;
206
207 public:
208   luci::CircleNode *find_clone(const luci::CircleNode *node);
209
210 protected:
211   luci::CloneContext &_clonecontext;
212 };
213
214 /**
215  * @brief Connect cloned node from input node
216  */
217 void clone_connect(const luci::CircleNode *node, luci::CloneContext &clonecontext);
218
219 } // namespace luci
220
221 #endif // __LUCI_PARTITION_CONNECT_NODE_H__