Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / helpers / GenerateKernelsListHelper.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci_interpreter/core/reader/CircleMicroReader.h"
18
19 #include <circle-generated/circle/schema_generated.h>
20
21 #include <iostream>
22 #include <fstream>
23 #include <set>
24
25 std::string get_register_kernel_str(const circle::BuiltinOperator builtin_operator)
26 {
27   switch (builtin_operator)
28   {
29     case circle::BuiltinOperator_ADD:
30       return "REGISTER_KERNEL(ADD, Add)";
31     case circle::BuiltinOperator_ARG_MAX:
32       return "REGISTER_KERNEL(ARG_MAX, ArgMax)";
33     case circle::BuiltinOperator_AVERAGE_POOL_2D:
34       return "REGISTER_KERNEL(AVERAGE_POOL_2D, AveragePool2D)";
35     case circle::BuiltinOperator_BATCH_TO_SPACE_ND:
36       return "REGISTER_KERNEL(BATCH_TO_SPACE_ND, BatchToSpaceND)";
37     case circle::BuiltinOperator_CAST:
38       return "REGISTER_KERNEL(CAST, Cast)";
39     case circle::BuiltinOperator_CONCATENATION:
40       return "REGISTER_KERNEL(CONCATENATION, Concatenation)";
41     case circle::BuiltinOperator_CONV_2D:
42       return "REGISTER_KERNEL(CONV_2D, Conv2D)";
43     case circle::BuiltinOperator_DEPTH_TO_SPACE:
44       return "REGISTER_KERNEL(DEPTH_TO_SPACE, DepthToSpace)";
45     case circle::BuiltinOperator_DEPTHWISE_CONV_2D:
46       return "REGISTER_KERNEL(DEPTHWISE_CONV_2D, DepthwiseConv2D)";
47     case circle::BuiltinOperator_DEQUANTIZE:
48       return "REGISTER_KERNEL(DEQUANTIZE, Dequantize)";
49     case circle::BuiltinOperator_DIV:
50       return "REGISTER_KERNEL(DIV, Div)";
51     case circle::BuiltinOperator_ELU:
52       return "REGISTER_KERNEL(ELU, Elu)";
53     case circle::BuiltinOperator_EXP:
54       return "REGISTER_KERNEL(EXP, Exp)";
55     case circle::BuiltinOperator_EXPAND_DIMS:
56       return "REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)";
57     case circle::BuiltinOperator_FILL:
58       return "REGISTER_KERNEL(FILL, Fill)";
59     case circle::BuiltinOperator_FLOOR:
60       return "REGISTER_KERNEL(FLOOR, Floor)";
61     case circle::BuiltinOperator_FLOOR_DIV:
62       return "REGISTER_KERNEL(FLOOR_DIV, FloorDiv)";
63     case circle::BuiltinOperator_EQUAL:
64       return "REGISTER_KERNEL(EQUAL, Equal)";
65     case circle::BuiltinOperator_FULLY_CONNECTED:
66       return "REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)";
67     case circle::BuiltinOperator_GREATER:
68       return "REGISTER_KERNEL(GREATER, Greater)";
69     case circle::BuiltinOperator_GREATER_EQUAL:
70       return "REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)";
71     case circle::BuiltinOperator_INSTANCE_NORM:
72       return "REGISTER_KERNEL(INSTANCE_NORM, InstanceNorm)";
73     case circle::BuiltinOperator_L2_NORMALIZATION:
74       return "REGISTER_KERNEL(L2_NORMALIZATION, L2Normalize)";
75     case circle::BuiltinOperator_L2_POOL_2D:
76       return "REGISTER_KERNEL(L2_POOL_2D, L2Pool2D)";
77     case circle::BuiltinOperator_LEAKY_RELU:
78       return "REGISTER_KERNEL(LEAKY_RELU, LeakyRelu)";
79     case circle::BuiltinOperator_LESS:
80       return "REGISTER_KERNEL(LESS, Less)";
81     case circle::BuiltinOperator_LESS_EQUAL:
82       return "REGISTER_KERNEL(LESS_EQUAL, LessEqual)";
83     case circle::BuiltinOperator_LOGICAL_AND:
84       return "REGISTER_KERNEL(LOGICAL_AND, LogicalAnd)";
85     case circle::BuiltinOperator_LOGICAL_NOT:
86       return "REGISTER_KERNEL(LOGICAL_NOT, LogicalNot)";
87     case circle::BuiltinOperator_LOGICAL_OR:
88       return "REGISTER_KERNEL(LOGICAL_OR, LogicalOr)";
89     case circle::BuiltinOperator_LOGISTIC:
90       return "REGISTER_KERNEL(LOGISTIC, Logistic)";
91     case circle::BuiltinOperator_GATHER:
92       return "REGISTER_KERNEL(GATHER, Gather)";
93     case circle::BuiltinOperator_MAXIMUM:
94       return "REGISTER_KERNEL(MAXIMUM, Maximum)";
95     case circle::BuiltinOperator_MAX_POOL_2D:
96       return "REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)";
97     case circle::BuiltinOperator_MINIMUM:
98       return "REGISTER_KERNEL(MINIMUM, Minimum)";
99     case circle::BuiltinOperator_MIRROR_PAD:
100       return "REGISTER_KERNEL(MIRROR_PAD, MirrorPad)";
101     case circle::BuiltinOperator_MUL:
102       return "REGISTER_KERNEL(MUL, Mul)";
103     case circle::BuiltinOperator_NEG:
104       return "REGISTER_KERNEL(NEG, Neg)";
105     case circle::BuiltinOperator_NOT_EQUAL:
106       return "REGISTER_KERNEL(NOT_EQUAL, NotEqual)";
107     case circle::BuiltinOperator_PAD:
108       return "REGISTER_KERNEL(PAD, Pad)";
109     case circle::BuiltinOperator_PADV2:
110       return "REGISTER_KERNEL(PADV2, PadV2)";
111     case circle::BuiltinOperator_PACK:
112       return "REGISTER_KERNEL(PACK, Pack)";
113     case circle::BuiltinOperator_PRELU:
114       return "REGISTER_KERNEL(PRELU, PRelu)";
115     case circle::BuiltinOperator_QUANTIZE:
116       return "REGISTER_KERNEL(QUANTIZE, Quantize)";
117     case circle::BuiltinOperator_REDUCE_PROD:
118       return "REGISTER_KERNEL(REDUCE_PROD, ReduceCommon)";
119     case circle::BuiltinOperator_RESHAPE:
120       return "REGISTER_KERNEL(RESHAPE, Reshape)";
121     case circle::BuiltinOperator_RESIZE_BILINEAR:
122       return "REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)";
123     case circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
124       return "REGISTER_KERNEL(RESIZE_NEAREST_NEIGHBOR, ResizeNearestNeighbor)";
125     case circle::BuiltinOperator_RSQRT:
126       return "REGISTER_KERNEL(RSQRT, Rsqrt)";
127     case circle::BuiltinOperator_SHAPE:
128       return "REGISTER_KERNEL(SHAPE, Shape)";
129     case circle::BuiltinOperator_SOFTMAX:
130       return "REGISTER_KERNEL(SOFTMAX, Softmax)";
131     case circle::BuiltinOperator_SPACE_TO_BATCH_ND:
132       return "REGISTER_KERNEL(SPACE_TO_BATCH_ND, SpaceToBatchND)";
133     case circle::BuiltinOperator_SPACE_TO_DEPTH:
134       return "REGISTER_KERNEL(SPACE_TO_DEPTH, SpaceToDepth)";
135     case circle::BuiltinOperator_SLICE:
136       return "REGISTER_KERNEL(SLICE, Slice)";
137     case circle::BuiltinOperator_STRIDED_SLICE:
138       return "REGISTER_KERNEL(STRIDED_SLICE, StridedSlice)";
139     case circle::BuiltinOperator_SQRT:
140       return "REGISTER_KERNEL(SQRT, Sqrt)";
141     case circle::BuiltinOperator_SQUARE:
142       return "REGISTER_KERNEL(SQUARE, Square)";
143     case circle::BuiltinOperator_SQUARED_DIFFERENCE:
144       return "REGISTER_KERNEL(SQUARED_DIFFERENCE, SquaredDifference)";
145     case circle::BuiltinOperator_SQUEEZE:
146       return "REGISTER_KERNEL(SQUEEZE, Squeeze)";
147     case circle::BuiltinOperator_SUB:
148       return "REGISTER_KERNEL(SUB, Sub)";
149     case circle::BuiltinOperator_SVDF:
150       return "REGISTER_KERNEL(SVDF, SVDF)";
151     case circle::BuiltinOperator_SPLIT:
152       return "REGISTER_KERNEL(SPLIT, Split)";
153     case circle::BuiltinOperator_SPLIT_V:
154       return "REGISTER_KERNEL(SPLIT_V, SplitV)";
155     case circle::BuiltinOperator_TANH:
156       return "REGISTER_KERNEL(TANH, Tanh)";
157     case circle::BuiltinOperator_TRANSPOSE:
158       return "REGISTER_KERNEL(TRANSPOSE, Transpose)";
159     case circle::BuiltinOperator_TRANSPOSE_CONV:
160       return "REGISTER_KERNEL(TRANSPOSE_CONV, TransposeConv)";
161     case circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
162       return "REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM)";
163     case circle::BuiltinOperator_WHILE:
164       return "REGISTER_KERNEL(WHILE, While)";
165     default:
166       assert(false && "Not supported kernel");
167   }
168 }
169
170 std::vector<char> loadFile(const std::string &path)
171 {
172   std::ifstream file(path, std::ios::binary | std::ios::in);
173   if (!file.good())
174   {
175     assert(false && "Failed to open file");
176   }
177
178   file.unsetf(std::ios::skipws);
179
180   file.seekg(0, std::ios::end);
181   auto fileSize = file.tellg();
182   file.seekg(0, std::ios::beg);
183
184   // reserve capacity
185   std::vector<char> data(fileSize);
186
187   // read the data
188   file.read(data.data(), fileSize);
189   if (file.fail())
190   {
191     assert(false && "Failed to read file");
192   }
193
194   return data;
195 }
196
197 // Parse model and write to std::ofstream &os models operations
198 void run(std::ofstream &os, const circle::Model *model)
199 {
200   luci_interpreter::CircleReader reader;
201   reader.parse(model);
202   const uint32_t subgraph_size = reader.num_subgraph();
203
204   // Set to avoid duplication in generated list
205   std::set<circle::BuiltinOperator> operations_set;
206
207   for (uint32_t g = 0; g < subgraph_size; g++)
208   {
209     reader.select_subgraph(g);
210     auto ops = reader.operators();
211     for (uint32_t i = 0; i < ops.size(); ++i)
212     {
213       const auto op = ops.at(i);
214       auto op_builtin_operator = reader.builtin_code(op);
215
216       auto result = operations_set.insert(op_builtin_operator);
217       if (result.second)
218       {
219         os << get_register_kernel_str(op_builtin_operator) << std::endl;
220       }
221     }
222   }
223 }
224
225 int main(int argc, char **argv)
226 {
227   if (argc != 3)
228   {
229     assert(false && "Should be 2 arguments: circle model path, and path for generated model\n");
230   }
231
232   std::string model_file(argv[1]);
233   std::string generated_file_path(argv[2]);
234
235   std::vector<char> model_data = loadFile(model_file);
236   const circle::Model *circle_model = circle::GetModel(model_data.data());
237
238   if (circle_model == nullptr)
239   {
240     std::cerr << "ERROR: Failed to load circle '" << model_file << "'" << std::endl;
241     return 255;
242   }
243
244   // Open or create file
245   std::ofstream out;
246   out.open(generated_file_path);
247
248   if (out.is_open())
249     run(out, circle_model);
250   else
251     std::cout << "SMTH GOES WRONG WHILE OPEN FILE" << std::endl;
252   return 0;
253 }