Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / service / include / luci / Service / CircleShapeInference.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_H__
18 #define __LUCI_CIRCLE_SHAPE_INFERENCE_H__
19
20 #include "ShapeDescription.h"
21
22 #include <loco/IR/Nodes.h>
23
24 #include <luci/IR/CircleNodes.h>
25 #include <luci/IR/CircleNodeVisitor.h>
26 #include <luci/Service/CircleShapeInferenceHelper.h>
27
28 namespace luci
29 {
30
31 /**
32  * @brief Get the shape of each node as a node annotation
33  *
34  * HOW TO USE
35  *
36  *   ShapeInference::get(g->nodes()->at(..));
37  */
38 struct ShapeInference
39 {
40   static ShapeDescription get(loco::Node *node);
41 };
42
43 namespace sinf // namespace for Shape Inference
44 {
45
46 struct Rule
47 {
48   bool infer(const luci::CircleNode *, loco::TensorShape &) const;
49 };
50
51 class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
52 {
53 public:
54   // TODO Remove this when all of visit function is implemented
55   loco::TensorShape visit(const luci::CircleNode *node) final { return sinf::circle_shape(node); }
56
57   // loco::TensorShape visit(const luci::CircleAbs *node) final;
58   // loco::TensorShape visit(const luci::CircleAdd *node) final;
59   // loco::TensorShape visit(const luci::CircleAddN *node) final;
60   // loco::TensorShape visit(const luci::CircleArgMax *node) final;
61   // loco::TensorShape visit(const luci::CircleArgMin *node) final;
62   // loco::TensorShape visit(const luci::CircleAveragePool2D *node) final;
63   // loco::TensorShape visit(const luci::CircleBatchMatMul *node) final;
64   // loco::TensorShape visit(const luci::CircleBatchToSpaceND *node) final;
65   // loco::TensorShape visit(const luci::CircleCast *node) final;
66   // loco::TensorShape visit(const luci::CircleCeil *node) final;
67   // loco::TensorShape visit(const luci::CircleConcatenation *node) final;
68   // loco::TensorShape visit(const luci::CircleConst *node) final;
69   // loco::TensorShape visit(const luci::CircleConv2D *node) final;
70   // loco::TensorShape visit(const luci::CircleCos *node) final;
71   // loco::TensorShape visit(const luci::CircleCustom *node) final;
72   // loco::TensorShape visit(const luci::CircleDepthToSpace *node) final;
73   // loco::TensorShape visit(const luci::CircleDepthwiseConv2D *node) final;
74   // loco::TensorShape visit(const luci::CircleDequantize *node) final;
75   // loco::TensorShape visit(const luci::CircleDiv *node) final;
76   // loco::TensorShape visit(const luci::CircleElu *node) final;
77   // loco::TensorShape visit(const luci::CircleEqual *node) final;
78   // loco::TensorShape visit(const luci::CircleExp *node) final;
79   // loco::TensorShape visit(const luci::CircleExpandDims *node) final;
80   // loco::TensorShape visit(const luci::CircleFill *node) final;
81   // loco::TensorShape visit(const luci::CircleFloor *node) final;
82   // loco::TensorShape visit(const luci::CircleFloorDiv *node) final;
83   // loco::TensorShape visit(const luci::CircleFloorMod *node) final;
84   // loco::TensorShape visit(const luci::CircleFullyConnected *node) final;
85   // loco::TensorShape visit(const luci::CircleGather *node) final;
86   // loco::TensorShape visit(const luci::CircleGatherNd *node) final;
87   // loco::TensorShape visit(const luci::CircleGreater *node) final;
88   // loco::TensorShape visit(const luci::CircleGreaterEqual *node) final;
89   // loco::TensorShape visit(const luci::CircleIf *node) final;
90   // loco::TensorShape visit(const luci::CircleL2Normalize *node) final;
91   // loco::TensorShape visit(const luci::CircleL2Pool2D *node) final;
92   // loco::TensorShape visit(const luci::CircleLeakyRelu *node) final;
93   // loco::TensorShape visit(const luci::CircleLess *node) final;
94   // loco::TensorShape visit(const luci::CircleLessEqual *node) final;
95   // loco::TensorShape visit(const luci::CircleLocalResponseNormalization *node) final;
96   // loco::TensorShape visit(const luci::CircleLog *node) final;
97   // loco::TensorShape visit(const luci::CircleLogicalAnd *node) final;
98   // loco::TensorShape visit(const luci::CircleLogicalNot *node) final;
99   // loco::TensorShape visit(const luci::CircleLogicalOr *node) final;
100   // loco::TensorShape visit(const luci::CircleLogistic *node) final;
101   // loco::TensorShape visit(const luci::CircleLogSoftmax *node) final;
102   // loco::TensorShape visit(const luci::CircleMatrixDiag *node) final;
103   // loco::TensorShape visit(const luci::CircleMatrixSetDiag *node) final;
104   // loco::TensorShape visit(const luci::CircleMaximum *node) final;
105   // loco::TensorShape visit(const luci::CircleMaxPool2D *node) final;
106   // loco::TensorShape visit(const luci::CircleMean *node) final;
107   // loco::TensorShape visit(const luci::CircleMinimum *node) final;
108   // loco::TensorShape visit(const luci::CircleMirrorPad *node) final;
109   // loco::TensorShape visit(const luci::CircleNeg *node) final;
110   // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final;
111   // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final;
112   // loco::TensorShape visit(const luci::CircleNotEqual *node) final;
113   // loco::TensorShape visit(const luci::CirclePack *node) final;
114   // loco::TensorShape visit(const luci::CirclePad *node) final;
115   // loco::TensorShape visit(const luci::CirclePadV2 *node) final;
116   // loco::TensorShape visit(const luci::CirclePow *node) final;
117   // loco::TensorShape visit(const luci::CirclePRelu *node) final;
118   // loco::TensorShape visit(const luci::CircleRange *node) final;
119   // loco::TensorShape visit(const luci::CircleRank *node) final;
120   // loco::TensorShape visit(const luci::CircleMul *node) final;
121   // loco::TensorShape visit(const luci::CircleOneHot *node) final;
122   // loco::TensorShape visit(const luci::CircleReduceAny *node) final;
123   // loco::TensorShape visit(const luci::CircleReduceMax *node) final;
124   // loco::TensorShape visit(const luci::CircleReduceMin *node) final;
125   // loco::TensorShape visit(const luci::CircleReduceProd *node) final;
126   // loco::TensorShape visit(const luci::CircleRelu *node) final;
127   // loco::TensorShape visit(const luci::CircleRelu6 *node) final;
128   // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final;
129   // loco::TensorShape visit(const luci::CircleReshape *node) final;
130   // loco::TensorShape visit(const luci::CircleResizeBilinear *node) final;
131   // loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final;
132   // loco::TensorShape visit(const luci::CircleReverseSequence *node) final;
133   // loco::TensorShape visit(const luci::CircleReverseV2 *node) final;
134   // loco::TensorShape visit(const luci::CircleRound *node) final;
135   // loco::TensorShape visit(const luci::CircleRsqrt *node) final;
136   // loco::TensorShape visit(const luci::CircleScatterNd *node) final;
137   // loco::TensorShape visit(const luci::CircleSegmentSum *node) final;
138   // loco::TensorShape visit(const luci::CircleSelect *node) final;
139   // loco::TensorShape visit(const luci::CircleSelectV2 *node) final;
140   // loco::TensorShape visit(const luci::CircleShape *node) final;
141   // loco::TensorShape visit(const luci::CircleSin *node) final;
142   // loco::TensorShape visit(const luci::CircleSlice *node) final;
143   // loco::TensorShape visit(const luci::CircleSoftmax *node) final;
144   // loco::TensorShape visit(const luci::CircleSpaceToBatchND *node) final;
145   // loco::TensorShape visit(const luci::CircleSpaceToDepth *node) final;
146   // loco::TensorShape visit(const luci::CircleSparseToDense *node) final;
147   // loco::TensorShape visit(const luci::CircleSplit *node) final;
148   // loco::TensorShape visit(const luci::CircleSplitV *node) final;
149   // loco::TensorShape visit(const luci::CircleSqrt *node) final;
150   // loco::TensorShape visit(const luci::CircleSquare *node) final;
151   // loco::TensorShape visit(const luci::CircleSquaredDifference *node) final;
152   // loco::TensorShape visit(const luci::CircleSqueeze *node) final;
153   // loco::TensorShape visit(const luci::CircleStridedSlice *node) final;
154   // loco::TensorShape visit(const luci::CircleSub *node) final;
155   // loco::TensorShape visit(const luci::CircleSum *node) final;
156   // loco::TensorShape visit(const luci::CircleTanh *node) final;
157   // loco::TensorShape visit(const luci::CircleTile *node) final;
158   // loco::TensorShape visit(const luci::CircleTopKV2 *node) final;
159   // loco::TensorShape visit(const luci::CircleTranspose *node) final;
160   // loco::TensorShape visit(const luci::CircleTransposeConv *node) final;
161   // loco::TensorShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final;
162   // loco::TensorShape visit(const luci::CircleUnique *node) final;
163   // loco::TensorShape visit(const luci::CircleUnpack *node) final;
164   // loco::TensorShape visit(const luci::CircleWhere *node) final;
165   // loco::TensorShape visit(const luci::CircleWhile *node) final;
166   // loco::TensorShape visit(const luci::CircleZerosLike *node) final;
167
168   // Circle Only
169   // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final;
170   // loco::TensorShape visit(const luci::CircleBCQGather *node) final;
171   // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final;
172
173   // Virtual
174   // loco::TensorShape visit(const luci::CircleInput *node) final;
175   // loco::TensorShape visit(const luci::CircleOutput *node) final;
176   // loco::TensorShape visit(const luci::CircleOutputDummy *node) final;
177   // loco::TensorShape visit(const luci::CircleOutputExclude *node) final;
178   // loco::TensorShape visit(const luci::CircleCustomOut *node) final;
179   // loco::TensorShape visit(const luci::CircleIfOut *node) final;
180   // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
181   // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
182   // loco::TensorShape visit(const luci::CircleSplitOut *node) final;
183   // loco::TensorShape visit(const luci::CircleSplitVOut *node) final;
184   // loco::TensorShape visit(const luci::CircleTopKV2Out *node) final;
185   // loco::TensorShape visit(const luci::CircleUniqueOut *node) final;
186   // loco::TensorShape visit(const luci::CircleUnpackOut *node) final;
187   // loco::TensorShape visit(const luci::CircleWhileOut *node) final;
188 };
189
190 } // namespace sinf
191
192 } // namespace luci
193
194 #endif // __LUCI_CIRCLE_SHAPE_INFERENCE_H__