2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Import/CircleReader.h"
26 bool is_valid(const circle::OperatorCodeT &opcode)
28 circle::BuiltinOperator code = opcode.builtin_code;
29 return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
32 bool is_custom(const circle::OperatorCodeT &opcode)
34 circle::BuiltinOperator code = opcode.builtin_code;
35 return (code == circle::BuiltinOperator_CUSTOM);
38 std::string opcode_name(const circle::OperatorCodeT &opcode)
40 if (!is_valid(opcode))
42 std::ostringstream oss;
47 if (is_custom(opcode))
49 if (opcode.custom_code.empty())
50 return "(invalid custom)";
52 return opcode.custom_code;
55 circle::BuiltinOperator code = opcode.builtin_code;
56 return circle::EnumNameBuiltinOperator(code);
59 const char *tensor_name(const circle::TensorT &tensor)
61 static const char *kEmptyTensorName = "(noname)";
63 if (!tensor.name.empty())
64 return tensor.name.c_str();
66 return kEmptyTensorName;
69 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
71 return tensor.quantization.get();
74 loco::DataType luci_datatype(const circle::TensorType type)
78 case circle::TensorType_FLOAT32:
79 return loco::DataType::FLOAT32;
80 case circle::TensorType_FLOAT16:
81 return loco::DataType::FLOAT16;
82 case circle::TensorType_INT32:
83 return loco::DataType::S32;
84 case circle::TensorType_UINT8:
85 return loco::DataType::U8;
86 case circle::TensorType_INT64:
87 return loco::DataType::S64;
88 case circle::TensorType_STRING:
90 case circle::TensorType_BOOL:
91 return loco::DataType::BOOL;
92 case circle::TensorType_INT16:
93 return loco::DataType::S16;
94 case circle::TensorType_COMPLEX64:
96 case circle::TensorType_INT8:
97 return loco::DataType::S8;
102 return loco::DataType::Unknown;
105 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
109 case circle::ActivationFunctionType::ActivationFunctionType_NONE:
110 return luci::FusedActFunc::NONE;
111 case circle::ActivationFunctionType::ActivationFunctionType_RELU:
112 return luci::FusedActFunc::RELU;
113 case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
114 return luci::FusedActFunc::RELU_N1_TO_1;
115 case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
116 return luci::FusedActFunc::RELU6;
117 case circle::ActivationFunctionType::ActivationFunctionType_TANH:
123 return luci::FusedActFunc::UNDEFINED;
126 Padding luci_padding(const circle::Padding padding)
130 case circle::Padding::Padding_SAME:
131 return Padding::SAME;
132 case circle::Padding::Padding_VALID:
133 return Padding::VALID;
136 return Padding::UNDEFINED;
139 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
143 case circle::MirrorPadMode::MirrorPadMode_REFLECT:
144 return MirrorPadMode::REFLECT;
145 case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
146 return MirrorPadMode::SYMMETRIC;
149 return MirrorPadMode::UNDEFINED;
152 std::unique_ptr<CircleQuantParam>
153 luci_quantparam(const circle::QuantizationParametersT *quantization)
155 const auto &min = quantization->min;
156 const auto &max = quantization->max;
157 const auto &scale = quantization->scale;
158 const auto &zero_point = quantization->zero_point;
160 if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
162 auto quantparam = std::make_unique<CircleQuantParam>();
164 quantparam->min = min;
165 quantparam->max = max;
166 quantparam->scale = scale;
167 quantparam->zerop = zero_point;
175 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
177 node->name(tensor_name(tensor));
178 node->dtype(luci_datatype(tensor.type));
180 std::vector<int32_t> dims = tensor.shape; // in NHWC
181 node->rank(dims.size());
182 for (uint32_t r = 0; r < dims.size(); ++r)
184 node->dim(r) = loco::Dimension(dims[r]);
187 const auto *quantization = tensor.quantization.get();
188 if (quantization != nullptr)
190 auto quantparam = luci_quantparam(quantization);
192 node->quantparam(std::move(quantparam));
196 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
198 const auto &op_codes = opcodes();
199 uint32_t index = op.opcode_index;
200 assert(index < op_codes.size());
201 const circle::OperatorCodeT &opcode = *op_codes[index];
203 return opcode.builtin_code;
206 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
208 const auto &op_codes = opcodes();
209 uint32_t index = op.opcode_index;
210 assert(index < op_codes.size());
211 const circle::OperatorCodeT &opcode = *op_codes[index];
213 if (!is_valid(opcode))
215 std::ostringstream oss;
216 oss << "(invalid: " << index << ")";
220 return ::luci::opcode_name(opcode);
223 bool CircleReader::parse(const circle::Model *model)
225 assert(model != nullptr);
227 _model.reset(model->UnPack());
229 // for direct pointer access
235 bool CircleReader::select_subgraph(uint32_t sgindex)
237 if (_model->subgraphs.size() <= sgindex)
243 _current_subgraph = _model->subgraphs[sgindex].get();
245 // for direct pointer access
246 auto subgraphs = _model_ptr->subgraphs();
247 const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
249 _tensors_ptr = subgraph->tensors();