Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / include / luci_interpreter / core / reader / CircleMicroReader.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_MICRO_INTERPRETER_MICRO_READER_H__
18 #define __LUCI_MICRO_INTERPRETER_MICRO_READER_H__
19
20 #include "luci_interpreter/core/ParamsType.h"
21 #include "luci_interpreter/core/DataType.h"
22
23 #include <circle-generated/circle/schema_generated.h>
24
25 #include <map>
26 #include <memory>
27 #include <vector>
28
29 namespace luci_interpreter
30 {
31
32 #ifdef USE_STATIC_ALLOC
33 namespace
34 {
35
36 using ExecutionPlanTable = std::map<uint32_t, std::vector<uint32_t>>;
37
38 template <typename VECTORTYPE> uint32_t read_u32(const VECTORTYPE &buffer, uint32_t idx)
39 {
40   static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
41
42   uint32_t val = 0;
43   val += (buffer.at(idx + 0) << 0 * 8);
44   val += (buffer.at(idx + 1) << 1 * 8);
45   val += (buffer.at(idx + 2) << 2 * 8);
46   val += (buffer.at(idx + 3) << 3 * 8);
47   return val;
48 }
49
50 } // namespace
51
52 namespace read_metadata
53 {
54
55 template <typename VECTORTYPE>
56 ExecutionPlanTable decode_execution_plan(const VECTORTYPE &execution_plan_data)
57 {
58   static_assert(std::is_same<typename VECTORTYPE::value_type, uint8_t>::value, "Types mismatch!");
59
60   ExecutionPlanTable execution_plan_table;
61   uint32_t idx = 0;
62
63   if (execution_plan_data.size() < 4)
64     assert(false && "Op table decode error : invalid entry number");
65
66   uint32_t entry_number = read_u32(execution_plan_data, idx);
67   idx += sizeof(uint32_t);
68
69   while (idx < execution_plan_data.size())
70   {
71     if (idx + 2 * sizeof(uint32_t) > execution_plan_data.size())
72       assert(false && "Op table decode error : invalid entry item");
73
74     uint32_t id = read_u32(execution_plan_data, idx);
75     idx += sizeof(uint32_t);
76
77     uint32_t size = read_u32(execution_plan_data, idx);
78
79     if (size == 0)
80       assert(false && "Op table decode error : empty execution plan entry");
81
82     idx += sizeof(uint32_t);
83
84     if (idx + sizeof(uint32_t) * size > execution_plan_data.size())
85       assert(false && "Source table decode error : invalid entry data");
86
87     std::vector<uint32_t> execution_plan_vector;
88     uint32_t position = read_u32(execution_plan_data, idx);
89     idx += sizeof(uint32_t);
90
91     for (uint32_t j = 1; j < size; ++j)
92     {
93       uint32_t execution_plan_inform = read_u32(execution_plan_data, idx);
94       idx += sizeof(uint32_t);
95
96       execution_plan_vector.push_back(execution_plan_inform);
97     }
98
99     if (!execution_plan_table.insert({position, execution_plan_vector}).second)
100       assert(false && "Op table decode error : duplicated origin ID");
101   }
102
103   if (idx != execution_plan_data.size())
104     assert(false && "Op table decode error : data size invalid");
105
106   if (execution_plan_table.size() != entry_number)
107     assert(false && "Op table decode error : entry number invalid");
108
109   return execution_plan_table;
110 }
111
112 } // namespace read_metadata
113 #endif
114
115 DataType luci_datatype(circle::TensorType type);
116 FusedActFunc luci_actfunc(circle::ActivationFunctionType type);
117 Padding luci_padding(circle::Padding padding);
118 MirrorPadMode luci_mirrorpad_mode(circle::MirrorPadMode mode);
119
120 /**
121  * @brief Wrapper to use flatbuffers::Vector pointer as std::vector entity
122  */
123 template <typename T> class VectorWrapper
124 {
125 public:
126   explicit VectorWrapper(const flatbuffers::Vector<T> *ptr);
127
128   const T *data() const;
129   uint32_t size() const;
130
131   using iterator = typename flatbuffers::Vector<T>::const_iterator;
132   iterator begin() const;
133   iterator end() const;
134
135   using value_type = typename flatbuffers::Vector<T>::return_type;
136   value_type at(uint32_t i) const;
137   value_type operator[](uint32_t i) const;
138
139   bool null() const;
140   bool empty() const;
141
142 private:
143   const flatbuffers::Vector<T> *_vector;
144 };
145
146 template <typename T> VectorWrapper<T> wrap(const flatbuffers::Vector<T> *vec)
147 {
148   return VectorWrapper<T>(vec);
149 }
150
151 /**
152  * @brief Loads Circle file and provides helpers to access attributes
153  */
154 class CircleReader
155 {
156 public:
157   using CircleBuffers = VectorWrapper<flatbuffers::Offset<circle::Buffer>>;
158   using CircleTensors = VectorWrapper<flatbuffers::Offset<circle::Tensor>>;
159   using CircleOperators = VectorWrapper<flatbuffers::Offset<circle::Operator>>;
160   using CircleOperatorCodes = VectorWrapper<flatbuffers::Offset<circle::OperatorCode>>;
161   using CircleMetadataSet = VectorWrapper<flatbuffers::Offset<circle::Metadata>>;
162
163 public:
164   CircleReader() = default;
165
166 public: // direct API
167   CircleOperatorCodes opcodes() const { return wrap(_model->operator_codes()); }
168   CircleBuffers buffers() const { return wrap(_model->buffers()); }
169   CircleTensors tensors() const { return wrap(_current_subgraph->tensors()); }
170   CircleOperators operators() const { return wrap(_current_subgraph->operators()); }
171   VectorWrapper<int32_t> inputs() const { return wrap(_current_subgraph->inputs()); }
172   VectorWrapper<int32_t> outputs() const { return wrap(_current_subgraph->outputs()); }
173   circle::DataFormat data_format() const { return _current_subgraph->data_format(); }
174   CircleMetadataSet metadata() const { return wrap(_model->metadata()); }
175
176   uint32_t num_subgraph() const { return wrap(_model->subgraphs()).size(); }
177   circle::BuiltinOperator builtin_code(const circle::Operator *op) const;
178
179 public:
180   bool parse(const circle::Model *model);
181   bool select_subgraph(uint32_t subgraph);
182   uint32_t get_current_subgraph_index() const { return _current_subgraph_index; }
183
184 private:
185   const circle::Model *_model{nullptr};
186   const circle::SubGraph *_current_subgraph{nullptr};
187   uint32_t _current_subgraph_index{0};
188 };
189
190 } // namespace luci_interpreter
191
192 #endif // __LUCI_MICRO_INTERPRETER_MICRO_READER_H__