2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Import/CircleReader.h"
26 bool is_valid(const circle::OperatorCodeT &opcode)
28 circle::BuiltinOperator code = opcode.builtin_code;
29 return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
32 bool is_valid(const circle::OperatorCode *opcode)
34 assert(opcode != nullptr);
35 circle::BuiltinOperator code = opcode->builtin_code();
36 return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
39 bool is_custom(const circle::OperatorCodeT &opcode)
41 circle::BuiltinOperator code = opcode.builtin_code;
42 return (code == circle::BuiltinOperator_CUSTOM);
45 bool is_custom(const circle::OperatorCode *opcode)
47 assert(opcode != nullptr);
48 circle::BuiltinOperator code = opcode->builtin_code();
49 return (code == circle::BuiltinOperator_CUSTOM);
52 std::string opcode_name(const circle::OperatorCodeT &opcode)
54 if (!is_valid(opcode))
56 std::ostringstream oss;
61 if (is_custom(opcode))
63 if (opcode.custom_code.empty())
64 return "(invalid custom)";
66 return opcode.custom_code;
69 circle::BuiltinOperator code = opcode.builtin_code;
70 return circle::EnumNameBuiltinOperator(code);
73 std::string opcode_name(const circle::OperatorCode *opcode)
75 assert(opcode != nullptr);
77 if (!is_valid(opcode))
79 std::ostringstream oss;
84 if (is_custom(opcode))
86 auto custom_code = opcode->custom_code()->str();
87 if (custom_code.empty())
88 return "(invalid custom)";
93 circle::BuiltinOperator code = opcode->builtin_code();
94 return circle::EnumNameBuiltinOperator(code);
97 const char *tensor_name(const circle::TensorT &tensor)
99 static const char *kEmptyTensorName = "(noname)";
101 if (!tensor.name.empty())
102 return tensor.name.c_str();
104 return kEmptyTensorName;
107 const char *tensor_name(const circle::Tensor *tensor)
109 assert(tensor != nullptr);
111 static const char *kEmptyTensorName = "(noname)";
112 const auto tensor_name = tensor->name()->c_str();
114 if (!std::string(tensor_name).empty())
117 return kEmptyTensorName;
120 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
122 return tensor.quantization.get();
125 const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
127 assert(tensor != nullptr);
128 return tensor->quantization();
131 loco::DataType luci_datatype(const circle::TensorType type)
135 case circle::TensorType_FLOAT32:
136 return loco::DataType::FLOAT32;
137 case circle::TensorType_FLOAT16:
138 return loco::DataType::FLOAT16;
139 case circle::TensorType_INT32:
140 return loco::DataType::S32;
141 case circle::TensorType_UINT8:
142 return loco::DataType::U8;
143 case circle::TensorType_INT64:
144 return loco::DataType::S64;
145 case circle::TensorType_STRING:
146 return loco::DataType::STRING;
147 case circle::TensorType_BOOL:
148 return loco::DataType::BOOL;
149 case circle::TensorType_INT16:
150 return loco::DataType::S16;
151 case circle::TensorType_COMPLEX64:
153 case circle::TensorType_INT8:
154 return loco::DataType::S8;
159 return loco::DataType::Unknown;
162 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
166 case circle::ActivationFunctionType::ActivationFunctionType_NONE:
167 return luci::FusedActFunc::NONE;
168 case circle::ActivationFunctionType::ActivationFunctionType_RELU:
169 return luci::FusedActFunc::RELU;
170 case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
171 return luci::FusedActFunc::RELU_N1_TO_1;
172 case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
173 return luci::FusedActFunc::RELU6;
174 case circle::ActivationFunctionType::ActivationFunctionType_TANH:
175 return luci::FusedActFunc::TANH;
176 case circle::ActivationFunctionType::ActivationFunctionType_SIGN_BIT:
177 return luci::FusedActFunc::SIGN_BIT;
182 return luci::FusedActFunc::UNDEFINED;
185 Padding luci_padding(const circle::Padding padding)
189 case circle::Padding::Padding_SAME:
190 return Padding::SAME;
191 case circle::Padding::Padding_VALID:
192 return Padding::VALID;
195 return Padding::UNDEFINED;
198 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
202 case circle::MirrorPadMode::MirrorPadMode_REFLECT:
203 return MirrorPadMode::REFLECT;
204 case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
205 return MirrorPadMode::SYMMETRIC;
208 return MirrorPadMode::UNDEFINED;
211 luci::CircleFullyConnected::WeightsFormat
212 luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format)
214 switch (weights_format)
216 case circle::FullyConnectedOptionsWeightsFormat_DEFAULT:
217 return luci::CircleFullyConnected::WeightsFormat::DEFAULT;
218 case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
219 return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8;
220 case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32:
221 return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32;
223 throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat");
227 DimensionType luci_dim_type(const circle::DimensionType dim_type)
231 case circle::DimensionType_DENSE:
232 return DimensionType::DENSE;
233 case circle::DimensionType_SPARSE_CSR:
234 return DimensionType::SPARSE_CSR;
236 throw std::runtime_error("Invalid DimensionType");
241 luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vector)
243 switch (sparse_index_vector.type)
245 case circle::SparseIndexVector_NONE:
246 return SparseIndexVector{SparseIndexVectorType::NONE, nullptr};
247 case circle::SparseIndexVector_Int32Vector:
249 const auto const_vec_ptr =
250 static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
251 return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr};
253 case circle::SparseIndexVector_Uint16Vector:
255 const auto const_vec_ptr =
256 static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
257 return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr};
259 case circle::SparseIndexVector_Uint8Vector:
261 const auto const_vec_ptr =
262 static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
263 return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr};
266 throw std::runtime_error("Invalid SparseIndexVector type");
270 std::unique_ptr<CircleQuantParam>
271 luci_quantparam(const circle::QuantizationParametersT *quantization)
273 const auto &min = quantization->min;
274 const auto &max = quantization->max;
275 const auto &scale = quantization->scale;
276 const auto &zero_point = quantization->zero_point;
277 const auto &quantized_dimension = quantization->quantized_dimension;
279 if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
281 auto quantparam = std::make_unique<CircleQuantParam>();
283 quantparam->min = min;
284 quantparam->max = max;
285 quantparam->scale = scale;
286 quantparam->zerop = zero_point;
287 quantparam->quantized_dimension = quantized_dimension;
295 std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParameters *qparams)
297 // create temporary unpacked API object
298 assert(qparams != nullptr);
299 circle::QuantizationParametersT quantization;
300 qparams->UnPackTo(&quantization);
302 return luci_quantparam(&quantization);
305 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
308 const auto &traversal_order = sparsity->traversal_order;
309 const auto &block_map = sparsity->block_map;
310 const auto &dim_metadata = sparsity->dim_metadata;
312 // TODO find a condition that should return nullptr
313 auto sparsityparam = std::make_unique<SparsityParam>();
315 sparsityparam->traversal_order = traversal_order;
316 sparsityparam->block_map = block_map;
317 for (const auto &dm : dim_metadata)
319 sparsityparam->dim_metadata.emplace_back(luci_dim_type(dm->format), dm->dense_size,
320 luci_sparse_index_vector(dm->array_segments),
321 luci_sparse_index_vector(dm->array_indices));
324 return sparsityparam;
327 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParameters *sparparam)
329 // create temporary unpacked API object
330 assert(sparparam != nullptr);
331 circle::SparsityParametersT sparsity;
332 sparparam->UnPackTo(&sparsity);
334 return luci_sparsityparam(&sparsity);
337 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
339 node->name(tensor_name(tensor));
340 node->dtype(luci_datatype(tensor.type));
342 assert(tensor.shape_signature.size() == 0 ||
343 tensor.shape_signature.size() == tensor.shape.size());
345 std::vector<int32_t> dims = tensor.shape; // in NHWC
346 node->rank(dims.size());
347 for (uint32_t r = 0; r < dims.size(); ++r)
349 if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
350 node->dim(r).unset();
352 node->dim(r).set(dims[r]);
355 const auto *quantization = tensor.quantization.get();
356 if (quantization != nullptr)
358 auto quantparam = luci_quantparam(quantization);
360 node->quantparam(std::move(quantparam));
363 const auto *sparsity = tensor.sparsity.get();
364 if (sparsity != nullptr)
366 auto sparsityparam = luci_sparsityparam(sparsity);
368 node->sparsityparam(std::move(sparsityparam));
372 void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
374 assert(tensor != nullptr);
376 node->name(tensor_name(tensor));
377 node->dtype(luci_datatype(tensor->type()));
379 const auto tensor_shape_signature = wrap(tensor->shape_signature());
380 const auto tensor_shape = wrap(tensor->shape());
381 assert(tensor_shape_signature.size() == 0 ||
382 tensor_shape_signature.size() == tensor_shape.size());
384 const auto dims = tensor_shape; // in NHWC
385 node->rank(dims.size());
386 for (uint32_t r = 0; r < dims.size(); ++r)
388 if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
389 node->dim(r).unset();
391 node->dim(r).set(dims[r]);
394 const auto quantization = tensor->quantization();
395 if (quantization != nullptr)
397 auto quantparam = luci_quantparam(quantization);
399 node->quantparam(std::move(quantparam));
402 const auto sparsity = tensor->sparsity();
403 if (sparsity != nullptr)
405 auto sparsityparam = luci_sparsityparam(sparsity);
407 node->sparsityparam(std::move(sparsityparam));
411 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
413 const auto &op_codes = opcodes();
414 uint32_t index = op.opcode_index;
415 assert(index < op_codes.size());
416 const circle::OperatorCodeT &opcode = *op_codes[index];
418 return opcode.builtin_code;
421 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
423 const auto &op_codes = opcodes();
424 uint32_t index = op.opcode_index;
425 assert(index < op_codes.size());
426 const circle::OperatorCodeT &opcode = *op_codes[index];
428 if (!is_valid(opcode))
430 std::ostringstream oss;
431 oss << "(invalid: " << index << ")";
435 return ::luci::opcode_name(opcode);
438 bool CircleReader::parse(const circle::Model *model)
440 assert(model != nullptr);
442 _model.reset(model->UnPack());
444 // for direct pointer access
445 _native_model = model;
450 bool CircleReader::select_subgraph(uint32_t sgindex)
452 if (_model->subgraphs.size() <= sgindex)
458 _current_subgraph = _model->subgraphs[sgindex].get();
460 // for direct pointer access
461 auto subgraphs = _native_model->subgraphs();
462 assert(subgraphs != nullptr);
464 _native_subgraph = subgraphs->Get(sgindex);
465 assert(_native_subgraph != nullptr);
467 _tensors_ptr = _native_subgraph->tensors();
472 template <typename T>
473 VectorWrapper<T>::VectorWrapper(const flatbuffers::Vector<T> *ptr) : _vector(ptr)
478 template <typename T> uint32_t VectorWrapper<T>::size() const
480 return null() ? 0 : _vector->size();
483 template <typename T> const T *VectorWrapper<T>::data() const
485 return null() ? nullptr : _vector->data();
488 template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::begin() const
490 return null() ? iterator(nullptr, 0) : _vector->begin();
493 template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::end() const
495 return null() ? begin() : _vector->end();
498 template <typename T> typename VectorWrapper<T>::value_type VectorWrapper<T>::at(uint32_t i) const
502 // TODO find better error message
503 throw std::range_error("Access to prohibited vector element");
506 return _vector->Get(i);
509 template <typename T>
510 typename VectorWrapper<T>::value_type VectorWrapper<T>::operator[](uint32_t i) const
515 template <typename T> bool VectorWrapper<T>::null() const { return _vector == nullptr; }
516 template <typename T> bool VectorWrapper<T>::empty() const { return size() == 0; }
518 #define REGISTER_WRAPPER(T) template class VectorWrapper<T>
519 REGISTER_WRAPPER(flatbuffers::Offset<circle::SubGraph>);
520 REGISTER_WRAPPER(flatbuffers::Offset<circle::Buffer>);
521 REGISTER_WRAPPER(flatbuffers::Offset<circle::Tensor>);
522 REGISTER_WRAPPER(flatbuffers::Offset<circle::Operator>);
523 REGISTER_WRAPPER(flatbuffers::Offset<circle::OperatorCode>);
524 REGISTER_WRAPPER(flatbuffers::Offset<circle::Metadata>);
525 REGISTER_WRAPPER(int32_t);
526 REGISTER_WRAPPER(uint8_t);
527 #undef REGISTER_WRAPPER