2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Import/CircleReader.h"
26 bool is_valid(const circle::OperatorCodeT &opcode)
28 circle::BuiltinOperator code = opcode.builtin_code;
29 return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
32 bool is_custom(const circle::OperatorCodeT &opcode)
34 circle::BuiltinOperator code = opcode.builtin_code;
35 return (code == circle::BuiltinOperator_CUSTOM);
38 std::string opcode_name(const circle::OperatorCodeT &opcode)
40 if (!is_valid(opcode))
42 std::ostringstream oss;
47 if (is_custom(opcode))
49 if (opcode.custom_code.empty())
50 return "(invalid custom)";
52 return opcode.custom_code;
55 circle::BuiltinOperator code = opcode.builtin_code;
56 return circle::EnumNameBuiltinOperator(code);
59 const char *tensor_name(const circle::TensorT &tensor)
61 static const char *kEmptyTensorName = "(noname)";
63 if (!tensor.name.empty())
64 return tensor.name.c_str();
66 return kEmptyTensorName;
69 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
71 return tensor.quantization.get();
74 loco::DataType luci_datatype(const circle::TensorType type)
78 case circle::TensorType_FLOAT32:
79 return loco::DataType::FLOAT32;
80 case circle::TensorType_FLOAT16:
81 return loco::DataType::FLOAT16;
82 case circle::TensorType_INT32:
83 return loco::DataType::S32;
84 case circle::TensorType_UINT8:
85 return loco::DataType::U8;
86 case circle::TensorType_INT64:
87 return loco::DataType::S64;
88 case circle::TensorType_STRING:
89 return loco::DataType::STRING;
90 case circle::TensorType_BOOL:
91 return loco::DataType::BOOL;
92 case circle::TensorType_INT16:
93 return loco::DataType::S16;
94 case circle::TensorType_COMPLEX64:
96 case circle::TensorType_INT8:
97 return loco::DataType::S8;
102 return loco::DataType::Unknown;
105 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
109 case circle::ActivationFunctionType::ActivationFunctionType_NONE:
110 return luci::FusedActFunc::NONE;
111 case circle::ActivationFunctionType::ActivationFunctionType_RELU:
112 return luci::FusedActFunc::RELU;
113 case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
114 return luci::FusedActFunc::RELU_N1_TO_1;
115 case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
116 return luci::FusedActFunc::RELU6;
117 case circle::ActivationFunctionType::ActivationFunctionType_TANH:
118 return luci::FusedActFunc::TANH;
119 case circle::ActivationFunctionType::ActivationFunctionType_SIGN_BIT:
120 return luci::FusedActFunc::SIGN_BIT;
125 return luci::FusedActFunc::UNDEFINED;
128 Padding luci_padding(const circle::Padding padding)
132 case circle::Padding::Padding_SAME:
133 return Padding::SAME;
134 case circle::Padding::Padding_VALID:
135 return Padding::VALID;
138 return Padding::UNDEFINED;
141 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
145 case circle::MirrorPadMode::MirrorPadMode_REFLECT:
146 return MirrorPadMode::REFLECT;
147 case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
148 return MirrorPadMode::SYMMETRIC;
151 return MirrorPadMode::UNDEFINED;
154 luci::CircleFullyConnected::WeightsFormat
155 luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format)
157 switch (weights_format)
159 case circle::FullyConnectedOptionsWeightsFormat_DEFAULT:
160 return luci::CircleFullyConnected::WeightsFormat::DEFAULT;
161 case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
162 return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8;
163 case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32:
164 return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32;
166 throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat");
170 DimensionType luci_dim_type(const circle::DimensionType dim_type)
174 case circle::DimensionType_DENSE:
175 return DimensionType::DENSE;
176 case circle::DimensionType_SPARSE_CSR:
177 return DimensionType::SPARSE_CSR;
179 throw std::runtime_error("Invalid DimensionType");
184 luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vector)
186 switch (sparse_index_vector.type)
188 case circle::SparseIndexVector_NONE:
189 return SparseIndexVector{SparseIndexVectorType::NONE, nullptr};
190 case circle::SparseIndexVector_Int32Vector:
192 const auto const_vec_ptr =
193 static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
194 return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr};
196 case circle::SparseIndexVector_Uint16Vector:
198 const auto const_vec_ptr =
199 static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
200 return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr};
202 case circle::SparseIndexVector_Uint8Vector:
204 const auto const_vec_ptr =
205 static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
206 return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr};
209 throw std::runtime_error("Invalid SparseIndexVector type");
213 std::unique_ptr<CircleQuantParam>
214 luci_quantparam(const circle::QuantizationParametersT *quantization)
216 const auto &min = quantization->min;
217 const auto &max = quantization->max;
218 const auto &scale = quantization->scale;
219 const auto &zero_point = quantization->zero_point;
220 const auto &quantized_dimension = quantization->quantized_dimension;
222 if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
224 auto quantparam = std::make_unique<CircleQuantParam>();
226 quantparam->min = min;
227 quantparam->max = max;
228 quantparam->scale = scale;
229 quantparam->zerop = zero_point;
230 quantparam->quantized_dimension = quantized_dimension;
238 std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
241 const auto &traversal_order = sparsity->traversal_order;
242 const auto &block_map = sparsity->block_map;
243 const auto &dim_metadata = sparsity->dim_metadata;
245 // TODO find a condition that should return nullptr
246 auto sparsityparam = std::make_unique<SparsityParam>();
248 sparsityparam->traversal_order = traversal_order;
249 sparsityparam->block_map = block_map;
250 for (const auto &dm : dim_metadata)
252 sparsityparam->dim_metadata.emplace_back(luci_dim_type(dm->format), dm->dense_size,
253 luci_sparse_index_vector(dm->array_segments),
254 luci_sparse_index_vector(dm->array_indices));
257 return sparsityparam;
260 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
262 node->name(tensor_name(tensor));
263 node->dtype(luci_datatype(tensor.type));
265 assert(tensor.shape_signature.size() == 0 ||
266 tensor.shape_signature.size() == tensor.shape.size());
268 std::vector<int32_t> dims = tensor.shape; // in NHWC
269 node->rank(dims.size());
270 for (uint32_t r = 0; r < dims.size(); ++r)
272 if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
273 node->dim(r).unset();
275 node->dim(r).set(dims[r]);
278 const auto *quantization = tensor.quantization.get();
279 if (quantization != nullptr)
281 auto quantparam = luci_quantparam(quantization);
283 node->quantparam(std::move(quantparam));
286 const auto *sparsity = tensor.sparsity.get();
287 if (sparsity != nullptr)
289 auto sparsityparam = luci_sparsityparam(sparsity);
291 node->sparsityparam(std::move(sparsityparam));
295 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
297 const auto &op_codes = opcodes();
298 uint32_t index = op.opcode_index;
299 assert(index < op_codes.size());
300 const circle::OperatorCodeT &opcode = *op_codes[index];
302 return opcode.builtin_code;
305 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
307 const auto &op_codes = opcodes();
308 uint32_t index = op.opcode_index;
309 assert(index < op_codes.size());
310 const circle::OperatorCodeT &opcode = *op_codes[index];
312 if (!is_valid(opcode))
314 std::ostringstream oss;
315 oss << "(invalid: " << index << ")";
319 return ::luci::opcode_name(opcode);
322 bool CircleReader::parse(const circle::Model *model)
324 assert(model != nullptr);
326 _model.reset(model->UnPack());
328 // for direct pointer access
334 bool CircleReader::select_subgraph(uint32_t sgindex)
336 if (_model->subgraphs.size() <= sgindex)
342 _current_subgraph = _model->subgraphs[sgindex].get();
344 // for direct pointer access
345 auto subgraphs = _model_ptr->subgraphs();
346 const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
348 _tensors_ptr = subgraph->tensors();