068de52391325fa42e7006e9f76dd4773599d8e7
[platform/core/ml/nnfw.git] / compiler / luci / import / src / CircleReader.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/CircleReader.h"
18
19 #include <memory>
20 #include <sstream>
21 #include <string>
22
23 namespace luci
24 {
25
26 bool is_valid(const circle::OperatorCodeT &opcode)
27 {
28   circle::BuiltinOperator code = opcode.builtin_code;
29   return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
30 }
31
32 bool is_custom(const circle::OperatorCodeT &opcode)
33 {
34   circle::BuiltinOperator code = opcode.builtin_code;
35   return (code == circle::BuiltinOperator_CUSTOM);
36 }
37
38 std::string opcode_name(const circle::OperatorCodeT &opcode)
39 {
40   if (!is_valid(opcode))
41   {
42     std::ostringstream oss;
43     oss << "(invalid)";
44     return oss.str();
45   }
46
47   if (is_custom(opcode))
48   {
49     if (opcode.custom_code.empty())
50       return "(invalid custom)";
51
52     return opcode.custom_code;
53   }
54
55   circle::BuiltinOperator code = opcode.builtin_code;
56   return circle::EnumNameBuiltinOperator(code);
57 }
58
59 const char *tensor_name(const circle::TensorT &tensor)
60 {
61   static const char *kEmptyTensorName = "(noname)";
62
63   if (!tensor.name.empty())
64     return tensor.name.c_str();
65
66   return kEmptyTensorName;
67 }
68
69 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
70 {
71   return tensor.quantization.get();
72 }
73
74 loco::DataType luci_datatype(const circle::TensorType type)
75 {
76   switch (type)
77   {
78     case circle::TensorType_FLOAT32:
79       return loco::DataType::FLOAT32;
80     case circle::TensorType_FLOAT16:
81       return loco::DataType::FLOAT16;
82     case circle::TensorType_INT32:
83       return loco::DataType::S32;
84     case circle::TensorType_UINT8:
85       return loco::DataType::U8;
86     case circle::TensorType_INT64:
87       return loco::DataType::S64;
88     case circle::TensorType_STRING:
89       break;
90     case circle::TensorType_BOOL:
91       return loco::DataType::BOOL;
92     case circle::TensorType_INT16:
93       return loco::DataType::S16;
94     case circle::TensorType_COMPLEX64:
95       break;
96     case circle::TensorType_INT8:
97       return loco::DataType::S8;
98     default:
99       break;
100   }
101   assert(false);
102   return loco::DataType::Unknown;
103 }
104
105 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
106 {
107   switch (type)
108   {
109     case circle::ActivationFunctionType::ActivationFunctionType_NONE:
110       return luci::FusedActFunc::NONE;
111     case circle::ActivationFunctionType::ActivationFunctionType_RELU:
112       return luci::FusedActFunc::RELU;
113     case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
114       return luci::FusedActFunc::RELU_N1_TO_1;
115     case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
116       return luci::FusedActFunc::RELU6;
117     case circle::ActivationFunctionType::ActivationFunctionType_TANH:
118       return luci::FusedActFunc::TANH;
119     case circle::ActivationFunctionType::ActivationFunctionType_SIGN_BIT:
120       return luci::FusedActFunc::SIGN_BIT;
121     default:
122       break;
123   }
124   assert(false);
125   return luci::FusedActFunc::UNDEFINED;
126 }
127
128 Padding luci_padding(const circle::Padding padding)
129 {
130   switch (padding)
131   {
132     case circle::Padding::Padding_SAME:
133       return Padding::SAME;
134     case circle::Padding::Padding_VALID:
135       return Padding::VALID;
136   }
137   assert(false);
138   return Padding::UNDEFINED;
139 }
140
141 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
142 {
143   switch (mode)
144   {
145     case circle::MirrorPadMode::MirrorPadMode_REFLECT:
146       return MirrorPadMode::REFLECT;
147     case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
148       return MirrorPadMode::SYMMETRIC;
149   }
150   assert(false);
151   return MirrorPadMode::UNDEFINED;
152 }
153
154 DimensionType luci_dim_type(const circle::DimensionType dim_type)
155 {
156   switch (dim_type)
157   {
158     case circle::DimensionType_DENSE:
159       return DimensionType::DENSE;
160     case circle::DimensionType_SPARSE_CSR:
161       return DimensionType::SPARSE_CSR;
162     default:
163       throw std::runtime_error("Invalid DimensionType");
164   }
165 }
166
167 SparseIndexVector
168 luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vector)
169 {
170   switch (sparse_index_vector.type)
171   {
172     case circle::SparseIndexVector_NONE:
173       return SparseIndexVector{SparseIndexVectorType::NONE, nullptr};
174     case circle::SparseIndexVector_Int32Vector:
175     {
176       const auto const_vec_ptr =
177           static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
178       return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr};
179     }
180     case circle::SparseIndexVector_Uint16Vector:
181     {
182       const auto const_vec_ptr =
183           static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
184       return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr};
185     }
186     case circle::SparseIndexVector_Uint8Vector:
187     {
188       const auto const_vec_ptr =
189           static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
190       return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr};
191     }
192     default:
193       throw std::runtime_error("Invalid SparseIndexVector type");
194   }
195 }
196
197 std::unique_ptr<CircleQuantParam>
198 luci_quantparam(const circle::QuantizationParametersT *quantization)
199 {
200   const auto &min = quantization->min;
201   const auto &max = quantization->max;
202   const auto &scale = quantization->scale;
203   const auto &zero_point = quantization->zero_point;
204   const auto &quantized_dimension = quantization->quantized_dimension;
205
206   if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
207   {
208     auto quantparam = std::make_unique<CircleQuantParam>();
209
210     quantparam->min = min;
211     quantparam->max = max;
212     quantparam->scale = scale;
213     quantparam->zerop = zero_point;
214     quantparam->quantized_dimension = quantized_dimension;
215
216     return quantparam;
217   }
218
219   return nullptr;
220 }
221
222 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
223 {
224   assert(sparsity);
225   const auto &traversal_order = sparsity->traversal_order;
226   const auto &block_map = sparsity->block_map;
227   const auto &dim_metadata = sparsity->dim_metadata;
228
229   // TODO find a condition that should return nullptr
230   auto sparsityparam = std::make_unique<SparsityParam>();
231
232   sparsityparam->traversal_order = traversal_order;
233   sparsityparam->block_map = block_map;
234   for (const auto &dm : dim_metadata)
235   {
236     sparsityparam->dim_metadata.emplace_back(luci_dim_type(dm->format), dm->dense_size,
237                                              luci_sparse_index_vector(dm->array_segments),
238                                              luci_sparse_index_vector(dm->array_indices));
239   }
240
241   return sparsityparam;
242 }
243
244 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
245 {
246   node->name(tensor_name(tensor));
247   node->dtype(luci_datatype(tensor.type));
248
249   std::vector<int32_t> dims = tensor.shape; // in NHWC
250   node->rank(dims.size());
251   for (uint32_t r = 0; r < dims.size(); ++r)
252   {
253     node->dim(r) = loco::Dimension(dims[r]);
254   }
255
256   node->shape_signature(tensor.shape_signature);
257
258   const auto *quantization = tensor.quantization.get();
259   if (quantization != nullptr)
260   {
261     auto quantparam = luci_quantparam(quantization);
262     if (quantparam)
263       node->quantparam(std::move(quantparam));
264   }
265
266   const auto *sparsity = tensor.sparsity.get();
267   if (sparsity != nullptr)
268   {
269     auto sparsityparam = luci_sparsityparam(sparsity);
270     if (sparsityparam)
271       node->sparsityparam(std::move(sparsityparam));
272   }
273 }
274
275 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
276 {
277   const auto &op_codes = opcodes();
278   uint32_t index = op.opcode_index;
279   assert(index < op_codes.size());
280   const circle::OperatorCodeT &opcode = *op_codes[index];
281
282   return opcode.builtin_code;
283 }
284
285 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
286 {
287   const auto &op_codes = opcodes();
288   uint32_t index = op.opcode_index;
289   assert(index < op_codes.size());
290   const circle::OperatorCodeT &opcode = *op_codes[index];
291
292   if (!is_valid(opcode))
293   {
294     std::ostringstream oss;
295     oss << "(invalid: " << index << ")";
296     return oss.str();
297   }
298
299   return ::luci::opcode_name(opcode);
300 }
301
302 bool CircleReader::parse(const circle::Model *model)
303 {
304   assert(model != nullptr);
305
306   _model.reset(model->UnPack());
307
308   // for direct pointer access
309   _model_ptr = model;
310
311   return true;
312 }
313
314 bool CircleReader::select_subgraph(uint32_t sgindex)
315 {
316   if (_model->subgraphs.size() <= sgindex)
317   {
318     assert(false);
319     return false;
320   }
321
322   _current_subgraph = _model->subgraphs[sgindex].get();
323
324   // for direct pointer access
325   auto subgraphs = _model_ptr->subgraphs();
326   const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
327
328   _tensors_ptr = subgraph->tensors();
329
330   return true;
331 }
332
333 } // namespace luci