1420020364dd6cb8c6b78e64a890dabf2869ac8b
[platform/core/ml/nnfw.git] / onert-micro / helpers / GenerateKernelsListHelper.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci_interpreter/core/reader/CircleMicroReader.h"
18
19 #include <circle-generated/circle/schema_generated.h>
20
21 #include <iostream>
22 #include <fstream>
23 #include <set>
24
25 std::string get_register_kernel_str(const circle::BuiltinOperator builtin_operator)
26 {
27   switch (builtin_operator)
28   {
29     case circle::BuiltinOperator_ADD:
30       return "REGISTER_KERNEL(ADD, Add)";
31     case circle::BuiltinOperator_ARG_MAX:
32       return "REGISTER_KERNEL(ARG_MAX, ArgMax)";
33     case circle::BuiltinOperator_AVERAGE_POOL_2D:
34       return "REGISTER_KERNEL(AVERAGE_POOL_2D, AveragePool2D)";
35     case circle::BuiltinOperator_BATCH_TO_SPACE_ND:
36       return "REGISTER_KERNEL(BATCH_TO_SPACE_ND, BatchToSpaceND)";
37     case circle::BuiltinOperator_CAST:
38       return "REGISTER_KERNEL(CAST, Cast)";
39     case circle::BuiltinOperator_CONCATENATION:
40       return "REGISTER_KERNEL(CONCATENATION, Concatenation)";
41     case circle::BuiltinOperator_CONV_2D:
42       return "REGISTER_KERNEL(CONV_2D, Conv2D)";
43     case circle::BuiltinOperator_DEPTH_TO_SPACE:
44       return "REGISTER_KERNEL(DEPTH_TO_SPACE, DepthToSpace)";
45     case circle::BuiltinOperator_DEPTHWISE_CONV_2D:
46       return "REGISTER_KERNEL(DEPTHWISE_CONV_2D, DepthwiseConv2D)";
47     case circle::BuiltinOperator_DEQUANTIZE:
48       return "REGISTER_KERNEL(DEQUANTIZE, Dequantize)";
49     case circle::BuiltinOperator_DIV:
50       return "REGISTER_KERNEL(DIV, Div)";
51     case circle::BuiltinOperator_ELU:
52       return "REGISTER_KERNEL(ELU, Elu)";
53     case circle::BuiltinOperator_EXP:
54       return "REGISTER_KERNEL(EXP, Exp)";
55     case circle::BuiltinOperator_EXPAND_DIMS:
56       return "REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)";
57     case circle::BuiltinOperator_FILL:
58       return "REGISTER_KERNEL(FILL, Fill)";
59     case circle::BuiltinOperator_FLOOR:
60       return "REGISTER_KERNEL(FLOOR, Floor)";
61     case circle::BuiltinOperator_FLOOR_DIV:
62       return "REGISTER_KERNEL(FLOOR_DIV, FloorDiv)";
63     case circle::BuiltinOperator_EQUAL:
64       return "REGISTER_KERNEL(EQUAL, Equal)";
65     case circle::BuiltinOperator_FULLY_CONNECTED:
66       return "REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)";
67     case circle::BuiltinOperator_GREATER:
68       return "REGISTER_KERNEL(GREATER, Greater)";
69     case circle::BuiltinOperator_GREATER_EQUAL:
70       return "REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)";
71     case circle::BuiltinOperator_INSTANCE_NORM:
72       return "REGISTER_KERNEL(INSTANCE_NORM, InstanceNorm)";
73     case circle::BuiltinOperator_L2_NORMALIZATION:
74       return "REGISTER_KERNEL(L2_NORMALIZATION, L2Normalize)";
75     case circle::BuiltinOperator_L2_POOL_2D:
76       return "REGISTER_KERNEL(L2_POOL_2D, L2Pool2D)";
77     case circle::BuiltinOperator_LEAKY_RELU:
78       return "REGISTER_KERNEL(LEAKY_RELU, LeakyRelu)";
79     case circle::BuiltinOperator_LESS:
80       return "REGISTER_KERNEL(LESS, Less)";
81     case circle::BuiltinOperator_LESS_EQUAL:
82       return "REGISTER_KERNEL(LESS_EQUAL, LessEqual)";
83     case circle::BuiltinOperator_LOGICAL_AND:
84       return "REGISTER_KERNEL(LOGICAL_AND, LogicalAnd)";
85     case circle::BuiltinOperator_LOGICAL_NOT:
86       return "REGISTER_KERNEL(LOGICAL_NOT, LogicalNot)";
87     case circle::BuiltinOperator_LOGICAL_OR:
88       return "REGISTER_KERNEL(LOGICAL_OR, LogicalOr)";
89     case circle::BuiltinOperator_LOGISTIC:
90       return "REGISTER_KERNEL(LOGISTIC, Logistic)";
91     case circle::BuiltinOperator_MAXIMUM:
92       return "REGISTER_KERNEL(MAXIMUM, Maximum)";
93     case circle::BuiltinOperator_MAX_POOL_2D:
94       return "REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)";
95     case circle::BuiltinOperator_MINIMUM:
96       return "REGISTER_KERNEL(MINIMUM, Minimum)";
97     case circle::BuiltinOperator_MIRROR_PAD:
98       return "REGISTER_KERNEL(MIRROR_PAD, MirrorPad)";
99     case circle::BuiltinOperator_MUL:
100       return "REGISTER_KERNEL(MUL, Mul)";
101     case circle::BuiltinOperator_NEG:
102       return "REGISTER_KERNEL(NEG, Neg)";
103     case circle::BuiltinOperator_NOT_EQUAL:
104       return "REGISTER_KERNEL(NOT_EQUAL, NotEqual)";
105     case circle::BuiltinOperator_PAD:
106       return "REGISTER_KERNEL(PAD, Pad)";
107     case circle::BuiltinOperator_PADV2:
108       return "REGISTER_KERNEL(PADV2, PadV2)";
109     case circle::BuiltinOperator_PRELU:
110       return "REGISTER_KERNEL(PRELU, PRelu)";
111     case circle::BuiltinOperator_QUANTIZE:
112       return "REGISTER_KERNEL(QUANTIZE, Quantize)";
113     case circle::BuiltinOperator_RESHAPE:
114       return "REGISTER_KERNEL(RESHAPE, Reshape)";
115     case circle::BuiltinOperator_RESIZE_BILINEAR:
116       return "REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)";
117     case circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
118       return "REGISTER_KERNEL(RESIZE_NEAREST_NEIGHBOR, ResizeNearestNeighbor)";
119     case circle::BuiltinOperator_RSQRT:
120       return "REGISTER_KERNEL(RSQRT, Rsqrt)";
121     case circle::BuiltinOperator_SHAPE:
122       return "REGISTER_KERNEL(SHAPE, Shape)";
123     case circle::BuiltinOperator_SOFTMAX:
124       return "REGISTER_KERNEL(SOFTMAX, Softmax)";
125     case circle::BuiltinOperator_SPACE_TO_BATCH_ND:
126       return "REGISTER_KERNEL(SPACE_TO_BATCH_ND, SpaceToBatchND)";
127     case circle::BuiltinOperator_SPACE_TO_DEPTH:
128       return "REGISTER_KERNEL(SPACE_TO_DEPTH, SpaceToDepth)";
129     case circle::BuiltinOperator_STRIDED_SLICE:
130       return "REGISTER_KERNEL(STRIDED_SLICE, StridedSlice)";
131     case circle::BuiltinOperator_SQRT:
132       return "REGISTER_KERNEL(SQRT, Sqrt)";
133     case circle::BuiltinOperator_SQUARE:
134       return "REGISTER_KERNEL(SQUARE, Square)";
135     case circle::BuiltinOperator_SQUARED_DIFFERENCE:
136       return "REGISTER_KERNEL(SQUARED_DIFFERENCE, SquaredDifference)";
137     case circle::BuiltinOperator_SQUEEZE:
138       return "REGISTER_KERNEL(SQUEEZE, Squeeze)";
139     case circle::BuiltinOperator_SUB:
140       return "REGISTER_KERNEL(SUB, Sub)";
141     case circle::BuiltinOperator_SVDF:
142       return "REGISTER_KERNEL(SVDF, SVDF)";
143     case circle::BuiltinOperator_TANH:
144       return "REGISTER_KERNEL(TANH, Tanh)";
145     case circle::BuiltinOperator_TRANSPOSE:
146       return "REGISTER_KERNEL(TRANSPOSE, Transpose)";
147     case circle::BuiltinOperator_TRANSPOSE_CONV:
148       return "REGISTER_KERNEL(TRANSPOSE_CONV, TransposeConv)";
149     case circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
150       return "REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM)";
151     default:
152       assert(false && "Not supported kernel");
153   }
154 }
155
156 std::vector<char> loadFile(const std::string &path)
157 {
158   std::ifstream file(path, std::ios::binary | std::ios::in);
159   if (!file.good())
160   {
161     assert(false && "Failed to open file");
162   }
163
164   file.unsetf(std::ios::skipws);
165
166   file.seekg(0, std::ios::end);
167   auto fileSize = file.tellg();
168   file.seekg(0, std::ios::beg);
169
170   // reserve capacity
171   std::vector<char> data(fileSize);
172
173   // read the data
174   file.read(data.data(), fileSize);
175   if (file.fail())
176   {
177     assert(false && "Failed to read file");
178   }
179
180   return data;
181 }
182
183 // Parse model and write to std::ofstream &os models operations
184 void run(std::ofstream &os, const circle::Model *model)
185 {
186   luci_interpreter::CircleReader reader;
187   reader.parse(model);
188   const uint32_t subgraph_size = reader.num_subgraph();
189
190   // Set to avoid duplication in generated list
191   std::set<circle::BuiltinOperator> operations_set;
192
193   for (uint32_t g = 0; g < subgraph_size; g++)
194   {
195     reader.select_subgraph(g);
196     auto ops = reader.operators();
197     for (uint32_t i = 0; i < ops.size(); ++i)
198     {
199       const auto op = ops.at(i);
200       auto op_builtin_operator = reader.builtin_code(op);
201
202       auto result = operations_set.insert(op_builtin_operator);
203       if (result.second)
204       {
205         os << get_register_kernel_str(op_builtin_operator) << std::endl;
206       }
207     }
208   }
209 }
210
211 int main(int argc, char **argv)
212 {
213   if (argc != 3)
214   {
215     assert(false && "Should be 2 arguments: circle model path, and path for generated model\n");
216   }
217
218   std::string model_file(argv[1]);
219   std::string generated_file_path(argv[2]);
220
221   std::vector<char> model_data = loadFile(model_file);
222   const circle::Model *circle_model = circle::GetModel(model_data.data());
223
224   if (circle_model == nullptr)
225   {
226     std::cerr << "ERROR: Failed to load circle '" << model_file << "'" << std::endl;
227     return 255;
228   }
229
230   // Open or create file
231   std::ofstream out;
232   out.open(generated_file_path);
233
234   if (out.is_open())
235     run(out, circle_model);
236   else
237     std::cout << "SMTH GOES WRONG WHILE OPEN FILE" << std::endl;
238   return 0;
239 }