Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / CircleReader.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/CircleReader.h"
18
19 #include <memory>
20 #include <sstream>
21 #include <string>
22
23 namespace luci
24 {
25
26 bool is_valid(const circle::OperatorCodeT &opcode)
27 {
28   circle::BuiltinOperator code = opcode.builtin_code;
29   return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
30 }
31
32 bool is_valid(const circle::OperatorCode *opcode)
33 {
34   assert(opcode != nullptr);
35   circle::BuiltinOperator code = opcode->builtin_code();
36   return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
37 }
38
39 bool is_custom(const circle::OperatorCodeT &opcode)
40 {
41   circle::BuiltinOperator code = opcode.builtin_code;
42   return (code == circle::BuiltinOperator_CUSTOM);
43 }
44
45 bool is_custom(const circle::OperatorCode *opcode)
46 {
47   assert(opcode != nullptr);
48   circle::BuiltinOperator code = opcode->builtin_code();
49   return (code == circle::BuiltinOperator_CUSTOM);
50 }
51
52 std::string opcode_name(const circle::OperatorCodeT &opcode)
53 {
54   if (!is_valid(opcode))
55   {
56     std::ostringstream oss;
57     oss << "(invalid)";
58     return oss.str();
59   }
60
61   if (is_custom(opcode))
62   {
63     if (opcode.custom_code.empty())
64       return "(invalid custom)";
65
66     return opcode.custom_code;
67   }
68
69   circle::BuiltinOperator code = opcode.builtin_code;
70   return circle::EnumNameBuiltinOperator(code);
71 }
72
73 std::string opcode_name(const circle::OperatorCode *opcode)
74 {
75   assert(opcode != nullptr);
76
77   if (!is_valid(opcode))
78   {
79     std::ostringstream oss;
80     oss << "(invalid)";
81     return oss.str();
82   }
83
84   if (is_custom(opcode))
85   {
86     auto custom_code = opcode->custom_code()->str();
87     if (custom_code.empty())
88       return "(invalid custom)";
89
90     return custom_code;
91   }
92
93   circle::BuiltinOperator code = opcode->builtin_code();
94   return circle::EnumNameBuiltinOperator(code);
95 }
96
97 const char *tensor_name(const circle::TensorT &tensor)
98 {
99   static const char *kEmptyTensorName = "(noname)";
100
101   if (!tensor.name.empty())
102     return tensor.name.c_str();
103
104   return kEmptyTensorName;
105 }
106
107 const char *tensor_name(const circle::Tensor *tensor)
108 {
109   assert(tensor != nullptr);
110
111   static const char *kEmptyTensorName = "(noname)";
112   const auto tensor_name = tensor->name()->c_str();
113
114   if (!std::string(tensor_name).empty())
115     return tensor_name;
116
117   return kEmptyTensorName;
118 }
119
120 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
121 {
122   return tensor.quantization.get();
123 }
124
125 const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
126 {
127   assert(tensor != nullptr);
128   return tensor->quantization();
129 }
130
131 loco::DataType luci_datatype(const circle::TensorType type)
132 {
133   switch (type)
134   {
135     case circle::TensorType_FLOAT32:
136       return loco::DataType::FLOAT32;
137     case circle::TensorType_FLOAT16:
138       return loco::DataType::FLOAT16;
139     case circle::TensorType_INT32:
140       return loco::DataType::S32;
141     case circle::TensorType_UINT8:
142       return loco::DataType::U8;
143     case circle::TensorType_INT64:
144       return loco::DataType::S64;
145     case circle::TensorType_STRING:
146       return loco::DataType::STRING;
147     case circle::TensorType_BOOL:
148       return loco::DataType::BOOL;
149     case circle::TensorType_INT16:
150       return loco::DataType::S16;
151     case circle::TensorType_COMPLEX64:
152       break;
153     case circle::TensorType_INT8:
154       return loco::DataType::S8;
155     default:
156       break;
157   }
158   assert(false);
159   return loco::DataType::Unknown;
160 }
161
162 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
163 {
164   switch (type)
165   {
166     case circle::ActivationFunctionType::ActivationFunctionType_NONE:
167       return luci::FusedActFunc::NONE;
168     case circle::ActivationFunctionType::ActivationFunctionType_RELU:
169       return luci::FusedActFunc::RELU;
170     case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
171       return luci::FusedActFunc::RELU_N1_TO_1;
172     case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
173       return luci::FusedActFunc::RELU6;
174     case circle::ActivationFunctionType::ActivationFunctionType_TANH:
175       return luci::FusedActFunc::TANH;
176     case circle::ActivationFunctionType::ActivationFunctionType_SIGN_BIT:
177       return luci::FusedActFunc::SIGN_BIT;
178     default:
179       break;
180   }
181   assert(false);
182   return luci::FusedActFunc::UNDEFINED;
183 }
184
185 Padding luci_padding(const circle::Padding padding)
186 {
187   switch (padding)
188   {
189     case circle::Padding::Padding_SAME:
190       return Padding::SAME;
191     case circle::Padding::Padding_VALID:
192       return Padding::VALID;
193   }
194   assert(false);
195   return Padding::UNDEFINED;
196 }
197
198 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
199 {
200   switch (mode)
201   {
202     case circle::MirrorPadMode::MirrorPadMode_REFLECT:
203       return MirrorPadMode::REFLECT;
204     case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
205       return MirrorPadMode::SYMMETRIC;
206   }
207   assert(false);
208   return MirrorPadMode::UNDEFINED;
209 }
210
211 luci::CircleFullyConnected::WeightsFormat
212 luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format)
213 {
214   switch (weights_format)
215   {
216     case circle::FullyConnectedOptionsWeightsFormat_DEFAULT:
217       return luci::CircleFullyConnected::WeightsFormat::DEFAULT;
218     case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
219       return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8;
220     case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32:
221       return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32;
222     default:
223       throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat");
224   }
225 }
226
227 DimensionType luci_dim_type(const circle::DimensionType dim_type)
228 {
229   switch (dim_type)
230   {
231     case circle::DimensionType_DENSE:
232       return DimensionType::DENSE;
233     case circle::DimensionType_SPARSE_CSR:
234       return DimensionType::SPARSE_CSR;
235     default:
236       throw std::runtime_error("Invalid DimensionType");
237   }
238 }
239
240 SparseIndexVector
241 luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vector)
242 {
243   switch (sparse_index_vector.type)
244   {
245     case circle::SparseIndexVector_NONE:
246       return SparseIndexVector{SparseIndexVectorType::NONE, nullptr};
247     case circle::SparseIndexVector_Int32Vector:
248     {
249       const auto const_vec_ptr =
250         static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
251       return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr};
252     }
253     case circle::SparseIndexVector_Uint16Vector:
254     {
255       const auto const_vec_ptr =
256         static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
257       return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr};
258     }
259     case circle::SparseIndexVector_Uint8Vector:
260     {
261       const auto const_vec_ptr =
262         static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
263       return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr};
264     }
265     default:
266       throw std::runtime_error("Invalid SparseIndexVector type");
267   }
268 }
269
270 std::unique_ptr<CircleQuantParam>
271 luci_quantparam(const circle::QuantizationParametersT *quantization)
272 {
273   const auto &min = quantization->min;
274   const auto &max = quantization->max;
275   const auto &scale = quantization->scale;
276   const auto &zero_point = quantization->zero_point;
277   const auto &quantized_dimension = quantization->quantized_dimension;
278
279   if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
280   {
281     auto quantparam = std::make_unique<CircleQuantParam>();
282
283     quantparam->min = min;
284     quantparam->max = max;
285     quantparam->scale = scale;
286     quantparam->zerop = zero_point;
287     quantparam->quantized_dimension = quantized_dimension;
288
289     return quantparam;
290   }
291
292   return nullptr;
293 }
294
295 std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParameters *qparams)
296 {
297   // create temporary unpacked API object
298   assert(qparams != nullptr);
299   circle::QuantizationParametersT quantization;
300   qparams->UnPackTo(&quantization);
301
302   return luci_quantparam(&quantization);
303 }
304
305 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
306 {
307   assert(sparsity);
308   const auto &traversal_order = sparsity->traversal_order;
309   const auto &block_map = sparsity->block_map;
310   const auto &dim_metadata = sparsity->dim_metadata;
311
312   // TODO find a condition that should return nullptr
313   auto sparsityparam = std::make_unique<SparsityParam>();
314
315   sparsityparam->traversal_order = traversal_order;
316   sparsityparam->block_map = block_map;
317   for (const auto &dm : dim_metadata)
318   {
319     sparsityparam->dim_metadata.emplace_back(luci_dim_type(dm->format), dm->dense_size,
320                                              luci_sparse_index_vector(dm->array_segments),
321                                              luci_sparse_index_vector(dm->array_indices));
322   }
323
324   return sparsityparam;
325 }
326
327 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParameters *sparparam)
328 {
329   // create temporary unpacked API object
330   assert(sparparam != nullptr);
331   circle::SparsityParametersT sparsity;
332   sparparam->UnPackTo(&sparsity);
333
334   return luci_sparsityparam(&sparsity);
335 }
336
337 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
338 {
339   node->name(tensor_name(tensor));
340   node->dtype(luci_datatype(tensor.type));
341
342   assert(tensor.shape_signature.size() == 0 ||
343          tensor.shape_signature.size() == tensor.shape.size());
344
345   std::vector<int32_t> dims = tensor.shape; // in NHWC
346   node->rank(dims.size());
347   for (uint32_t r = 0; r < dims.size(); ++r)
348   {
349     if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
350       node->dim(r).unset();
351     else
352       node->dim(r).set(dims[r]);
353   }
354
355   const auto *quantization = tensor.quantization.get();
356   if (quantization != nullptr)
357   {
358     auto quantparam = luci_quantparam(quantization);
359     if (quantparam)
360       node->quantparam(std::move(quantparam));
361   }
362
363   const auto *sparsity = tensor.sparsity.get();
364   if (sparsity != nullptr)
365   {
366     auto sparsityparam = luci_sparsityparam(sparsity);
367     if (sparsityparam)
368       node->sparsityparam(std::move(sparsityparam));
369   }
370 }
371
372 void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
373 {
374   assert(tensor != nullptr);
375
376   node->name(tensor_name(tensor));
377   node->dtype(luci_datatype(tensor->type()));
378
379   const auto tensor_shape_signature = wrap(tensor->shape_signature());
380   const auto tensor_shape = wrap(tensor->shape());
381   assert(tensor_shape_signature.size() == 0 ||
382          tensor_shape_signature.size() == tensor_shape.size());
383
384   const auto dims = tensor_shape; // in NHWC
385   node->rank(dims.size());
386   for (uint32_t r = 0; r < dims.size(); ++r)
387   {
388     if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
389       node->dim(r).unset();
390     else
391       node->dim(r).set(dims[r]);
392   }
393
394   const auto quantization = tensor->quantization();
395   if (quantization != nullptr)
396   {
397     auto quantparam = luci_quantparam(quantization);
398     if (quantparam)
399       node->quantparam(std::move(quantparam));
400   }
401
402   const auto sparsity = tensor->sparsity();
403   if (sparsity != nullptr)
404   {
405     auto sparsityparam = luci_sparsityparam(sparsity);
406     if (sparsityparam)
407       node->sparsityparam(std::move(sparsityparam));
408   }
409 }
410
411 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
412 {
413   const auto &op_codes = opcodes();
414   uint32_t index = op.opcode_index;
415   assert(index < op_codes.size());
416   const circle::OperatorCodeT &opcode = *op_codes[index];
417
418   return opcode.builtin_code;
419 }
420
421 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
422 {
423   const auto &op_codes = opcodes();
424   uint32_t index = op.opcode_index;
425   assert(index < op_codes.size());
426   const circle::OperatorCodeT &opcode = *op_codes[index];
427
428   if (!is_valid(opcode))
429   {
430     std::ostringstream oss;
431     oss << "(invalid: " << index << ")";
432     return oss.str();
433   }
434
435   return ::luci::opcode_name(opcode);
436 }
437
438 bool CircleReader::parse(const circle::Model *model)
439 {
440   assert(model != nullptr);
441
442   _model.reset(model->UnPack());
443
444   // for direct pointer access
445   _native_model = model;
446
447   return true;
448 }
449
450 bool CircleReader::select_subgraph(uint32_t sgindex)
451 {
452   if (_model->subgraphs.size() <= sgindex)
453   {
454     assert(false);
455     return false;
456   }
457
458   _current_subgraph = _model->subgraphs[sgindex].get();
459
460   // for direct pointer access
461   auto subgraphs = _native_model->subgraphs();
462   assert(subgraphs != nullptr);
463
464   _native_subgraph = subgraphs->Get(sgindex);
465   assert(_native_subgraph != nullptr);
466
467   _tensors_ptr = _native_subgraph->tensors();
468
469   return true;
470 }
471
472 template <typename T>
473 VectorWrapper<T>::VectorWrapper(const flatbuffers::Vector<T> *ptr) : _vector(ptr)
474 {
475   // Do nothing
476 }
477
478 template <typename T> uint32_t VectorWrapper<T>::size() const
479 {
480   return null() ? 0 : _vector->size();
481 }
482
483 template <typename T> const T *VectorWrapper<T>::data() const
484 {
485   return null() ? nullptr : _vector->data();
486 }
487
488 template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::begin() const
489 {
490   return null() ? iterator(nullptr, 0) : _vector->begin();
491 }
492
493 template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::end() const
494 {
495   return null() ? begin() : _vector->end();
496 }
497
498 template <typename T> typename VectorWrapper<T>::value_type VectorWrapper<T>::at(uint32_t i) const
499 {
500   if (i >= size())
501   {
502     // TODO find better error message
503     throw std::range_error("Access to prohibited vector element");
504   }
505
506   return _vector->Get(i);
507 }
508
509 template <typename T>
510 typename VectorWrapper<T>::value_type VectorWrapper<T>::operator[](uint32_t i) const
511 {
512   return at(i);
513 }
514
515 template <typename T> bool VectorWrapper<T>::null() const { return _vector == nullptr; }
516 template <typename T> bool VectorWrapper<T>::empty() const { return size() == 0; }
517
518 #define REGISTER_WRAPPER(T) template class VectorWrapper<T>
519 REGISTER_WRAPPER(flatbuffers::Offset<circle::SubGraph>);
520 REGISTER_WRAPPER(flatbuffers::Offset<circle::Buffer>);
521 REGISTER_WRAPPER(flatbuffers::Offset<circle::Tensor>);
522 REGISTER_WRAPPER(flatbuffers::Offset<circle::Operator>);
523 REGISTER_WRAPPER(flatbuffers::Offset<circle::OperatorCode>);
524 REGISTER_WRAPPER(flatbuffers::Offset<circle::Metadata>);
525 REGISTER_WRAPPER(int32_t);
526 REGISTER_WRAPPER(uint8_t);
527 #undef REGISTER_WRAPPER
528
529 } // namespace luci