2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci_interpreter/core/reader/CircleMicroReader.h"
19 #include <circle-generated/circle/schema_generated.h>
25 std::string get_register_kernel_str(const circle::BuiltinOperator builtin_operator)
27 switch (builtin_operator)
29 case circle::BuiltinOperator_ADD:
30 return "REGISTER_KERNEL(ADD, Add)";
31 case circle::BuiltinOperator_ARG_MAX:
32 return "REGISTER_KERNEL(ARG_MAX, ArgMax)";
33 case circle::BuiltinOperator_AVERAGE_POOL_2D:
34 return "REGISTER_KERNEL(AVERAGE_POOL_2D, AveragePool2D)";
35 case circle::BuiltinOperator_BATCH_TO_SPACE_ND:
36 return "REGISTER_KERNEL(BATCH_TO_SPACE_ND, BatchToSpaceND)";
37 case circle::BuiltinOperator_CAST:
38 return "REGISTER_KERNEL(CAST, Cast)";
39 case circle::BuiltinOperator_CONCATENATION:
40 return "REGISTER_KERNEL(CONCATENATION, Concatenation)";
41 case circle::BuiltinOperator_CONV_2D:
42 return "REGISTER_KERNEL(CONV_2D, Conv2D)";
43 case circle::BuiltinOperator_DEPTH_TO_SPACE:
44 return "REGISTER_KERNEL(DEPTH_TO_SPACE, DepthToSpace)";
45 case circle::BuiltinOperator_DEPTHWISE_CONV_2D:
46 return "REGISTER_KERNEL(DEPTHWISE_CONV_2D, DepthwiseConv2D)";
47 case circle::BuiltinOperator_DEQUANTIZE:
48 return "REGISTER_KERNEL(DEQUANTIZE, Dequantize)";
49 case circle::BuiltinOperator_DIV:
50 return "REGISTER_KERNEL(DIV, Div)";
51 case circle::BuiltinOperator_ELU:
52 return "REGISTER_KERNEL(ELU, Elu)";
53 case circle::BuiltinOperator_EXP:
54 return "REGISTER_KERNEL(EXP, Exp)";
55 case circle::BuiltinOperator_EXPAND_DIMS:
56 return "REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)";
57 case circle::BuiltinOperator_FILL:
58 return "REGISTER_KERNEL(FILL, Fill)";
59 case circle::BuiltinOperator_FLOOR:
60 return "REGISTER_KERNEL(FLOOR, Floor)";
61 case circle::BuiltinOperator_FLOOR_DIV:
62 return "REGISTER_KERNEL(FLOOR_DIV, FloorDiv)";
63 case circle::BuiltinOperator_EQUAL:
64 return "REGISTER_KERNEL(EQUAL, Equal)";
65 case circle::BuiltinOperator_FULLY_CONNECTED:
66 return "REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)";
67 case circle::BuiltinOperator_GREATER:
68 return "REGISTER_KERNEL(GREATER, Greater)";
69 case circle::BuiltinOperator_GREATER_EQUAL:
70 return "REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)";
71 case circle::BuiltinOperator_INSTANCE_NORM:
72 return "REGISTER_KERNEL(INSTANCE_NORM, InstanceNorm)";
73 case circle::BuiltinOperator_L2_NORMALIZATION:
74 return "REGISTER_KERNEL(L2_NORMALIZATION, L2Normalize)";
75 case circle::BuiltinOperator_L2_POOL_2D:
76 return "REGISTER_KERNEL(L2_POOL_2D, L2Pool2D)";
77 case circle::BuiltinOperator_LEAKY_RELU:
78 return "REGISTER_KERNEL(LEAKY_RELU, LeakyRelu)";
79 case circle::BuiltinOperator_LESS:
80 return "REGISTER_KERNEL(LESS, Less)";
81 case circle::BuiltinOperator_LESS_EQUAL:
82 return "REGISTER_KERNEL(LESS_EQUAL, LessEqual)";
83 case circle::BuiltinOperator_LOGICAL_AND:
84 return "REGISTER_KERNEL(LOGICAL_AND, LogicalAnd)";
85 case circle::BuiltinOperator_LOGICAL_NOT:
86 return "REGISTER_KERNEL(LOGICAL_NOT, LogicalNot)";
87 case circle::BuiltinOperator_LOGICAL_OR:
88 return "REGISTER_KERNEL(LOGICAL_OR, LogicalOr)";
89 case circle::BuiltinOperator_LOGISTIC:
90 return "REGISTER_KERNEL(LOGISTIC, Logistic)";
91 case circle::BuiltinOperator_MAXIMUM:
92 return "REGISTER_KERNEL(MAXIMUM, Maximum)";
93 case circle::BuiltinOperator_MAX_POOL_2D:
94 return "REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)";
95 case circle::BuiltinOperator_MINIMUM:
96 return "REGISTER_KERNEL(MINIMUM, Minimum)";
97 case circle::BuiltinOperator_MIRROR_PAD:
98 return "REGISTER_KERNEL(MIRROR_PAD, MirrorPad)";
99 case circle::BuiltinOperator_MUL:
100 return "REGISTER_KERNEL(MUL, Mul)";
101 case circle::BuiltinOperator_NEG:
102 return "REGISTER_KERNEL(NEG, Neg)";
103 case circle::BuiltinOperator_NOT_EQUAL:
104 return "REGISTER_KERNEL(NOT_EQUAL, NotEqual)";
105 case circle::BuiltinOperator_PAD:
106 return "REGISTER_KERNEL(PAD, Pad)";
107 case circle::BuiltinOperator_PADV2:
108 return "REGISTER_KERNEL(PADV2, PadV2)";
109 case circle::BuiltinOperator_PRELU:
110 return "REGISTER_KERNEL(PRELU, PRelu)";
111 case circle::BuiltinOperator_QUANTIZE:
112 return "REGISTER_KERNEL(QUANTIZE, Quantize)";
113 case circle::BuiltinOperator_RESHAPE:
114 return "REGISTER_KERNEL(RESHAPE, Reshape)";
115 case circle::BuiltinOperator_RESIZE_BILINEAR:
116 return "REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)";
117 case circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
118 return "REGISTER_KERNEL(RESIZE_NEAREST_NEIGHBOR, ResizeNearestNeighbor)";
119 case circle::BuiltinOperator_RSQRT:
120 return "REGISTER_KERNEL(RSQRT, Rsqrt)";
121 case circle::BuiltinOperator_SHAPE:
122 return "REGISTER_KERNEL(SHAPE, Shape)";
123 case circle::BuiltinOperator_SOFTMAX:
124 return "REGISTER_KERNEL(SOFTMAX, Softmax)";
125 case circle::BuiltinOperator_SPACE_TO_BATCH_ND:
126 return "REGISTER_KERNEL(SPACE_TO_BATCH_ND, SpaceToBatchND)";
127 case circle::BuiltinOperator_SPACE_TO_DEPTH:
128 return "REGISTER_KERNEL(SPACE_TO_DEPTH, SpaceToDepth)";
129 case circle::BuiltinOperator_STRIDED_SLICE:
130 return "REGISTER_KERNEL(STRIDED_SLICE, StridedSlice)";
131 case circle::BuiltinOperator_SQRT:
132 return "REGISTER_KERNEL(SQRT, Sqrt)";
133 case circle::BuiltinOperator_SQUARE:
134 return "REGISTER_KERNEL(SQUARE, Square)";
135 case circle::BuiltinOperator_SQUARED_DIFFERENCE:
136 return "REGISTER_KERNEL(SQUARED_DIFFERENCE, SquaredDifference)";
137 case circle::BuiltinOperator_SQUEEZE:
138 return "REGISTER_KERNEL(SQUEEZE, Squeeze)";
139 case circle::BuiltinOperator_SUB:
140 return "REGISTER_KERNEL(SUB, Sub)";
141 case circle::BuiltinOperator_SVDF:
142 return "REGISTER_KERNEL(SVDF, SVDF)";
143 case circle::BuiltinOperator_TANH:
144 return "REGISTER_KERNEL(TANH, Tanh)";
145 case circle::BuiltinOperator_TRANSPOSE:
146 return "REGISTER_KERNEL(TRANSPOSE, Transpose)";
147 case circle::BuiltinOperator_TRANSPOSE_CONV:
148 return "REGISTER_KERNEL(TRANSPOSE_CONV, TransposeConv)";
149 case circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
150 return "REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM)";
152 assert(false && "Not supported kernel");
156 std::vector<char> loadFile(const std::string &path)
158 std::ifstream file(path, std::ios::binary | std::ios::in);
161 assert(false && "Failed to open file");
164 file.unsetf(std::ios::skipws);
166 file.seekg(0, std::ios::end);
167 auto fileSize = file.tellg();
168 file.seekg(0, std::ios::beg);
171 std::vector<char> data(fileSize);
174 file.read(data.data(), fileSize);
177 assert(false && "Failed to read file");
183 // Parse model and write to std::ofstream &os models operations
184 void run(std::ofstream &os, const circle::Model *model)
186 luci_interpreter::CircleReader reader;
188 const uint32_t subgraph_size = reader.num_subgraph();
190 // Set to avoid duplication in generated list
191 std::set<circle::BuiltinOperator> operations_set;
193 for (uint32_t g = 0; g < subgraph_size; g++)
195 reader.select_subgraph(g);
196 auto ops = reader.operators();
197 for (uint32_t i = 0; i < ops.size(); ++i)
199 const auto op = ops.at(i);
200 auto op_builtin_operator = reader.builtin_code(op);
202 auto result = operations_set.insert(op_builtin_operator);
205 os << get_register_kernel_str(op_builtin_operator) << std::endl;
211 int main(int argc, char **argv)
215 assert(false && "Should be 2 arguments: circle model path, and path for generated model\n");
218 std::string model_file(argv[1]);
219 std::string generated_file_path(argv[2]);
221 std::vector<char> model_data = loadFile(model_file);
222 const circle::Model *circle_model = circle::GetModel(model_data.data());
224 if (circle_model == nullptr)
226 std::cerr << "ERROR: Failed to load circle '" << model_file << "'" << std::endl;
230 // Open or create file
232 out.open(generated_file_path);
235 run(out, circle_model);
237 std::cout << "SMTH GOES WRONG WHILE OPEN FILE" << std::endl;