Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / import / src / CircleImportMetadata.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CircleImportMetadata.h"
18
19 #include <vector>
20
21 namespace
22 {
23
24 uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx)
25 {
26   uint32_t val = 0;
27   val += (buffer.at(idx + 0) << 0 * 8);
28   val += (buffer.at(idx + 1) << 1 * 8);
29   val += (buffer.at(idx + 2) << 2 * 8);
30   val += (buffer.at(idx + 3) << 3 * 8);
31   return val;
32 }
33
34 } // namespace
35
36 namespace
37 {
38
39 // 'source_table' is decoded to std::map<uint32_t, std::string> format.
40 const std::map<uint32_t, std::string>
41 decoded_source_table(const std::vector<uint8_t> &source_table_data)
42 {
43   std::map<uint32_t, std::string> source_id_name_map;
44   uint32_t idx = 0;
45
46   if (source_table_data.size() < 4)
47     throw std::runtime_error("Source table decode error : invalid entry number");
48
49   uint32_t entry_number = read_u32(source_table_data, idx);
50   idx += sizeof(uint32_t);
51
52   while (idx < source_table_data.size())
53   {
54     if (idx + 2 * sizeof(uint32_t) > source_table_data.size())
55       throw std::runtime_error("Source table decode error : invalid entry item");
56
57     uint32_t id = read_u32(source_table_data, idx);
58     idx += sizeof(uint32_t);
59
60     uint32_t length = read_u32(source_table_data, idx);
61     idx += sizeof(uint32_t);
62
63     if (idx + sizeof(char) * length > source_table_data.size())
64       throw std::runtime_error("Source table decode error : invalid entry data");
65
66     // The last character of name is '\0'.
67     // However, as std::string do not use '\0' for finding the end of string,
68     // we ignore the character and do not include it in the string.
69     std::string origin_name;
70     for (uint32_t j = 0; j < length - 1; ++j)
71       origin_name += source_table_data.at(idx + j);
72     assert(source_table_data.at(idx + length - 1) == '\0');
73     idx += sizeof(char) * length;
74
75     if (source_id_name_map.insert({id, origin_name}).second == false)
76       throw std::runtime_error("Source table decode error : duplicated origin ID");
77   }
78
79   if (idx != source_table_data.size())
80     throw std::runtime_error("Source table decode error : data size invalid");
81
82   if (source_id_name_map.size() != entry_number)
83     throw std::runtime_error("Source table decode error : result size mismatch");
84
85   return source_id_name_map;
86 }
87
88 // 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format.
89 const std::map<uint32_t, std::set<uint32_t>>
90 decoded_op_table(const std::vector<uint8_t> &op_table_data)
91 {
92   std::map<uint32_t, std::set<uint32_t>> node_source_ids_map;
93   uint32_t idx = 0;
94
95   if (op_table_data.size() < 4)
96     throw std::runtime_error("Op table decode error : invalid entry number");
97
98   uint32_t entry_number = read_u32(op_table_data, idx);
99   idx += sizeof(uint32_t);
100
101   while (idx < op_table_data.size())
102   {
103     if (idx + 2 * sizeof(uint32_t) > op_table_data.size())
104       throw std::runtime_error("Op table decode error : invalid entry item");
105
106     uint32_t id = read_u32(op_table_data, idx);
107     idx += sizeof(uint32_t);
108
109     uint32_t node_num = read_u32(op_table_data, idx);
110     idx += sizeof(uint32_t);
111
112     if (idx + sizeof(uint32_t) * node_num > op_table_data.size())
113       throw std::runtime_error("Source table decode error : invalid entry data");
114
115     std::set<uint32_t> source_ids;
116     for (uint32_t j = 0; j < node_num; ++j)
117     {
118       uint32_t origin = read_u32(op_table_data, idx);
119       idx += sizeof(uint32_t);
120
121       source_ids.insert(origin);
122     }
123
124     if (node_source_ids_map.insert({id, source_ids}).second == false)
125       throw std::runtime_error("Op table decode error : duplicated origin ID");
126   }
127
128   if (idx != op_table_data.size())
129     throw std::runtime_error("Op table decode error : data size invalid");
130
131   if (node_source_ids_map.size() != entry_number)
132     throw std::runtime_error("Op table decode error : entry number invalid");
133
134   return node_source_ids_map;
135 }
136
137 // 'execution_plan_table' is decoded to std::map<uint32_t, std::vector<uint32_t>> format.
138 const luci::ExecutionPlanTable
139 decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data)
140 {
141   luci::ExecutionPlanTable execution_plan_table;
142   uint32_t idx = 0;
143
144   if (execution_plan_data.size() < 4)
145     throw std::runtime_error("Op table decode error : invalid entry number");
146
147   uint32_t entry_number = read_u32(execution_plan_data, idx);
148   idx += sizeof(uint32_t);
149
150   while (idx < execution_plan_data.size())
151   {
152     if (idx + 2 * sizeof(uint32_t) > execution_plan_data.size())
153       throw std::runtime_error("Op table decode error : invalid entry item");
154
155     uint32_t id = read_u32(execution_plan_data, idx);
156     idx += sizeof(uint32_t);
157
158     uint32_t size = read_u32(execution_plan_data, idx);
159     idx += sizeof(uint32_t);
160
161     if (idx + sizeof(uint32_t) * size > execution_plan_data.size())
162       throw std::runtime_error("Source table decode error : invalid entry data");
163
164     std::vector<uint32_t> execution_plan_vector;
165     for (uint32_t j = 0; j < size; ++j)
166     {
167       uint32_t execution_plan_inform = read_u32(execution_plan_data, idx);
168       idx += sizeof(uint32_t);
169
170       execution_plan_vector.push_back(execution_plan_inform);
171     }
172
173     if (execution_plan_table.insert({id, execution_plan_vector}).second == false)
174       throw std::runtime_error("Op table decode error : duplicated origin ID");
175   }
176
177   if (idx != execution_plan_data.size())
178     throw std::runtime_error("Op table decode error : data size invalid");
179
180   if (execution_plan_table.size() != entry_number)
181     throw std::runtime_error("Op table decode error : entry number invalid");
182
183   return execution_plan_table;
184 }
185
186 } // namespace
187
188 namespace luci
189 {
190
191 CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader)
192 {
193   const auto &metadata = reader.metadata();
194   for (uint32_t i = 0; i < metadata.size(); ++i)
195   {
196     const circle::MetadataT &meta = *metadata[i];
197
198     assert(meta.buffer < reader.buffers().size());
199     const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data;
200
201     if (meta.name.compare("ONE_op_table") == 0)
202       _op_table = decoded_op_table(buffer);
203     else if (meta.name.compare("ONE_source_table") == 0)
204       _source_table = decoded_source_table(buffer);
205     else if (meta.name.compare("ONE_execution_plan_table") == 0)
206       _execution_plan_table = decoded_execution_plan(buffer);
207   }
208 }
209
210 const OriginTable CircleImportMetadata::origin_table(void)
211 {
212   OriginTable origin_table;
213
214   if (_op_table.size() > 0 && _source_table.size() > 0)
215   {
216     for (auto &kv : _op_table)
217     {
218       const auto node_id = kv.first;
219       const auto &source_ids = kv.second;
220
221       std::vector<std::shared_ptr<CircleNodeOrigin>> origins;
222       for (auto source_id : source_ids)
223       {
224         const auto source_name = _source_table.at(source_id);
225         origins.push_back(single_origin(source_id, source_name));
226       }
227
228       auto origin = composite_origin(origins);
229       origin_table.emplace(node_id, origin);
230     }
231   }
232
233   return origin_table;
234 }
235
236 } // namespace luci