2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflchef/ModelChef.h"
18 #include <souschef/RangedArguments.h>
19 #include <souschef/Registry.h>
23 #include <souschef/DataChefs.h>
28 #include <souschef/Dataset.h>
29 #include <souschef/Dims.h>
45 using namespace souschef;
50 class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl
53 GeneratedModelImpl(std::unique_ptr<flatbuffers::FlatBufferBuilder> &&builder)
54 : _builder{std::move(builder)}
60 const char *base(void) const override
62 // Return the base address of generated flatbuffer model
63 return reinterpret_cast<const char *>(_builder->GetBufferPointer());
67 size_t size(void) const override
69 // Return the size of generated flatbuffer model
70 return _builder->GetSize();
74 std::unique_ptr<flatbuffers::FlatBufferBuilder> _builder;
82 struct DataChefRegistry final : public Registry<DataChefFactory>
86 DataChefRegistry &data_chef_registry(const tflchef::TensorType &type)
88 static DataChefRegistry s32;
89 static DataChefRegistry s64;
90 static DataChefRegistry fp32;
91 static DataChefRegistry u8;
92 static DataChefRegistry string;
93 static DataChefRegistry boolean;
94 static DataChefRegistry s16;
102 case tflchef::FLOAT32:
106 case tflchef::STRING:
116 throw std::runtime_error{"Unknown tensor type"};
119 struct OpChefRegistry final : public Registry<OpChefFactory>
123 OpChefRegistry &op_chef_registry(void)
125 static OpChefRegistry registry;
129 /// @brief This will prepare a map of unique builtin codes in the model recipe
130 std::map<tflite::BuiltinOperator, int32_t>
131 gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
133 // Key and value of the map are BuiltinOperator and operator version
134 std::map<tflite::BuiltinOperator, int32_t> builtin_map;
136 for (const auto &operation : model_recipe.operation())
138 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
139 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
142 // Various operation version is unified as the highest version among them
143 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
144 builtin_map[op_chef->code()] < operation.version())
145 builtin_map[op_chef->code()] = operation.version();
148 // Add ops used in Graphs(subgraphs)
149 for (int g = 0; g < model_recipe.graph_size(); ++g)
151 const auto &graph = model_recipe.graph(g);
152 for (const auto &operation : graph.operation())
154 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
155 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
158 // Various operation version is unified as the highest version among them
159 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
160 builtin_map[op_chef->code()] < operation.version())
161 builtin_map[op_chef->code()] = operation.version();
168 /// @brief This will prepare a set of unique custom codes in the mode recipe
169 std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_recipe)
171 std::set<std::string> customcode_set;
172 for (const auto &operation : model_recipe.operation())
174 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
175 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
176 customcode_set.insert(operation.type());
179 // Add ops used in Graphs(subgraphs)
180 for (int g = 0; g < model_recipe.graph_size(); ++g)
182 const auto &graph = model_recipe.graph(g);
183 for (const auto &operation : graph.operation())
185 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
186 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
187 customcode_set.insert(operation.type());
191 return customcode_set;
201 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec;
202 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec;
203 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec;
204 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder;
205 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map;
206 std::vector<std::string> &custom_code_vec;
210 template <typename T> void cook_graph(const T &graph, CookParams &cp)
214 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = cp.buffer_vec;
215 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = cp.code_vec;
216 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = cp.subgraph_vec;
217 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = cp.flatbuffer_builder;
218 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = cp.builtin_code_map;
219 std::vector<std::string> &custom_code_vec = cp.custom_code_vec;
222 std::vector<flatbuffers::Offset<::tflite::Tensor>> tensor_vec;
225 std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;
227 // default name for graph
228 std::string graph_name = cp.noname;
229 if (graph.has_name())
230 graph_name = graph.name();
232 // Tensor Name -> Tensor ID mapping (per Graph)
233 std::map<std::string, int32_t> symbol_table;
235 auto lookup = [&symbol_table, &graph_name](const std::string &name) {
236 if (symbol_table.find(name) != symbol_table.end())
237 return symbol_table.at(name);
239 return -1; // -1 in TFLite means that optional input tensor is empty.
242 std::string msg = "tflchef : input not found in " + graph_name + " graph";
243 throw std::runtime_error(msg.c_str());
247 int32_t buffer_start = buffer_vec.size();
248 int32_t buffer_index = 0;
250 // Create buffer(s) 1~n(I) for input(s)
251 const auto size_input = graph.input_size();
252 for (int ci = 0; ci < size_input; ++ci)
254 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
255 buffer_vec.emplace_back(buffer_builder.Finish());
257 // Create buffer(s) n(I)+1~n(I)+n(O) for output(s)
258 const auto size_output = graph.output_size();
259 for (int co = 0; co < size_output; ++co)
261 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
262 buffer_vec.emplace_back(buffer_builder.Finish());
265 auto input_names = as_dataset(graph.input()).vectorize();
266 auto output_names = as_dataset(graph.output()).vectorize();
268 for (const auto &operand : graph.operand())
270 assert(operand.has_name());
272 assert(operand.has_type());
274 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
275 std::vector<int32_t> dims;
276 if (operand.has_shape())
278 dims = as_dims(operand.shape());
279 shape = flatbuffer_builder->CreateVector(dims);
282 auto name = flatbuffer_builder->CreateString(operand.name());
286 // Create Buffer if filler is specified
287 if (operand.has_filler())
289 const auto &filler = operand.filler();
291 assert(filler.has_tag());
293 auto args = ranged_arguments(filler.arg().begin(), filler.arg().end());
294 auto chef = data_chef_registry(operand.type()).lookup(filler.tag()).create(args);
296 assert(chef != nullptr);
299 int32_t count = (element_count(dims) > 0) ? element_count(dims) : filler.arg_size();
300 auto data_vec = chef->generate(count);
301 auto data = flatbuffer_builder->CreateVector(data_vec);
304 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
305 buffer_builder.add_data(data);
306 auto buffer = buffer_builder.Finish();
308 // Update Buffer Index & Vector
309 buffer_index = buffer_vec.size();
310 buffer_vec.emplace_back(buffer);
314 // if this is input or output, assign to that buffer_index
316 for (auto it = input_names.begin(); it != input_names.end(); ++it, ++idx)
318 if (*it == operand.name())
320 buffer_index = buffer_start + idx;
324 if (buffer_index == 0)
327 for (auto it = output_names.begin(); it != output_names.end(); ++it, ++idx)
329 if (*it == operand.name())
331 buffer_index = buffer_start + size_input + idx;
336 if (buffer_index == 0)
338 // we couldn't find the buffer; create an empty buffer for this tensor
339 buffer_index = buffer_vec.size();
341 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
342 buffer_vec.emplace_back(buffer_builder.Finish());
345 assert(buffer_index != 0);
347 flatbuffers::Offset<tflite::QuantizationParameters> quant_index;
349 // Create QuantizationParameters if quant is specified
350 if (operand.has_quant())
352 const auto &quant = operand.quant();
354 // Create each parameters
355 // NOTE if some parameters are not given, those will be set to default value
356 std::vector<float> quant_max_vec(quant.max_size());
357 std::vector<float> quant_min_vec(quant.min_size());
358 std::vector<float> quant_scale_vec(quant.scale_size());
359 std::vector<int64_t> quant_zero_point_vec(quant.zero_point_size());
361 for (uint32_t i = 0; i < quant.max_size(); ++i)
362 quant_max_vec.at(i) = quant.max(i);
363 for (uint32_t i = 0; i < quant.min_size(); ++i)
364 quant_min_vec.at(i) = quant.min(i);
365 for (uint32_t i = 0; i < quant.scale_size(); ++i)
366 quant_scale_vec.at(i) = quant.scale(i);
367 for (uint32_t i = 0; i < quant.zero_point_size(); ++i)
368 quant_zero_point_vec.at(i) = quant.zero_point(i);
370 auto quant_max = flatbuffer_builder->CreateVector(quant_max_vec);
371 auto quant_min = flatbuffer_builder->CreateVector(quant_min_vec);
372 auto quant_scale = flatbuffer_builder->CreateVector(quant_scale_vec);
373 auto quant_zero_point = flatbuffer_builder->CreateVector(quant_zero_point_vec);
375 // Create QuantizationParameters
376 tflite::QuantizationParametersBuilder quant_builder{*flatbuffer_builder};
377 quant_builder.add_max(quant_max);
378 quant_builder.add_min(quant_min);
379 quant_builder.add_scale(quant_scale);
380 quant_builder.add_zero_point(quant_zero_point);
381 quant_builder.add_quantized_dimension(quant.quantized_dimension());
383 // Update QuantizationParameters Index
384 quant_index = quant_builder.Finish();
387 flatbuffers::Offset<tflite::SparsityParameters> sparsity_index;
389 if (operand.has_sparsity())
391 const auto &sparsity = operand.sparsity();
393 // Create traversal order
394 std::vector<int> traversal_order_vec{sparsity.traversal_order().dim().begin(),
395 sparsity.traversal_order().dim().end()};
396 auto traversal_order = flatbuffer_builder->CreateVector(traversal_order_vec);
399 std::vector<int> block_map_vec{sparsity.block_map().dim().begin(),
400 sparsity.block_map().dim().end()};
401 auto block_map = flatbuffer_builder->CreateVector(block_map_vec);
403 // Create dimension metadata
404 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> dim_metadata_vec;
405 auto recipe_dim_metadata = sparsity.dim_metadata();
406 for (const auto &dm : recipe_dim_metadata)
408 // Create array segments
409 auto tflite_array_segments =
410 as_tflite_sparse_index_vec(*flatbuffer_builder, dm.array_segments());
412 // Create array indices
413 auto tflite_array_indices =
414 as_tflite_sparse_index_vec(*flatbuffer_builder, dm.array_indices());
416 auto tflite_dim_metadata_builder = tflite::DimensionMetadataBuilder{*flatbuffer_builder};
417 tflite_dim_metadata_builder.add_format(as_tflite_dimensiontype(dm.format()));
418 tflite_dim_metadata_builder.add_dense_size(dm.dense_size());
419 tflite_dim_metadata_builder.add_array_segments(tflite_array_segments);
420 tflite_dim_metadata_builder.add_array_segments_type(
421 as_tflite_sparse_idx_vec_type(dm.array_segments().type()));
422 tflite_dim_metadata_builder.add_array_indices(tflite_array_indices);
423 tflite_dim_metadata_builder.add_array_indices_type(
424 as_tflite_sparse_idx_vec_type(dm.array_indices().type()));
425 auto tflite_dim_metadata = tflite_dim_metadata_builder.Finish();
426 dim_metadata_vec.emplace_back(tflite_dim_metadata);
428 auto dim_metadata = flatbuffer_builder->CreateVector(dim_metadata_vec);
430 sparsity_index = tflite::CreateSparsityParameters(*flatbuffer_builder, traversal_order,
431 block_map, dim_metadata);
434 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
435 if (operand.has_shape_signature())
437 auto signature = as_dims(operand.shape_signature());
438 shape_signature = flatbuffer_builder->CreateVector(signature);
442 tflite::TensorBuilder tensor_builder{*flatbuffer_builder};
444 tensor_builder.add_shape(shape);
445 tensor_builder.add_type(as_tflite_tensortype(operand.type()));
446 tensor_builder.add_buffer(buffer_index);
447 tensor_builder.add_name(name);
448 tensor_builder.add_is_variable(operand.is_variable());
449 if (operand.has_quant())
450 tensor_builder.add_quantization(quant_index);
451 tensor_builder.add_sparsity(sparsity_index);
452 if (operand.has_shape_signature())
453 tensor_builder.add_shape_signature(shape_signature);
456 tensor_vec.emplace_back(tensor_builder.Finish());
458 // Update Tensor Name -> Tensor Index Map
459 int32_t tensor_index = symbol_table.size();
460 const auto &tensor_name = operand.name();
462 INFO(l) << "Symbol [" << tensor_name << "] = Tensor " << tensor_index << std::endl;
464 symbol_table[tensor_name] = tensor_index;
468 for (const auto &operation : graph.operation())
470 assert(operation.has_type());
472 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
475 std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
476 auto inputs = flatbuffer_builder->CreateVector(input_vec);
479 std::vector<int32_t> output_vec = as_dataset(operation.output()).map(lookup).vectorize();
480 auto outputs = flatbuffer_builder->CreateVector(output_vec);
483 auto options = op_chef->value(*flatbuffer_builder);
485 // Create Custom option
486 auto circle_custom_options = op_chef->custom_value(*flatbuffer_builder);
489 tflite::OperatorBuilder op_builder{*flatbuffer_builder};
491 // Note that opcode_index is an index into the operator_codes vector.
492 // operator_codes consists of buildtin_code and custom_code, which is inserted sequentially.
493 uint32_t opcode_index = 0;
494 auto op_it = builtin_code_map.find(op_chef->code());
496 if (op_it != builtin_code_map.end())
498 opcode_index = std::distance(builtin_code_map.begin(), op_it);
503 auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type());
504 assert(op_it != custom_code_vec.end());
505 opcode_index = builtin_code_map.size();
506 opcode_index += std::distance(custom_code_vec.begin(), op_it);
509 op_builder.add_opcode_index(opcode_index);
510 op_builder.add_inputs(inputs);
511 op_builder.add_outputs(outputs);
512 op_builder.add_builtin_options_type(op_chef->type());
513 op_builder.add_builtin_options(options);
514 op_builder.add_custom_options(circle_custom_options);
515 op_builder.add_custom_options_format(tflite::CustomOptionsFormat_FLEXBUFFERS);
517 operator_vec.emplace_back(op_builder.Finish());
520 // Create network input/output vector
521 std::vector<int32_t> input_vec = as_dataset(graph.input()).map(lookup).vectorize();
522 std::vector<int32_t> output_vec = as_dataset(graph.output()).map(lookup).vectorize();
524 // Create "SubGraph" arguments
525 auto tensors = flatbuffer_builder->CreateVector(tensor_vec);
526 auto inputs = flatbuffer_builder->CreateVector(input_vec);
527 auto outputs = flatbuffer_builder->CreateVector(output_vec);
528 auto operators = flatbuffer_builder->CreateVector(operator_vec);
529 auto name = flatbuffer_builder->CreateString(graph_name);
531 tflite::SubGraphBuilder subgraph_builder{*flatbuffer_builder};
533 subgraph_builder.add_tensors(tensors);
534 subgraph_builder.add_inputs(inputs);
535 subgraph_builder.add_outputs(outputs);
536 subgraph_builder.add_operators(operators);
537 subgraph_builder.add_name(name);
539 subgraph_vec.emplace_back(subgraph_builder.Finish());
548 * @brief Generate a (in-memory) TensorFlow Lite model from a given model recipe
550 GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
552 // Initialize Op Chef Registry
553 #define OP_CHEF(NAME, FACTORY_CLASS) \
554 op_chef_registry().add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
555 #include "OpChef.def"
558 // Initialize Data Chef Registry
559 #define DATA_CHEF(TYPE, NAME, FACTORY_CLASS) \
560 data_chef_registry(::tflchef::TYPE) \
561 .add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
562 #include "DataChef.def"
566 // Create FlatBufferBuilder
568 auto flatbuffer_builder =
569 std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));
572 std::vector<flatbuffers::Offset<::tflite::Buffer>> buffer_vec;
575 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
578 std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
580 // Create OperatorCode with Builtin Operator
581 auto builtin_code_map = gather_builtincode_map(model_recipe);
582 for (auto const &opcode : builtin_code_map)
584 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
585 // TODO support for opcode.first >= 127
586 assert(opcode.first < 127);
587 code_builder.add_deprecated_builtin_code(opcode.first);
588 code_builder.add_version(opcode.second);
589 code_builder.add_builtin_code(opcode.first);
590 auto code = code_builder.Finish();
591 // Update OperatorCode vector
592 code_vec.emplace_back(code);
595 // Create OperatorCode with Custom Operator
596 std::set<std::string> custom_code_set = gather_customcode_set(model_recipe);
597 std::vector<std::string> custom_code_vec{custom_code_set.begin(), custom_code_set.end()};
599 for (auto opcode : custom_code_vec)
601 auto custom_code = flatbuffer_builder->CreateString(opcode);
602 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
603 code_builder.add_deprecated_builtin_code(tflite::BuiltinOperator_CUSTOM);
604 code_builder.add_custom_code(custom_code);
605 code_builder.add_builtin_code(tflite::BuiltinOperator_CUSTOM);
606 auto code = code_builder.Finish();
607 // Update OperatorCode vector
608 code_vec.emplace_back(code);
611 // Create an Empty Buffer
613 // Buffer 0 SHOULD be an empty buffer in TensorFlow Lite model file
614 // (Please refer to the comment for Tensor.buffer field in schema)
616 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
617 buffer_vec.emplace_back(buffer_builder.Finish());
623 CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
624 builtin_code_map, custom_code_vec, "main"};
626 cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
629 // Create subgraphs if exist
631 for (int g = 0; g < model_recipe.graph_size(); ++g)
633 const auto &graph = model_recipe.graph(g);
635 std::ostringstream stringStream;
636 stringStream << "sub_" << (g + 1);
638 CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
639 builtin_code_map, custom_code_vec, stringStream.str()};
641 cook_graph<::tflchef::Graph>(graph, cp);
644 // Create "Model" arguments
645 auto buffers = flatbuffer_builder->CreateVector(buffer_vec);
646 auto operator_codes = flatbuffer_builder->CreateVector(code_vec);
647 auto subgraphs = flatbuffer_builder->CreateVector(subgraph_vec);
648 auto description = flatbuffer_builder->CreateString("Generated by tflchef");
651 tflite::ModelBuilder model_builder{*flatbuffer_builder};
653 model_builder.add_version(3);
654 model_builder.add_operator_codes(operator_codes);
655 model_builder.add_subgraphs(subgraphs);
656 model_builder.add_description(description);
657 model_builder.add_buffers(buffers);
659 auto model = model_builder.Finish();
662 ::tflite::FinishModelBuffer(*flatbuffer_builder, model);
664 // Return "GenerateModel"
665 return GeneratedModel{
666 std::unique_ptr<GeneratedModelImpl>(new GeneratedModelImpl(std::move(flatbuffer_builder)))};
669 } // namespace tflchef