2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include "CircleModel.h"
22 #include "DataLookup.h"
24 namespace tflite2circle
28 void Offset<MetaDataBufferLink>::build(FlatBufBuilder &fb,
29 const TFLFlatBufVec *tflite_flatbuffer_vec)
31 if (tflite_flatbuffer_vec == nullptr)
33 std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
34 tflite_flatbuffer_vec->end()};
35 _circle_flatbuffer_vec_offset = fb->CreateVector(metadata_buffer_vec);
39 void Offset<BufferLink>::build(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
41 std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
43 for (auto it : *tflite_flatbuffer_vec)
45 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
48 std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
49 buffer_data = fb->CreateVector(data_vec);
51 circle::BufferBuilder circle_buffer_builder{*fb};
52 circle_buffer_builder.add_data(buffer_data);
53 auto circle_buffers = circle_buffer_builder.Finish();
54 buffers_vec.emplace_back(circle_buffers);
56 _circle_flatbuffer_vec_offset = fb->CreateVector(buffers_vec);
60 void Offset<SubGraphLink>::build(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
62 std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
64 for (auto it_sg : *tflite_flatbuffer_vec)
66 // tensors of subgraph
67 std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
69 auto tflite_tensors = it_sg->tensors();
70 for (auto it : *tflite_tensors)
73 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
76 auto shape_vec = std::vector<int32_t>({it->shape()->begin(), it->shape()->end()});
77 shape = fb->CreateVector(shape_vec);
80 flatbuffers::Offset<flatbuffers::String> name;
82 name = fb->CreateString(it->name()->str());
84 flatbuffers::Offset<circle::QuantizationParameters> quantization;
85 if (it->quantization())
87 std::vector<float> tfmin;
88 std::vector<float> tfmax;
89 std::vector<float> tfscale;
90 std::vector<int64_t> tfzerop;
91 flatbuffers::Offset<flatbuffers::Vector<float>> min;
92 flatbuffers::Offset<flatbuffers::Vector<float>> max;
93 flatbuffers::Offset<flatbuffers::Vector<float>> scale;
94 flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
95 int32_t quantized_dimension = it->quantization()->quantized_dimension();
97 if (it->quantization()->min() && it->quantization()->max())
99 auto rmin = it->quantization()->min();
100 auto rmax = it->quantization()->max();
101 tfmin = std::vector<float>{rmin->begin(), rmin->end()};
102 tfmax = std::vector<float>{rmax->begin(), rmax->end()};
103 min = fb->CreateVector(tfmin);
104 max = fb->CreateVector(tfmax);
107 if (it->quantization()->scale() && it->quantization()->zero_point())
109 auto rs = it->quantization()->scale();
110 auto rz = it->quantization()->zero_point();
111 tfscale = std::vector<float>{rs->begin(), rs->end()};
112 tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
113 scale = fb->CreateVector(tfscale);
114 zero_point = fb->CreateVector(tfzerop);
117 quantization = circle::CreateQuantizationParameters(*fb, min, max, scale, zero_point,
118 circle::QuantizationDetails_NONE, 0,
119 quantized_dimension);
122 bool is_variable = it->is_variable();
124 flatbuffers::Offset<circle::SparsityParameters> sparsity;
128 flatbuffers::Offset<flatbuffers::Vector<int32_t>> traversal_order;
129 flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_map;
130 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::DimensionMetadata>>>
134 if (it->sparsity()->traversal_order())
136 auto traversal_order_vec = std::vector<int32_t>{
137 it->sparsity()->traversal_order()->begin(), it->sparsity()->traversal_order()->end()};
138 traversal_order = fb->CreateVector(traversal_order_vec);
142 if (it->sparsity()->block_map())
144 auto block_map_vec = std::vector<int32_t>{it->sparsity()->block_map()->begin(),
145 it->sparsity()->block_map()->end()};
146 block_map = fb->CreateVector(block_map_vec);
150 std::vector<flatbuffers::Offset<circle::DimensionMetadata>> dim_metadata_vec;
151 auto tflite_dim_metadata = it->sparsity()->dim_metadata();
152 for (auto it : *tflite_dim_metadata)
155 auto tflite_array_segments_type = it->array_segments_type();
156 auto circle_array_segments =
157 get_circle_sparse_index_vector(*fb, it->array_segments(), tflite_array_segments_type);
158 auto circle_array_segments_type =
159 get_circle_sparse_index_vector_type(tflite_array_segments_type);
162 auto tflite_array_indices_type = it->array_indices_type();
163 auto circle_array_indices =
164 get_circle_sparse_index_vector(*fb, it->array_indices(), tflite_array_indices_type);
165 auto circle_array_indices_type =
166 get_circle_sparse_index_vector_type(tflite_array_indices_type);
168 auto circle_dim_metadata_builder = circle::DimensionMetadataBuilder{*fb};
170 circle_dim_metadata_builder.add_format(get_circle_dimension_type(it->format()));
171 circle_dim_metadata_builder.add_dense_size(it->dense_size());
172 circle_dim_metadata_builder.add_array_segments(circle_array_segments);
173 circle_dim_metadata_builder.add_array_segments_type(circle_array_segments_type);
174 circle_dim_metadata_builder.add_array_indices(circle_array_indices);
175 circle_dim_metadata_builder.add_array_indices_type(circle_array_indices_type);
176 auto dim_metadata = circle_dim_metadata_builder.Finish();
177 dim_metadata_vec.emplace_back(dim_metadata);
179 dim_metadata = fb->CreateVector(dim_metadata_vec);
181 sparsity = circle::CreateSparsityParameters(*fb, traversal_order, block_map, dim_metadata);
185 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
186 if (it->shape_signature())
188 auto shape_signature_vec =
189 std::vector<int32_t>({it->shape_signature()->begin(), it->shape_signature()->end()});
190 shape_signature = fb->CreateVector(shape_signature_vec);
193 circle::TensorBuilder tensor_builder{*fb};
194 tensor_builder.add_shape(shape);
195 tensor_builder.add_type(get_circle_tensortype(it->type()));
196 tensor_builder.add_buffer(it->buffer());
197 tensor_builder.add_name(name);
198 tensor_builder.add_quantization(quantization);
199 tensor_builder.add_is_variable(is_variable);
200 tensor_builder.add_sparsity(sparsity);
201 tensor_builder.add_shape_signature(shape_signature);
202 auto tensor = tensor_builder.Finish();
203 tensor_vec.emplace_back(tensor);
205 auto circle_tensors = fb->CreateVector(tensor_vec);
207 // inputs of subgraph
208 auto tflite_inputs = it_sg->inputs();
209 std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
211 auto circle_inputs = fb->CreateVector(input_vec);
213 // outputs of subgraph
214 auto tflite_outputs = it_sg->outputs();
215 std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
217 auto circle_outputs = fb->CreateVector(output_vec);
219 // operators of subgraph
220 std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
222 auto tflite_operators = it_sg->operators();
223 if (tflite_operators != nullptr)
225 for (auto it : *tflite_operators)
228 std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
229 auto circle_inputs = fb->CreateVector(input_vec);
231 std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
232 auto circle_outputs = fb->CreateVector(output_vec);
234 auto circle_builtin_options = get_circle_builtin_options(*fb, it);
235 auto circle_builtin_options_type = get_circle_builtin_options_type(it);
237 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
238 if (it->custom_options())
240 std::vector<uint8_t> custom_options_vec{it->custom_options()->begin(),
241 it->custom_options()->end()};
242 circle_custom_options = fb->CreateVector(custom_options_vec);
244 // custom options format
245 // TODO Make get_circle_custom_options_format
246 assert(it->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS);
247 auto circle_custom_options_format = circle::CustomOptionsFormat_FLEXBUFFERS;
249 circle::OperatorBuilder operator_builder{*fb};
250 operator_builder.add_opcode_index(it->opcode_index());
251 operator_builder.add_inputs(circle_inputs);
252 operator_builder.add_outputs(circle_outputs);
253 operator_builder.add_builtin_options(circle_builtin_options);
254 operator_builder.add_builtin_options_type(circle_builtin_options_type);
255 operator_builder.add_custom_options(circle_custom_options);
256 operator_builder.add_custom_options_format(circle_custom_options_format);
257 // TODO mutating_variable_inputs
258 auto opeartor = operator_builder.Finish();
259 operator_vec.emplace_back(opeartor);
262 auto circle_operators = fb->CreateVector(operator_vec);
265 auto subgraphs_name = fb->CreateString(it_sg->name());
268 auto circle_subgraph_builder = circle::SubGraphBuilder{*fb};
270 circle_subgraph_builder.add_tensors(circle_tensors);
271 circle_subgraph_builder.add_inputs(circle_inputs);
272 circle_subgraph_builder.add_outputs(circle_outputs);
273 circle_subgraph_builder.add_operators(circle_operators);
274 circle_subgraph_builder.add_name(subgraphs_name);
275 circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
277 auto circle_subgraph = circle_subgraph_builder.Finish();
278 subgprahs_vec.emplace_back(circle_subgraph);
280 _circle_flatbuffer_vec_offset = fb->CreateVector(subgprahs_vec);
283 tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
285 assert(opcode != nullptr);
286 int8_t dp_code = opcode->deprecated_builtin_code();
287 // 127 is max of int8_t which is upper bound of v3 builtin_code
288 // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
289 if (dp_code < 127 && dp_code >= 0)
290 return tflite::BuiltinOperator(dp_code);
291 return opcode->builtin_code();
295 void Offset<OperatorCodeLink>::build(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
297 std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
299 for (auto it : *tflite_flatbuffer_vec)
301 auto custom_code = fb->CreateString(it->custom_code());
302 circle::OperatorCodeBuilder operator_code_builder{*fb};
303 // TODO support circle deprecated_builtin_code
304 auto bt_code = builtin_code_neutral(it);
305 operator_code_builder.add_builtin_code(get_circle_builtin_code(bt_code));
306 operator_code_builder.add_custom_code(custom_code);
307 operator_code_builder.add_version(it->version());
308 auto code = operator_code_builder.Finish();
309 operator_code_vec.emplace_back(code);
311 _circle_flatbuffer_vec_offset = fb->CreateVector(operator_code_vec);
314 CircleModel::CircleModel(FlatBufBuilder &fb, const tflite::Model *tfl_model)
315 : _version{0}, _description{fb->CreateString("ONE-tflite2circle")}, _fb{fb}
317 _operator_codes_offset = std::make_unique<Offset<OperatorCodeLink>>(fb);
318 _subGraphs_offset = std::make_unique<Offset<SubGraphLink>>(fb);
319 _buffers_offset = std::make_unique<Offset<BufferLink>>(fb);
320 _metadata_buffer_offset = std::make_unique<Offset<MetaDataBufferLink>>(fb);
322 _operator_codes_offset->build(fb, tfl_model->operator_codes());
323 _subGraphs_offset->build(fb, tfl_model->subgraphs());
324 _buffers_offset->build(fb, tfl_model->buffers());
325 _metadata_buffer_offset->build(fb, tfl_model->metadata_buffer());
330 void CircleModel::model_build(void) const
332 circle::ModelBuilder model_builder{*_fb};
334 model_builder.add_version(_version);
335 model_builder.add_description(_description);
336 model_builder.add_operator_codes(_operator_codes_offset->offset());
337 model_builder.add_subgraphs(_subGraphs_offset->offset());
338 model_builder.add_buffers(_buffers_offset->offset());
339 model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
341 auto model = model_builder.Finish();
342 circle::FinishModelBuffer(*_fb, model);
345 const char *CircleModel::base(void) const
347 return reinterpret_cast<const char *>(_fb->GetBufferPointer());
350 size_t CircleModel::size(void) const { return _fb->GetSize(); }
352 } // namespace tflite2circle