2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
22 #include "CircleModel.h"
23 #include "DataLookup.h"
25 #include <mio_tflite280/Helper.h>
27 namespace tflite2circle
30 template <> void Offset<MetaDataBufferLink>::build(const TFLFlatBufVec *tflite_flatbuffer_vec)
32 if (tflite_flatbuffer_vec == nullptr)
34 std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
35 tflite_flatbuffer_vec->end()};
36 _circle_flatbuffer_vec_offset = _fb->CreateVector(metadata_buffer_vec);
39 template <> void Offset<BufferLink>::build(const TFLFlatBufVec *tflite_flatbuffer_vec)
41 std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
43 for (auto it : *tflite_flatbuffer_vec)
45 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
48 std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
49 buffer_data = _fb->CreateVector(data_vec);
51 circle::BufferBuilder circle_buffer_builder{*_fb};
52 circle_buffer_builder.add_data(buffer_data);
53 auto circle_buffers = circle_buffer_builder.Finish();
54 buffers_vec.emplace_back(circle_buffers);
56 _circle_flatbuffer_vec_offset = _fb->CreateVector(buffers_vec);
59 template <> void Offset<SubGraphLink>::build(const TFLFlatBufVec *tflite_flatbuffer_vec)
61 std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
63 int32_t subgraph_index = 0;
65 for (auto it_sg : *tflite_flatbuffer_vec)
67 // tensors of subgraph
68 std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
70 auto tflite_tensors = it_sg->tensors();
71 for (auto it : *tflite_tensors)
74 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
77 auto shape_vec = std::vector<int32_t>({it->shape()->begin(), it->shape()->end()});
78 shape = _fb->CreateVector(shape_vec);
81 flatbuffers::Offset<flatbuffers::String> name;
83 name = _fb->CreateString(it->name()->str());
85 flatbuffers::Offset<circle::QuantizationParameters> quantization;
86 if (it->quantization())
88 std::vector<float> tfmin;
89 std::vector<float> tfmax;
90 std::vector<float> tfscale;
91 std::vector<int64_t> tfzerop;
92 flatbuffers::Offset<flatbuffers::Vector<float>> min;
93 flatbuffers::Offset<flatbuffers::Vector<float>> max;
94 flatbuffers::Offset<flatbuffers::Vector<float>> scale;
95 flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
96 int32_t quantized_dimension = it->quantization()->quantized_dimension();
98 if (it->quantization()->min() && it->quantization()->max())
100 auto rmin = it->quantization()->min();
101 auto rmax = it->quantization()->max();
102 tfmin = std::vector<float>{rmin->begin(), rmin->end()};
103 tfmax = std::vector<float>{rmax->begin(), rmax->end()};
104 min = _fb->CreateVector(tfmin);
105 max = _fb->CreateVector(tfmax);
108 if (it->quantization()->scale() && it->quantization()->zero_point())
110 auto rs = it->quantization()->scale();
111 auto rz = it->quantization()->zero_point();
112 tfscale = std::vector<float>{rs->begin(), rs->end()};
113 tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
114 scale = _fb->CreateVector(tfscale);
115 zero_point = _fb->CreateVector(tfzerop);
118 quantization = circle::CreateQuantizationParameters(*_fb, min, max, scale, zero_point,
119 circle::QuantizationDetails_NONE, 0,
120 quantized_dimension);
123 bool is_variable = it->is_variable();
125 flatbuffers::Offset<circle::SparsityParameters> sparsity;
129 flatbuffers::Offset<flatbuffers::Vector<int32_t>> traversal_order;
130 flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_map;
131 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::DimensionMetadata>>>
135 if (it->sparsity()->traversal_order())
137 auto traversal_order_vec = std::vector<int32_t>{
138 it->sparsity()->traversal_order()->begin(), it->sparsity()->traversal_order()->end()};
139 traversal_order = _fb->CreateVector(traversal_order_vec);
143 if (it->sparsity()->block_map())
145 auto block_map_vec = std::vector<int32_t>{it->sparsity()->block_map()->begin(),
146 it->sparsity()->block_map()->end()};
147 block_map = _fb->CreateVector(block_map_vec);
151 std::vector<flatbuffers::Offset<circle::DimensionMetadata>> dim_metadata_vec;
152 auto tflite_dim_metadata = it->sparsity()->dim_metadata();
153 for (auto it : *tflite_dim_metadata)
156 auto tflite_array_segments_type = it->array_segments_type();
157 auto circle_array_segments =
158 get_circle_sparse_index_vector(*_fb, it->array_segments(), tflite_array_segments_type);
159 auto circle_array_segments_type =
160 get_circle_sparse_index_vector_type(tflite_array_segments_type);
163 auto tflite_array_indices_type = it->array_indices_type();
164 auto circle_array_indices =
165 get_circle_sparse_index_vector(*_fb, it->array_indices(), tflite_array_indices_type);
166 auto circle_array_indices_type =
167 get_circle_sparse_index_vector_type(tflite_array_indices_type);
169 auto circle_dim_metadata_builder = circle::DimensionMetadataBuilder{*_fb};
171 circle_dim_metadata_builder.add_format(get_circle_dimension_type(it->format()));
172 circle_dim_metadata_builder.add_dense_size(it->dense_size());
173 circle_dim_metadata_builder.add_array_segments(circle_array_segments);
174 circle_dim_metadata_builder.add_array_segments_type(circle_array_segments_type);
175 circle_dim_metadata_builder.add_array_indices(circle_array_indices);
176 circle_dim_metadata_builder.add_array_indices_type(circle_array_indices_type);
177 auto dim_metadata = circle_dim_metadata_builder.Finish();
178 dim_metadata_vec.emplace_back(dim_metadata);
180 dim_metadata = _fb->CreateVector(dim_metadata_vec);
182 sparsity = circle::CreateSparsityParameters(*_fb, traversal_order, block_map, dim_metadata);
186 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
187 if (it->shape_signature())
189 auto shape_signature_vec =
190 std::vector<int32_t>({it->shape_signature()->begin(), it->shape_signature()->end()});
191 shape_signature = _fb->CreateVector(shape_signature_vec);
194 circle::TensorBuilder tensor_builder{*_fb};
195 tensor_builder.add_shape(shape);
196 tensor_builder.add_type(get_circle_tensortype(it->type()));
197 tensor_builder.add_buffer(it->buffer());
198 tensor_builder.add_name(name);
199 tensor_builder.add_quantization(quantization);
200 tensor_builder.add_is_variable(is_variable);
201 tensor_builder.add_sparsity(sparsity);
202 tensor_builder.add_shape_signature(shape_signature);
203 auto tensor = tensor_builder.Finish();
204 tensor_vec.emplace_back(tensor);
206 auto circle_tensors = _fb->CreateVector(tensor_vec);
208 // inputs of subgraph
209 auto tflite_inputs = it_sg->inputs();
210 std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
212 // apply signature_def to input tensor index so that input orders follow like tensorflow lite
213 // interpreter._get_full_signature_list() method, which is ordered(sorted) in name
214 // NOTE we do not need this when circle format supports signature_def
215 if (_tfl_signature_def_offsets != nullptr)
217 for (auto it_signdef : *_tfl_signature_def_offsets)
219 if (it_signdef->subgraph_index() == subgraph_index)
221 auto inputs = it_signdef->inputs();
222 assert(inputs->size() == input_vec.size());
224 std::map<std::string, uint32_t> map_name_index;
225 for (auto it_tm : *inputs)
227 map_name_index[it_tm->name()->str()] = it_tm->tensor_index();
229 uint32_t input_vec_idx = 0;
230 for (auto &item : map_name_index)
232 input_vec[input_vec_idx++] = item.second;
238 auto circle_inputs = _fb->CreateVector(input_vec);
240 // outputs of subgraph
241 auto tflite_outputs = it_sg->outputs();
242 std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
244 if (_tfl_signature_def_offsets != nullptr)
246 // apply SignatureDef
247 for (auto it_signdef : *_tfl_signature_def_offsets)
249 if (it_signdef->subgraph_index() == subgraph_index)
251 auto outputs = it_signdef->outputs();
252 assert(outputs->size() == output_vec.size());
254 std::map<std::string, uint32_t> map_name_index;
255 for (auto it_tm : *outputs)
257 map_name_index[it_tm->name()->str()] = it_tm->tensor_index();
259 uint32_t output_vec_idx = 0;
260 for (auto &item : map_name_index)
262 output_vec[output_vec_idx++] = item.second;
268 auto circle_outputs = _fb->CreateVector(output_vec);
270 // operators of subgraph
271 std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
273 auto tflite_operators = it_sg->operators();
274 if (tflite_operators != nullptr)
276 for (auto it : *tflite_operators)
279 std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
280 auto circle_inputs = _fb->CreateVector(input_vec);
282 std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
283 auto circle_outputs = _fb->CreateVector(output_vec);
285 auto circle_builtin_options = get_circle_builtin_options(*_fb, it);
286 auto circle_builtin_options_type = get_circle_builtin_options_type(it);
288 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
289 if (it->custom_options())
291 std::vector<uint8_t> custom_options_vec{it->custom_options()->begin(),
292 it->custom_options()->end()};
293 circle_custom_options = _fb->CreateVector(custom_options_vec);
295 // custom options format
296 // TODO Make get_circle_custom_options_format
297 assert(it->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS);
298 auto circle_custom_options_format = circle::CustomOptionsFormat_FLEXBUFFERS;
300 circle::OperatorBuilder operator_builder{*_fb};
301 operator_builder.add_opcode_index(it->opcode_index());
302 operator_builder.add_inputs(circle_inputs);
303 operator_builder.add_outputs(circle_outputs);
304 operator_builder.add_builtin_options(circle_builtin_options);
305 operator_builder.add_builtin_options_type(circle_builtin_options_type);
306 operator_builder.add_custom_options(circle_custom_options);
307 operator_builder.add_custom_options_format(circle_custom_options_format);
308 // TODO mutating_variable_inputs
309 auto opeartor = operator_builder.Finish();
310 operator_vec.emplace_back(opeartor);
313 auto circle_operators = _fb->CreateVector(operator_vec);
316 auto subgraphs_name = _fb->CreateString(it_sg->name());
319 auto circle_subgraph_builder = circle::SubGraphBuilder{*_fb};
321 circle_subgraph_builder.add_tensors(circle_tensors);
322 circle_subgraph_builder.add_inputs(circle_inputs);
323 circle_subgraph_builder.add_outputs(circle_outputs);
324 circle_subgraph_builder.add_operators(circle_operators);
325 circle_subgraph_builder.add_name(subgraphs_name);
326 circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
328 auto circle_subgraph = circle_subgraph_builder.Finish();
329 subgprahs_vec.emplace_back(circle_subgraph);
332 subgraph_index = subgraph_index + 1;
334 _circle_flatbuffer_vec_offset = _fb->CreateVector(subgprahs_vec);
337 template <> void Offset<OperatorCodeLink>::build(const TFLFlatBufVec *tflite_flatbuffer_vec)
339 std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
341 for (auto it : *tflite_flatbuffer_vec)
343 auto custom_code = _fb->CreateString(it->custom_code());
344 circle::OperatorCodeBuilder operator_code_builder{*_fb};
345 auto de_code = it->deprecated_builtin_code();
346 auto bt_code = it->builtin_code();
347 auto cir_de_code = get_circle_builtin_code(de_code);
348 auto cir_bt_code = get_circle_builtin_code(bt_code);
349 // correct bt_code where bt_code == 0 for old tflite format
350 if (cir_bt_code == 0)
351 cir_bt_code = static_cast<circle::BuiltinOperator>(cir_de_code);
352 operator_code_builder.add_deprecated_builtin_code(cir_de_code);
353 operator_code_builder.add_builtin_code(cir_bt_code);
354 operator_code_builder.add_custom_code(custom_code);
355 operator_code_builder.add_version(it->version());
356 auto code = operator_code_builder.Finish();
357 operator_code_vec.emplace_back(code);
359 _circle_flatbuffer_vec_offset = _fb->CreateVector(operator_code_vec);
362 CircleModel::CircleModel(FlatBufBuilder &fb)
363 : _version{0}, _description{fb->CreateString("ONE-tflite2circle")}, _fb{fb}
368 void CircleModel::load_offsets(const tflite::Model *tfl_model)
370 _operator_codes_offset = std::make_unique<Offset<OperatorCodeLink>>(_fb);
371 _subGraphs_offset = std::make_unique<Offset<SubGraphLink>>(_fb);
372 _buffers_offset = std::make_unique<Offset<BufferLink>>(_fb);
373 _metadata_buffer_offset = std::make_unique<Offset<MetaDataBufferLink>>(_fb);
375 _subGraphs_offset->set_signature_defs(tfl_model->signature_defs());
377 _operator_codes_offset->build(tfl_model->operator_codes());
378 _subGraphs_offset->build(tfl_model->subgraphs());
379 _buffers_offset->build(tfl_model->buffers());
380 _metadata_buffer_offset->build(tfl_model->metadata_buffer());
383 void CircleModel::model_build(void) const
385 circle::ModelBuilder model_builder{*_fb};
387 model_builder.add_version(_version);
388 model_builder.add_description(_description);
389 model_builder.add_operator_codes(_operator_codes_offset->offset());
390 model_builder.add_subgraphs(_subGraphs_offset->offset());
391 model_builder.add_buffers(_buffers_offset->offset());
392 model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
394 auto model = model_builder.Finish();
395 circle::FinishModelBuffer(*_fb, model);
398 const char *CircleModel::base(void) const
400 return reinterpret_cast<const char *>(_fb->GetBufferPointer());
403 size_t CircleModel::size(void) const { return _fb->GetSize(); }
405 } // namespace tflite2circle