Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / CircleReader.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/CircleReader.h"
18
19 #include <memory>
20 #include <sstream>
21 #include <string>
22
23 namespace luci
24 {
25
26 bool is_valid(const circle::OperatorCodeT &opcode)
27 {
28   circle::BuiltinOperator code = opcode.builtin_code;
29   return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
30 }
31
32 bool is_custom(const circle::OperatorCodeT &opcode)
33 {
34   circle::BuiltinOperator code = opcode.builtin_code;
35   return (code == circle::BuiltinOperator_CUSTOM);
36 }
37
38 std::string opcode_name(const circle::OperatorCodeT &opcode)
39 {
40   if (!is_valid(opcode))
41   {
42     std::ostringstream oss;
43     oss << "(invalid)";
44     return oss.str();
45   }
46
47   if (is_custom(opcode))
48   {
49     if (opcode.custom_code.empty())
50       return "(invalid custom)";
51
52     return opcode.custom_code;
53   }
54
55   circle::BuiltinOperator code = opcode.builtin_code;
56   return circle::EnumNameBuiltinOperator(code);
57 }
58
59 const char *tensor_name(const circle::TensorT &tensor)
60 {
61   static const char *kEmptyTensorName = "(noname)";
62
63   if (!tensor.name.empty())
64     return tensor.name.c_str();
65
66   return kEmptyTensorName;
67 }
68
69 const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
70 {
71   return tensor.quantization.get();
72 }
73
74 loco::DataType luci_datatype(const circle::TensorType type)
75 {
76   switch (type)
77   {
78     case circle::TensorType_FLOAT32:
79       return loco::DataType::FLOAT32;
80     case circle::TensorType_FLOAT16:
81       return loco::DataType::FLOAT16;
82     case circle::TensorType_INT32:
83       return loco::DataType::S32;
84     case circle::TensorType_UINT8:
85       return loco::DataType::U8;
86     case circle::TensorType_INT64:
87       return loco::DataType::S64;
88     case circle::TensorType_STRING:
89       break;
90     case circle::TensorType_BOOL:
91       return loco::DataType::BOOL;
92     case circle::TensorType_INT16:
93       return loco::DataType::S16;
94     case circle::TensorType_COMPLEX64:
95       break;
96     case circle::TensorType_INT8:
97       return loco::DataType::S8;
98     default:
99       break;
100   }
101   assert(false);
102   return loco::DataType::Unknown;
103 }
104
105 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
106 {
107   switch (type)
108   {
109     case circle::ActivationFunctionType::ActivationFunctionType_NONE:
110       return luci::FusedActFunc::NONE;
111     case circle::ActivationFunctionType::ActivationFunctionType_RELU:
112       return luci::FusedActFunc::RELU;
113     case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
114       return luci::FusedActFunc::RELU_N1_TO_1;
115     case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
116       return luci::FusedActFunc::RELU6;
117     case circle::ActivationFunctionType::ActivationFunctionType_TANH:
118       break;
119     default:
120       break;
121   }
122   assert(false);
123   return luci::FusedActFunc::UNDEFINED;
124 }
125
126 Padding luci_padding(const circle::Padding padding)
127 {
128   switch (padding)
129   {
130     case circle::Padding::Padding_SAME:
131       return Padding::SAME;
132     case circle::Padding::Padding_VALID:
133       return Padding::VALID;
134   }
135   assert(false);
136   return Padding::UNDEFINED;
137 }
138
139 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
140 {
141   switch (mode)
142   {
143     case circle::MirrorPadMode::MirrorPadMode_REFLECT:
144       return MirrorPadMode::REFLECT;
145     case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
146       return MirrorPadMode::SYMMETRIC;
147   }
148   assert(false);
149   return MirrorPadMode::UNDEFINED;
150 }
151
152 std::unique_ptr<CircleQuantParam>
153 luci_quantparam(const circle::QuantizationParametersT *quantization)
154 {
155   const auto &min = quantization->min;
156   const auto &max = quantization->max;
157   const auto &scale = quantization->scale;
158   const auto &zero_point = quantization->zero_point;
159   const auto &quantized_dimension = quantization->quantized_dimension;
160
161   if ((!min.empty() && !max.empty()) || (!scale.empty() && !zero_point.empty()))
162   {
163     auto quantparam = std::make_unique<CircleQuantParam>();
164
165     quantparam->min = min;
166     quantparam->max = max;
167     quantparam->scale = scale;
168     quantparam->zerop = zero_point;
169     quantparam->quantized_dimension = quantized_dimension;
170
171     return quantparam;
172   }
173
174   return nullptr;
175 }
176
177 void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
178 {
179   node->name(tensor_name(tensor));
180   node->dtype(luci_datatype(tensor.type));
181
182   std::vector<int32_t> dims = tensor.shape; // in NHWC
183   node->rank(dims.size());
184   for (uint32_t r = 0; r < dims.size(); ++r)
185   {
186     node->dim(r) = loco::Dimension(dims[r]);
187   }
188
189   const auto *quantization = tensor.quantization.get();
190   if (quantization != nullptr)
191   {
192     auto quantparam = luci_quantparam(quantization);
193     if (quantparam)
194       node->quantparam(std::move(quantparam));
195   }
196 }
197
198 circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
199 {
200   const auto &op_codes = opcodes();
201   uint32_t index = op.opcode_index;
202   assert(index < op_codes.size());
203   const circle::OperatorCodeT &opcode = *op_codes[index];
204
205   return opcode.builtin_code;
206 }
207
208 std::string CircleReader::opcode_name(const circle::OperatorT &op) const
209 {
210   const auto &op_codes = opcodes();
211   uint32_t index = op.opcode_index;
212   assert(index < op_codes.size());
213   const circle::OperatorCodeT &opcode = *op_codes[index];
214
215   if (!is_valid(opcode))
216   {
217     std::ostringstream oss;
218     oss << "(invalid: " << index << ")";
219     return oss.str();
220   }
221
222   return ::luci::opcode_name(opcode);
223 }
224
225 bool CircleReader::parse(const circle::Model *model)
226 {
227   assert(model != nullptr);
228
229   _model.reset(model->UnPack());
230
231   // for direct pointer access
232   _model_ptr = model;
233
234   return true;
235 }
236
237 bool CircleReader::select_subgraph(uint32_t sgindex)
238 {
239   if (_model->subgraphs.size() <= sgindex)
240   {
241     assert(false);
242     return false;
243   }
244
245   _current_subgraph = _model->subgraphs[sgindex].get();
246
247   // for direct pointer access
248   auto subgraphs = _model_ptr->subgraphs();
249   const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
250
251   _tensors_ptr = subgraph->tensors();
252
253   return true;
254 }
255
256 } // namespace luci