2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "CircleImportMetadata.h"
24 uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx)
27 val += (buffer.at(idx + 0) << 0 * 8);
28 val += (buffer.at(idx + 1) << 1 * 8);
29 val += (buffer.at(idx + 2) << 2 * 8);
30 val += (buffer.at(idx + 3) << 3 * 8);
39 // 'source_table' is decoded to std::map<uint32_t, std::string> format.
40 const std::map<uint32_t, std::string>
41 decoded_source_table(const std::vector<uint8_t> &source_table_data)
43 std::map<uint32_t, std::string> source_id_name_map;
46 if (source_table_data.size() < 4)
47 throw std::runtime_error("Source table decode error : invalid entry number");
49 uint32_t entry_number = read_u32(source_table_data, idx);
50 idx += sizeof(uint32_t);
52 while (idx < source_table_data.size())
54 if (idx + 2 * sizeof(uint32_t) > source_table_data.size())
55 throw std::runtime_error("Source table decode error : invalid entry item");
57 uint32_t id = read_u32(source_table_data, idx);
58 idx += sizeof(uint32_t);
60 uint32_t length = read_u32(source_table_data, idx);
61 idx += sizeof(uint32_t);
63 if (idx + sizeof(char) * length > source_table_data.size())
64 throw std::runtime_error("Source table decode error : invalid entry data");
66 // The last character of name is '\0'.
67 // However, as std::string do not use '\0' for finding the end of string,
68 // we ignore the character and do not include it in the string.
69 std::string origin_name;
70 for (uint32_t j = 0; j < length - 1; ++j)
71 origin_name += source_table_data.at(idx + j);
72 assert(source_table_data.at(idx + length - 1) == '\0');
73 idx += sizeof(char) * length;
75 if (source_id_name_map.insert({id, origin_name}).second == false)
76 throw std::runtime_error("Source table decode error : duplicated origin ID");
79 if (idx != source_table_data.size())
80 throw std::runtime_error("Source table decode error : data size invalid");
82 if (source_id_name_map.size() != entry_number)
83 throw std::runtime_error("Source table decode error : result size mismatch");
85 return source_id_name_map;
88 // 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format.
89 const std::map<uint32_t, std::set<uint32_t>>
90 decoded_op_table(const std::vector<uint8_t> &op_table_data)
92 std::map<uint32_t, std::set<uint32_t>> node_source_ids_map;
95 if (op_table_data.size() < 4)
96 throw std::runtime_error("Op table decode error : invalid entry number");
98 uint32_t entry_number = read_u32(op_table_data, idx);
99 idx += sizeof(uint32_t);
101 while (idx < op_table_data.size())
103 if (idx + 2 * sizeof(uint32_t) > op_table_data.size())
104 throw std::runtime_error("Op table decode error : invalid entry item");
106 uint32_t id = read_u32(op_table_data, idx);
107 idx += sizeof(uint32_t);
109 uint32_t node_num = read_u32(op_table_data, idx);
110 idx += sizeof(uint32_t);
112 if (idx + sizeof(uint32_t) * node_num > op_table_data.size())
113 throw std::runtime_error("Source table decode error : invalid entry data");
115 std::set<uint32_t> source_ids;
116 for (uint32_t j = 0; j < node_num; ++j)
118 uint32_t origin = read_u32(op_table_data, idx);
119 idx += sizeof(uint32_t);
121 source_ids.insert(origin);
124 if (node_source_ids_map.insert({id, source_ids}).second == false)
125 throw std::runtime_error("Op table decode error : duplicated origin ID");
128 if (idx != op_table_data.size())
129 throw std::runtime_error("Op table decode error : data size invalid");
131 if (node_source_ids_map.size() != entry_number)
132 throw std::runtime_error("Op table decode error : entry number invalid");
134 return node_source_ids_map;
137 // 'execution_plan_table' is decoded to std::map<uint32_t, std::vector<uint32_t>> format.
138 const luci::ExecutionPlanTable
139 decoded_execution_plan(const std::vector<uint8_t> &execution_plan_data)
141 luci::ExecutionPlanTable execution_plan_table;
144 if (execution_plan_data.size() < 4)
145 throw std::runtime_error("Op table decode error : invalid entry number");
147 uint32_t entry_number = read_u32(execution_plan_data, idx);
148 idx += sizeof(uint32_t);
150 while (idx < execution_plan_data.size())
152 if (idx + 2 * sizeof(uint32_t) > execution_plan_data.size())
153 throw std::runtime_error("Op table decode error : invalid entry item");
155 uint32_t id = read_u32(execution_plan_data, idx);
156 idx += sizeof(uint32_t);
158 uint32_t size = read_u32(execution_plan_data, idx);
159 idx += sizeof(uint32_t);
161 if (idx + sizeof(uint32_t) * size > execution_plan_data.size())
162 throw std::runtime_error("Source table decode error : invalid entry data");
164 std::vector<uint32_t> execution_plan_vector;
165 for (uint32_t j = 0; j < size; ++j)
167 uint32_t execution_plan_inform = read_u32(execution_plan_data, idx);
168 idx += sizeof(uint32_t);
170 execution_plan_vector.push_back(execution_plan_inform);
173 if (execution_plan_table.insert({id, execution_plan_vector}).second == false)
174 throw std::runtime_error("Op table decode error : duplicated origin ID");
177 if (idx != execution_plan_data.size())
178 throw std::runtime_error("Op table decode error : data size invalid");
180 if (execution_plan_table.size() != entry_number)
181 throw std::runtime_error("Op table decode error : entry number invalid");
183 return execution_plan_table;
191 CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader)
193 const auto &metadata = reader.metadata();
194 for (uint32_t i = 0; i < metadata.size(); ++i)
196 const circle::MetadataT &meta = *metadata[i];
198 assert(meta.buffer < reader.buffers().size());
199 const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data;
201 if (meta.name.compare("ONE_op_table") == 0)
202 _op_table = decoded_op_table(buffer);
203 else if (meta.name.compare("ONE_source_table") == 0)
204 _source_table = decoded_source_table(buffer);
205 else if (meta.name.compare("ONE_execution_plan_table") == 0)
206 _execution_plan_table = decoded_execution_plan(buffer);
210 const OriginTable CircleImportMetadata::origin_table(void)
212 OriginTable origin_table;
214 if (_op_table.size() > 0 && _source_table.size() > 0)
216 for (auto &kv : _op_table)
218 const auto node_id = kv.first;
219 const auto &source_ids = kv.second;
221 std::vector<std::shared_ptr<CircleNodeOrigin>> origins;
222 for (auto source_id : source_ids)
224 const auto source_name = _source_table.at(source_id);
225 origins.push_back(single_origin(source_id, source_name));
228 auto origin = composite_origin(origins);
229 origin_table.emplace(node_id, origin);