Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / CircleReader.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/CircleReader.h"
18
19 #include <memory>
20 #include <sstream>
21 #include <string>
22
23 namespace luci
24 {
25
26 bool is_valid(const circle::OperatorCodeT &opcode)
27 {
28   circle::BuiltinOperator code = opcode.builtin_code;
29   return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
30 }
31
32 bool is_custom(const circle::OperatorCodeT &opcode)
33 {
34   circle::BuiltinOperator code = opcode.builtin_code;
35   return (code == circle::BuiltinOperator_CUSTOM);
36 }
37
38 std::string opcode_name(const circle::OperatorCodeT &opcode)
39 {
40   if (!is_valid(opcode))
41   {
42     std::ostringstream oss;
43     oss << "(invalid)";
44     return oss.str();
45   }
46
47   if (is_custom(opcode))
48   {
49     if (opcode.custom_code.empty())
50       return "(invalid custom)";
51
52     return opcode.custom_code;
53   }
54
55   circle::BuiltinOperator code = opcode.builtin_code;
56   return circle::EnumNameBuiltinOperator(code);
57 }
58
59 const char *tensor_name(const circle::TensorT &tensor)
60 {
61   static const char *kEmptyTensorName = "(noname)";
62
63   if (!tensor.name.empty())
64     return tensor.name.c_str();
65
66   return kEmptyTensorName;
67 }
68
69 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
70 {
71   return tensor.quantization.get();
72 }
73
74 loco::DataType luci_datatype(const circle::TensorType type)
75 {
76   switch (type)
77   {
78     case circle::TensorType_FLOAT32:
79       return loco::DataType::FLOAT32;
80     case circle::TensorType_FLOAT16:
81       return loco::DataType::FLOAT16;
82     case circle::TensorType_INT32:
83       return loco::DataType::S32;
84     case circle::TensorType_UINT8:
85       return loco::DataType::U8;
86     case circle::TensorType_INT64:
87       return loco::DataType::S64;
88     case circle::TensorType_STRING:
89       break;
90     case circle::TensorType_BOOL:
91       return loco::DataType::BOOL;
92     case circle::TensorType_INT16:
93       return loco::DataType::S16;
94     case circle::TensorType_COMPLEX64:
95       break;
96     case circle::TensorType_INT8:
97       return loco::DataType::S8;
98     default:
99       break;
100   }
101   assert(false);
102   return loco::DataType::Unknown;
103 }
104
105 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
106 {
107   switch (type)
108   {
109     case circle::ActivationFunctionType::ActivationFunctionType_NONE:
110       return luci::FusedActFunc::NONE;
111     case circle::ActivationFunctionType::ActivationFunctionType_RELU:
112       return luci::FusedActFunc::RELU;
113     case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
114       return luci::FusedActFunc::RELU_N1_TO_1;
115     case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
116       return luci::FusedActFunc::RELU6;
117     case circle::ActivationFunctionType::ActivationFunctionType_TANH:
118       break;
119     default:
120       break;
121   }
122   assert(false);
123   return luci::FusedActFunc::UNDEFINED;
124 }
125
126 Padding luci_padding(const circle::Padding padding)
127 {
128   switch (padding)
129   {
130     case circle::Padding::Padding_SAME:
131       return Padding::SAME;
132     case circle::Padding::Padding_VALID:
133       return Padding::VALID;
134   }
135   assert(false);
136   return Padding::UNDEFINED;
137 }
138
139 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
140 {
141   switch (mode)
142   {
143     case circle::MirrorPadMode::MirrorPadMode_REFLECT:
144       return MirrorPadMode::REFLECT;
145     case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
146       return MirrorPadMode::SYMMETRIC;
147   }
148   assert(false);
149   return MirrorPadMode::UNDEFINED;
150 }
151
152 std::unique_ptr<CircleQuantParam>
153 luci_quantparam(const circle::QuantizationParametersT *quantization)
154 {
155   const auto &min = quantization->min;
156   const auto &max = quantization->max;
157   const auto &scale = quantization->scale;
158   const auto &zero_point = quantization->zero_point;
159
160   if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
161   {
162     auto quantparam = std::make_unique<CircleQuantParam>();
163
164     quantparam->min = min;
165     quantparam->max = max;
166     quantparam->scale = scale;
167     quantparam->zerop = zero_point;
168
169     return quantparam;
170   }
171
172   return nullptr;
173 }
174
175 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
176 {
177   node->name(tensor_name(tensor));
178   node->dtype(luci_datatype(tensor.type));
179
180   std::vector<int32_t> dims = tensor.shape; // in NHWC
181   node->rank(dims.size());
182   for (uint32_t r = 0; r < dims.size(); ++r)
183   {
184     node->dim(r) = loco::Dimension(dims[r]);
185   }
186
187   const auto *quantization = tensor.quantization.get();
188   if (quantization != nullptr)
189   {
190     auto quantparam = luci_quantparam(quantization);
191     if (quantparam)
192       node->quantparam(std::move(quantparam));
193   }
194 }
195
196 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
197 {
198   const auto &op_codes = opcodes();
199   uint32_t index = op.opcode_index;
200   assert(index < op_codes.size());
201   const circle::OperatorCodeT &opcode = *op_codes[index];
202
203   return opcode.builtin_code;
204 }
205
206 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
207 {
208   const auto &op_codes = opcodes();
209   uint32_t index = op.opcode_index;
210   assert(index < op_codes.size());
211   const circle::OperatorCodeT &opcode = *op_codes[index];
212
213   if (!is_valid(opcode))
214   {
215     std::ostringstream oss;
216     oss << "(invalid: " << index << ")";
217     return oss.str();
218   }
219
220   return ::luci::opcode_name(opcode);
221 }
222
223 bool CircleReader::parse(const circle::Model *model)
224 {
225   assert(model != nullptr);
226
227   _model.reset(model->UnPack());
228
229   // for direct pointer access
230   _model_ptr = model;
231
232   return true;
233 }
234
235 bool CircleReader::select_subgraph(uint32_t sgindex)
236 {
237   if (_model->subgraphs.size() <= sgindex)
238   {
239     assert(false);
240     return false;
241   }
242
243   _current_subgraph = _model->subgraphs[sgindex].get();
244
245   // for direct pointer access
246   auto subgraphs = _model_ptr->subgraphs();
247   const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
248
249   _tensors_ptr = subgraph->tensors();
250
251   return true;
252 }
253
254 } // namespace luci