2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
20 #include "CircleModel.h"
21 #include "DataLookup.h"
23 namespace tflite2circle
27 Offset<MetaDataBufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
29 if (tflite_flatbuffer_vec == nullptr)
31 std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
32 tflite_flatbuffer_vec->end()};
33 _circle_flatbuffer_vec_offset = fb->CreateVector(metadata_buffer_vec);
37 Offset<BufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
39 std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
41 for (auto it : *tflite_flatbuffer_vec)
43 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
46 std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
47 buffer_data = fb->CreateVector(data_vec);
49 circle::BufferBuilder circle_buffer_builder{*fb};
50 circle_buffer_builder.add_data(buffer_data);
51 auto circle_buffers = circle_buffer_builder.Finish();
52 buffers_vec.emplace_back(circle_buffers);
54 _circle_flatbuffer_vec_offset = fb->CreateVector(buffers_vec);
58 Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
60 std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
62 for (auto it_sg : *tflite_flatbuffer_vec)
64 // tensors of subgraph
65 std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
67 auto tflite_tensors = it_sg->tensors();
68 for (auto it : *tflite_tensors)
71 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
74 auto shape_vec = std::vector<int32_t>({it->shape()->begin(), it->shape()->end()});
75 shape = fb->CreateVector(shape_vec);
78 flatbuffers::Offset<flatbuffers::String> name;
80 name = fb->CreateString(it->name()->str());
82 flatbuffers::Offset<circle::QuantizationParameters> quantization;
83 if (it->quantization())
85 std::vector<float> tfmin;
86 std::vector<float> tfmax;
87 std::vector<float> tfscale;
88 std::vector<int64_t> tfzerop;
89 flatbuffers::Offset<flatbuffers::Vector<float>> min;
90 flatbuffers::Offset<flatbuffers::Vector<float>> max;
91 flatbuffers::Offset<flatbuffers::Vector<float>> scale;
92 flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
93 int32_t quantized_dimension = it->quantization()->quantized_dimension();
95 if (it->quantization()->min() && it->quantization()->max())
97 auto rmin = it->quantization()->min();
98 auto rmax = it->quantization()->max();
99 tfmin = std::vector<float>{rmin->begin(), rmin->end()};
100 tfmax = std::vector<float>{rmax->begin(), rmax->end()};
101 min = fb->CreateVector(tfmin);
102 max = fb->CreateVector(tfmax);
105 if (it->quantization()->scale() && it->quantization()->zero_point())
107 auto rs = it->quantization()->scale();
108 auto rz = it->quantization()->zero_point();
109 tfscale = std::vector<float>{rs->begin(), rs->end()};
110 tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
111 scale = fb->CreateVector(tfscale);
112 zero_point = fb->CreateVector(tfzerop);
115 quantization = circle::CreateQuantizationParameters(*fb, min, max, scale, zero_point,
116 circle::QuantizationDetails_NONE, 0,
117 quantized_dimension);
120 bool is_variable = it->is_variable();
122 flatbuffers::Offset<circle::SparsityParameters> sparsity;
126 flatbuffers::Offset<flatbuffers::Vector<int32_t>> traversal_order;
127 flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_map;
128 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::DimensionMetadata>>>
132 if (it->sparsity()->traversal_order())
134 auto traversal_order_vec = std::vector<int32_t>{
135 it->sparsity()->traversal_order()->begin(), it->sparsity()->traversal_order()->end()};
136 traversal_order = fb->CreateVector(traversal_order_vec);
140 if (it->sparsity()->block_map())
142 auto block_map_vec = std::vector<int32_t>{it->sparsity()->block_map()->begin(),
143 it->sparsity()->block_map()->end()};
144 block_map = fb->CreateVector(block_map_vec);
148 std::vector<flatbuffers::Offset<circle::DimensionMetadata>> dim_metadata_vec;
149 auto tflite_dim_metadata = it->sparsity()->dim_metadata();
150 for (auto it : *tflite_dim_metadata)
153 auto tflite_array_segments_type = it->array_segments_type();
154 auto circle_array_segments =
155 get_circle_sparse_index_vector(*fb, it, tflite_array_segments_type);
156 auto circle_array_segments_type =
157 get_circle_sparse_index_vector_type(tflite_array_segments_type);
160 auto tflite_array_indices_type = it->array_indices_type();
161 auto circle_array_indices =
162 get_circle_sparse_index_vector(*fb, it, tflite_array_indices_type);
163 auto circle_array_indices_type =
164 get_circle_sparse_index_vector_type(tflite_array_indices_type);
166 auto circle_dim_metadata_builder = circle::DimensionMetadataBuilder{*fb};
168 circle_dim_metadata_builder.add_format(get_circle_dimension_type(it->format()));
169 circle_dim_metadata_builder.add_dense_size(it->dense_size());
170 circle_dim_metadata_builder.add_array_segments(circle_array_segments);
171 circle_dim_metadata_builder.add_array_segments_type(circle_array_segments_type);
172 circle_dim_metadata_builder.add_array_indices(circle_array_indices);
173 circle_dim_metadata_builder.add_array_indices_type(circle_array_indices_type);
174 auto dim_metadata = circle_dim_metadata_builder.Finish();
175 dim_metadata_vec.emplace_back(dim_metadata);
177 dim_metadata = fb->CreateVector(dim_metadata_vec);
179 sparsity = circle::CreateSparsityParameters(*fb, traversal_order, block_map, dim_metadata);
183 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
184 if (it->shape_signature())
186 auto shape_signature_vec =
187 std::vector<int32_t>({it->shape_signature()->begin(), it->shape_signature()->end()});
188 shape_signature = fb->CreateVector(shape_signature_vec);
191 circle::TensorBuilder tensor_builder{*fb};
192 tensor_builder.add_shape(shape);
193 tensor_builder.add_type(get_circle_tensortype(it->type()));
194 tensor_builder.add_buffer(it->buffer());
195 tensor_builder.add_name(name);
196 tensor_builder.add_quantization(quantization);
197 tensor_builder.add_is_variable(is_variable);
198 tensor_builder.add_sparsity(sparsity);
199 tensor_builder.add_shape_signature(shape_signature);
200 auto tensor = tensor_builder.Finish();
201 tensor_vec.emplace_back(tensor);
203 auto circle_tensors = fb->CreateVector(tensor_vec);
205 // inputs of subgraph
206 auto tflite_inputs = it_sg->inputs();
207 std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
209 auto circle_inputs = fb->CreateVector(input_vec);
211 // outputs of subgraph
212 auto tflite_outputs = it_sg->outputs();
213 std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
215 auto circle_outputs = fb->CreateVector(output_vec);
217 // operators of subgraph
218 std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
220 auto tflite_operators = it_sg->operators();
221 for (auto it : *tflite_operators)
224 std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
225 auto circle_inputs = fb->CreateVector(input_vec);
227 std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
228 auto circle_outputs = fb->CreateVector(output_vec);
230 auto circle_builtin_options = get_circle_builtin_options(*fb, it);
231 auto circle_builtin_options_type = get_circle_builtin_options_type(it);
233 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
234 if (it->custom_options())
236 std::vector<uint8_t> custom_options_vec{it->custom_options()->begin(),
237 it->custom_options()->end()};
238 circle_custom_options = fb->CreateVector(custom_options_vec);
240 // custom options format
241 // TODO Make get_circle_custom_options_format
242 assert(it->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS);
243 auto circle_custom_options_format = circle::CustomOptionsFormat_FLEXBUFFERS;
245 circle::OperatorBuilder operator_builder{*fb};
246 operator_builder.add_opcode_index(it->opcode_index());
247 operator_builder.add_inputs(circle_inputs);
248 operator_builder.add_outputs(circle_outputs);
249 operator_builder.add_builtin_options(circle_builtin_options);
250 operator_builder.add_builtin_options_type(circle_builtin_options_type);
251 operator_builder.add_custom_options(circle_custom_options);
252 operator_builder.add_custom_options_format(circle_custom_options_format);
253 // TODO mutating_variable_inputs
254 auto opeartor = operator_builder.Finish();
255 operator_vec.emplace_back(opeartor);
257 auto circle_operators = fb->CreateVector(operator_vec);
260 auto subgraphs_name = fb->CreateString(it_sg->name());
263 auto circle_subgraph_builder = circle::SubGraphBuilder{*fb};
265 circle_subgraph_builder.add_tensors(circle_tensors);
266 circle_subgraph_builder.add_inputs(circle_inputs);
267 circle_subgraph_builder.add_outputs(circle_outputs);
268 circle_subgraph_builder.add_operators(circle_operators);
269 circle_subgraph_builder.add_name(subgraphs_name);
270 circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
272 auto circle_subgraph = circle_subgraph_builder.Finish();
273 subgprahs_vec.emplace_back(circle_subgraph);
275 _circle_flatbuffer_vec_offset = fb->CreateVector(subgprahs_vec);
279 Offset<OperatorCodeLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
281 std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
283 for (auto it : *tflite_flatbuffer_vec)
285 auto custom_code = fb->CreateString(it->custom_code());
286 circle::OperatorCodeBuilder operator_code_builder{*fb};
287 operator_code_builder.add_builtin_code(get_circle_builtin_code(it->builtin_code()));
288 operator_code_builder.add_custom_code(custom_code);
289 operator_code_builder.add_version(it->version());
290 auto code = operator_code_builder.Finish();
291 operator_code_vec.emplace_back(code);
293 _circle_flatbuffer_vec_offset = fb->CreateVector(operator_code_vec);
296 CircleModel::CircleModel(FlatBufBuilder &fb, TFLModel &model)
297 : _version{0}, _description{fb->CreateString("nnpackage")}, _fb{fb}
299 const tflite::Model *tfl_model = model.load_model();
300 // verify flatbuffers
301 flatbuffers::Verifier verifier{reinterpret_cast<const uint8_t *>(model._data.data()),
303 if (!tflite::VerifyModelBuffer(verifier))
305 throw std::runtime_error("ERROR: Failed to verify tflite");
308 _operator_codes_offset =
309 std::make_unique<Offset<OperatorCodeLink>>(fb, tfl_model->operator_codes());
310 _subGraphs_offset = std::make_unique<Offset<SubGraphLink>>(fb, tfl_model->subgraphs());
311 _buffers_offset = std::make_unique<Offset<BufferLink>>(fb, tfl_model->buffers());
312 _metadata_buffer_offset =
313 std::make_unique<Offset<MetaDataBufferLink>>(fb, tfl_model->metadata_buffer());
317 void CircleModel::model_build(void) const
319 circle::ModelBuilder model_builder{*_fb};
321 model_builder.add_version(_version);
322 model_builder.add_description(_description);
323 model_builder.add_operator_codes(_operator_codes_offset->offset());
324 model_builder.add_subgraphs(_subGraphs_offset->offset());
325 model_builder.add_buffers(_buffers_offset->offset());
326 model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
328 auto model = model_builder.Finish();
329 circle::FinishModelBuffer(*_fb, model);
332 const char *CircleModel::base(void) const
334 return reinterpret_cast<const char *>(_fb->GetBufferPointer());
337 size_t CircleModel::size(void) const { return _fb->GetSize(); }
339 } // namespace tflite2circle