2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __LUCI_PARTITION_CONNECT_NODE_H__
18 #define __LUCI_PARTITION_CONNECT_NODE_H__
20 #include <luci/IR/CircleNode.h>
21 #include <luci/IR/CircleNodeVisitor.h>
27 * @note MapNode2Clone is used as a map from original node to cloned node
28 * to find input of a cloned node
38 * From view of [C'] we need to find [A'] and [B']. We know [C] from [C'],
39 * then we can get from input of [C] as [A], [B] then [A]->[A'] and [B]->[B']
42 using MapNode2Clone = std::map<const CircleNode * /* ORG */, CircleNode * /* CLONE */>;
46 std::pair<MapNode2Clone::iterator, bool> emplace(const CircleNode *org, CircleNode *clone)
48 return node2clone.emplace(org, clone);
50 MapNode2Clone::iterator find(const CircleNode *org) { return node2clone.find(org); }
51 MapNode2Clone::iterator end(void) { return node2clone.end(); }
53 MapNode2Clone::const_iterator find(const CircleNode *org) const { return node2clone.find(org); }
54 MapNode2Clone::const_iterator end(void) const { return node2clone.end(); }
56 MapNode2Clone node2clone;
59 class ConnectNode final : public luci::CircleNodeVisitor<void>
62 ConnectNode(luci::CloneContext &clonecontext) : _clonecontext(clonecontext){};
65 void visit(const luci::CircleAbs *) final;
66 void visit(const luci::CircleAdd *) final;
67 void visit(const luci::CircleAddN *) final;
68 void visit(const luci::CircleArgMax *) final;
69 void visit(const luci::CircleArgMin *) final;
70 void visit(const luci::CircleAveragePool2D *) final;
71 void visit(const luci::CircleBatchMatMul *) final;
72 void visit(const luci::CircleBatchToSpaceND *) final;
73 void visit(const luci::CircleCast *) final;
74 void visit(const luci::CircleCeil *) final;
75 void visit(const luci::CircleConcatenation *) final;
76 void visit(const luci::CircleConst *) final;
77 void visit(const luci::CircleConv2D *) final;
78 void visit(const luci::CircleCos *) final;
79 void visit(const luci::CircleCustom *) final;
80 void visit(const luci::CircleDensify *) final;
81 void visit(const luci::CircleDepthToSpace *) final;
82 void visit(const luci::CircleDepthwiseConv2D *) final;
83 void visit(const luci::CircleDequantize *) final;
84 void visit(const luci::CircleDiv *) final;
85 void visit(const luci::CircleElu *) final;
86 void visit(const luci::CircleEqual *) final;
87 void visit(const luci::CircleExp *) final;
88 void visit(const luci::CircleExpandDims *) final;
89 void visit(const luci::CircleFakeQuant *) final;
90 void visit(const luci::CircleFill *) final;
91 void visit(const luci::CircleFloor *) final;
92 void visit(const luci::CircleFloorDiv *) final;
93 void visit(const luci::CircleFloorMod *) final;
94 void visit(const luci::CircleFullyConnected *) final;
95 void visit(const luci::CircleGather *) final;
96 void visit(const luci::CircleGatherNd *) final;
97 void visit(const luci::CircleGelu *) final;
98 void visit(const luci::CircleGreater *) final;
99 void visit(const luci::CircleGreaterEqual *) final;
100 void visit(const luci::CircleHardSwish *) final;
101 void visit(const luci::CircleIf *) final;
102 void visit(const luci::CircleL2Normalize *) final;
103 void visit(const luci::CircleL2Pool2D *) final;
104 void visit(const luci::CircleLeakyRelu *) final;
105 void visit(const luci::CircleLess *) final;
106 void visit(const luci::CircleLessEqual *) final;
107 void visit(const luci::CircleLocalResponseNormalization *) final;
108 void visit(const luci::CircleLog *) final;
109 void visit(const luci::CircleLogicalAnd *) final;
110 void visit(const luci::CircleLogicalNot *) final;
111 void visit(const luci::CircleLogicalOr *) final;
112 void visit(const luci::CircleLogistic *) final;
113 void visit(const luci::CircleLogSoftmax *) final;
114 void visit(const luci::CircleMatrixDiag *) final;
115 void visit(const luci::CircleMatrixSetDiag *) final;
116 void visit(const luci::CircleMaximum *) final;
117 void visit(const luci::CircleMaxPool2D *) final;
118 void visit(const luci::CircleMean *) final;
119 void visit(const luci::CircleMinimum *) final;
120 void visit(const luci::CircleMirrorPad *) final;
121 void visit(const luci::CircleMul *) final;
122 void visit(const luci::CircleNeg *) final;
123 void visit(const luci::CircleNonMaxSuppressionV4 *) final;
124 void visit(const luci::CircleNonMaxSuppressionV5 *) final;
125 void visit(const luci::CircleNotEqual *) final;
126 void visit(const luci::CircleOneHot *) final;
127 void visit(const luci::CirclePack *) final;
128 void visit(const luci::CirclePad *) final;
129 void visit(const luci::CirclePadV2 *) final;
130 void visit(const luci::CirclePow *) final;
131 void visit(const luci::CirclePRelu *) final;
132 void visit(const luci::CircleQuantize *) final;
133 void visit(const luci::CircleRange *) final;
134 void visit(const luci::CircleRank *) final;
135 void visit(const luci::CircleReduceAny *) final;
136 void visit(const luci::CircleReduceMax *) final;
137 void visit(const luci::CircleReduceMin *) final;
138 void visit(const luci::CircleReduceProd *) final;
139 void visit(const luci::CircleRelu *) final;
140 void visit(const luci::CircleRelu6 *) final;
141 void visit(const luci::CircleReluN1To1 *) final;
142 void visit(const luci::CircleReshape *) final;
143 void visit(const luci::CircleResizeBilinear *) final;
144 void visit(const luci::CircleResizeNearestNeighbor *) final;
145 void visit(const luci::CircleReverseSequence *) final;
146 void visit(const luci::CircleReverseV2 *) final;
147 void visit(const luci::CircleRound *) final;
148 void visit(const luci::CircleRsqrt *) final;
149 void visit(const luci::CircleScatterNd *) final;
150 void visit(const luci::CircleSegmentSum *) final;
151 void visit(const luci::CircleSelect *) final;
152 void visit(const luci::CircleSelectV2 *) final;
153 void visit(const luci::CircleShape *) final;
154 void visit(const luci::CircleSin *) final;
155 void visit(const luci::CircleSlice *) final;
156 void visit(const luci::CircleSoftmax *) final;
157 void visit(const luci::CircleSpaceToBatchND *) final;
158 void visit(const luci::CircleSpaceToDepth *) final;
159 void visit(const luci::CircleSparseToDense *) final;
160 void visit(const luci::CircleSplit *) final;
161 void visit(const luci::CircleSplitV *) final;
162 void visit(const luci::CircleSqrt *) final;
163 void visit(const luci::CircleSquare *) final;
164 void visit(const luci::CircleSquaredDifference *) final;
165 void visit(const luci::CircleSqueeze *) final;
166 void visit(const luci::CircleStridedSlice *) final;
167 void visit(const luci::CircleSVDF *) final;
168 void visit(const luci::CircleSub *) final;
169 void visit(const luci::CircleSum *) final;
170 void visit(const luci::CircleTanh *) final;
171 void visit(const luci::CircleTile *) final;
172 void visit(const luci::CircleTopKV2 *) final;
173 void visit(const luci::CircleTranspose *) final;
174 void visit(const luci::CircleTransposeConv *) final;
175 void visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
176 void visit(const luci::CircleUnique *) final;
177 void visit(const luci::CircleUnpack *) final;
178 void visit(const luci::CircleWhere *) final;
179 void visit(const luci::CircleWhile *) final;
180 void visit(const luci::CircleZerosLike *) final;
183 void visit(const luci::CircleBCQFullyConnected *) final;
184 void visit(const luci::CircleBCQGather *) final;
185 void visit(const luci::CircleInstanceNorm *) final;
187 // NOTE CircleInput and CircleOutput are not handled here as these need
188 // link with graph I/O
191 void visit(const luci::CircleCustomOut *) final;
192 void visit(const luci::CircleIfOut *) final;
193 // void visit(const luci::CircleInput *) final;
194 void visit(const luci::CircleNonMaxSuppressionV4Out *) final;
195 void visit(const luci::CircleNonMaxSuppressionV5Out *) final;
196 // void visit(const luci::CircleOutput *) final;
197 void visit(const luci::CircleOutputDummy *) final;
198 void visit(const luci::CircleOutputExclude *) final;
199 void visit(const luci::CircleSplitOut *) final;
200 void visit(const luci::CircleSplitVOut *) final;
201 void visit(const luci::CircleTopKV2Out *) final;
202 void visit(const luci::CircleUniqueOut *) final;
203 void visit(const luci::CircleUnpackOut *) final;
204 void visit(const luci::CircleVariable *) final;
205 void visit(const luci::CircleWhileOut *) final;
208 luci::CircleNode *find_clone(const luci::CircleNode *node);
211 luci::CloneContext &_clonecontext;
215 * @brief Connect cloned node from input node
217 void clone_connect(const luci::CircleNode *node, luci::CloneContext &clonecontext);
221 #endif // __LUCI_PARTITION_CONNECT_NODE_H__