Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / core / reader / CircleMicroReader.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci_interpreter/core/reader/CircleMicroReader.h"
18 #include "luci_interpreter/core/reader/CircleMicroReaderHelper.h"
19
20 #include <algorithm>
21
22 namespace luci_interpreter
23 {
24
25 // TODO check can we remove it
26 DataType luci_datatype(const circle::TensorType type)
27 {
28   switch (type)
29   {
30     case circle::TensorType_FLOAT32:
31       return DataType::FLOAT32;
32     case circle::TensorType_FLOAT16:
33       return DataType::FLOAT16;
34     case circle::TensorType_INT32:
35       return DataType::S32;
36     case circle::TensorType_UINT8:
37       return DataType::U8;
38     case circle::TensorType_INT64:
39       return DataType::S64;
40     case circle::TensorType_BOOL:
41       return DataType::BOOL;
42     case circle::TensorType_INT16:
43       return DataType::S16;
44     case circle::TensorType_COMPLEX64:
45       break;
46     case circle::TensorType_INT8:
47       return DataType::S8;
48     default:
49       break;
50   }
51   assert(false);
52   return DataType::Unknown;
53 }
54
55 FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
56 {
57   switch (type)
58   {
59     case circle::ActivationFunctionType::ActivationFunctionType_NONE:
60       return FusedActFunc::NONE;
61     case circle::ActivationFunctionType::ActivationFunctionType_RELU:
62       return FusedActFunc::RELU;
63     case circle::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
64       return FusedActFunc::RELU_N1_TO_1;
65     case circle::ActivationFunctionType::ActivationFunctionType_RELU6:
66       return FusedActFunc::RELU6;
67     case circle::ActivationFunctionType::ActivationFunctionType_TANH:
68       return FusedActFunc::TANH;
69     case circle::ActivationFunctionType::ActivationFunctionType_SIGN_BIT:
70       return FusedActFunc::SIGN_BIT;
71     default:
72       break;
73   }
74   assert(false);
75   return FusedActFunc::UNDEFINED;
76 }
77
78 Padding luci_padding(const circle::Padding padding)
79 {
80   switch (padding)
81   {
82     case circle::Padding::Padding_SAME:
83       return Padding::SAME;
84     case circle::Padding::Padding_VALID:
85       return Padding::VALID;
86   }
87   assert(false);
88   return Padding::UNDEFINED;
89 }
90
91 MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
92 {
93   switch (mode)
94   {
95     case circle::MirrorPadMode::MirrorPadMode_REFLECT:
96       return MirrorPadMode::REFLECT;
97     case circle::MirrorPadMode::MirrorPadMode_SYMMETRIC:
98       return MirrorPadMode::SYMMETRIC;
99   }
100   assert(false);
101   return MirrorPadMode::UNDEFINED;
102 }
103
104 circle::BuiltinOperator CircleReader::builtin_code(const circle::Operator *op) const
105 {
106   assert(op != nullptr);
107
108   const auto op_codes = opcodes();
109   uint32_t index = op->opcode_index();
110   assert(index < op_codes.size());
111   const auto opcode = op_codes[index];
112   assert(opcode != nullptr);
113
114   return circle::builtin_code_neutral(opcode);
115 }
116
117 bool CircleReader::parse(const circle::Model *model)
118 {
119   assert(model != nullptr);
120
121   // for direct pointer access
122   _model = model;
123
124   return true;
125 }
126
127 bool CircleReader::select_subgraph(uint32_t sgindex)
128 {
129   if (num_subgraph() <= sgindex)
130   {
131     assert(false);
132     return false;
133   }
134
135   // for direct pointer access
136   auto subgraphs = _model->subgraphs();
137   assert(subgraphs != nullptr);
138
139   _current_subgraph = subgraphs->Get(sgindex);
140   assert(_current_subgraph != nullptr);
141
142   _current_subgraph_index = sgindex;
143
144   return true;
145 }
146
147 template <typename T>
148 VectorWrapper<T>::VectorWrapper(const flatbuffers::Vector<T> *ptr) : _vector(ptr)
149 {
150   // Do nothing
151 }
152
153 template <typename T> uint32_t VectorWrapper<T>::size() const
154 {
155   return null() ? 0 : _vector->size();
156 }
157
158 template <typename T> const T *VectorWrapper<T>::data() const
159 {
160   return null() ? nullptr : _vector->data();
161 }
162
163 template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::begin() const
164 {
165   return null() ? iterator(nullptr, 0) : _vector->begin();
166 }
167
168 template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::end() const
169 {
170   return null() ? begin() : _vector->end();
171 }
172
173 template <typename T> typename VectorWrapper<T>::value_type VectorWrapper<T>::at(uint32_t i) const
174 {
175   if (i >= size())
176   {
177     // TODO find better error message
178     assert(false && "Access to prohibited vector element");
179   }
180
181   return _vector->Get(i);
182 }
183
184 template <typename T>
185 typename VectorWrapper<T>::value_type VectorWrapper<T>::operator[](uint32_t i) const
186 {
187   return at(i);
188 }
189
190 template <typename T> bool VectorWrapper<T>::null() const { return _vector == nullptr; }
191 template <typename T> bool VectorWrapper<T>::empty() const { return size() == 0; }
192
193 #define REGISTER_WRAPPER(T) template class VectorWrapper<T>
194 REGISTER_WRAPPER(flatbuffers::Offset<circle::SubGraph>);
195 REGISTER_WRAPPER(flatbuffers::Offset<circle::Buffer>);
196 REGISTER_WRAPPER(flatbuffers::Offset<circle::Tensor>);
197 REGISTER_WRAPPER(flatbuffers::Offset<circle::Operator>);
198 REGISTER_WRAPPER(flatbuffers::Offset<circle::OperatorCode>);
199 REGISTER_WRAPPER(flatbuffers::Offset<circle::Metadata>);
200 REGISTER_WRAPPER(int32_t);
201 REGISTER_WRAPPER(uint8_t);
202 #undef REGISTER_WRAPPER
203
204 } // namespace luci_interpreter