2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflchef/ModelChef.h"
18 #include <souschef/RangedArguments.h>
19 #include <souschef/Registry.h>
23 #include <souschef/DataChefs.h>
28 #include <souschef/Dataset.h>
29 #include <souschef/Dims.h>
45 using namespace souschef;
50 class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl
53 GeneratedModelImpl(std::unique_ptr<flatbuffers::FlatBufferBuilder> &&builder)
54 : _builder{std::move(builder)}
60 const char *base(void) const override
62 // Return the base address of generated flatbuffer model
63 return reinterpret_cast<const char *>(_builder->GetBufferPointer());
67 size_t size(void) const override
69 // Return the size of generated flatbuffer model
70 return _builder->GetSize();
74 std::unique_ptr<flatbuffers::FlatBufferBuilder> _builder;
82 struct DataChefRegistry final : public Registry<DataChefFactory>
86 DataChefRegistry &data_chef_registry(const tflchef::TensorType &type)
88 static DataChefRegistry s32;
89 static DataChefRegistry s64;
90 static DataChefRegistry fp32;
91 static DataChefRegistry u8;
92 static DataChefRegistry string;
93 static DataChefRegistry boolean;
94 static DataChefRegistry s16;
95 static DataChefRegistry fp16;
96 static DataChefRegistry s8;
104 case tflchef::FLOAT32:
106 case tflchef::FLOAT16:
110 case tflchef::STRING:
122 throw std::runtime_error{"Unknown tensor type"};
125 struct OpChefRegistry final : public Registry<OpChefFactory>
129 OpChefRegistry &op_chef_registry(void)
131 static OpChefRegistry registry;
135 /// @brief This will prepare a map of unique builtin codes in the model recipe
136 std::map<tflite::BuiltinOperator, int32_t>
137 gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
139 // Key and value of the map are BuiltinOperator and operator version
140 std::map<tflite::BuiltinOperator, int32_t> builtin_map;
142 for (const auto &operation : model_recipe.operation())
144 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
145 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
148 // Various operation version is unified as the highest version among them
149 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
150 builtin_map[op_chef->code()] < operation.version())
151 builtin_map[op_chef->code()] = operation.version();
154 // Add ops used in Graphs(subgraphs)
155 for (int g = 0; g < model_recipe.graph_size(); ++g)
157 const auto &graph = model_recipe.graph(g);
158 for (const auto &operation : graph.operation())
160 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
161 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
164 // Various operation version is unified as the highest version among them
165 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
166 builtin_map[op_chef->code()] < operation.version())
167 builtin_map[op_chef->code()] = operation.version();
174 /// @brief This will prepare a set of unique custom codes in the mode recipe
175 std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_recipe)
177 std::set<std::string> customcode_set;
178 for (const auto &operation : model_recipe.operation())
180 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
181 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
182 customcode_set.insert(operation.type());
185 // Add ops used in Graphs(subgraphs)
186 for (int g = 0; g < model_recipe.graph_size(); ++g)
188 const auto &graph = model_recipe.graph(g);
189 for (const auto &operation : graph.operation())
191 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
192 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
193 customcode_set.insert(operation.type());
197 return customcode_set;
207 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec;
208 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec;
209 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec;
210 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder;
211 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map;
212 std::vector<std::string> &custom_code_vec;
216 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>>
217 make_dim_metadata_vec(flatbuffers::FlatBufferBuilder *flatbuffer_builder, int32_t dims_count,
218 const std::vector<int> &traversal_order_vec,
219 const std::vector<sparsity::TfLiteDimensionType> &format_vec,
220 const std::vector<std::vector<int32_t>> &dim_metadata_src)
222 // Build sparsity parameter.
223 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> dim_metadata_vec(dims_count);
224 for (int32_t i = 0; i < dims_count; i++)
226 const int32_t metadata_idx = 2 * i;
227 if (format_vec[traversal_order_vec[i]] == sparsity::kTfLiteDimSparseCSR)
229 auto array_segments =
230 tflite::CreateInt32Vector(*flatbuffer_builder,
231 flatbuffer_builder->CreateVector(dim_metadata_src[metadata_idx]))
234 tflite::CreateInt32Vector(
235 *flatbuffer_builder, flatbuffer_builder->CreateVector(dim_metadata_src[metadata_idx + 1]))
237 dim_metadata_vec[i] =
238 tflite::CreateDimensionMetadata(*flatbuffer_builder, tflite::DimensionType_SPARSE_CSR, 0,
239 tflite::SparseIndexVector_Int32Vector, array_segments,
240 tflite::SparseIndexVector_Int32Vector, array_indices);
244 dim_metadata_vec[i] = tflite::CreateDimensionMetadata(
245 *flatbuffer_builder, tflite::DimensionType_DENSE, dim_metadata_src[metadata_idx][0]);
248 return dim_metadata_vec;
251 template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph, CookParams &cp)
255 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = cp.buffer_vec;
256 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = cp.code_vec;
257 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = cp.subgraph_vec;
258 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = cp.flatbuffer_builder;
259 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = cp.builtin_code_map;
260 std::vector<std::string> &custom_code_vec = cp.custom_code_vec;
263 std::vector<flatbuffers::Offset<::tflite::Tensor>> tensor_vec;
266 std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;
268 // default name for graph
269 std::string graph_name = cp.noname;
270 if (graph.has_name())
271 graph_name = graph.name();
273 // Tensor Name -> Tensor ID mapping (per Graph)
274 std::map<std::string, int32_t> symbol_table;
276 auto lookup = [&symbol_table, &graph_name](const std::string &name) {
277 if (symbol_table.find(name) != symbol_table.end())
278 return symbol_table.at(name);
280 return -1; // -1 in TFLite means that optional input tensor is empty.
283 std::string msg = "tflchef : input not found in " + graph_name + " graph";
284 throw std::runtime_error(msg.c_str());
288 int32_t buffer_start = buffer_vec.size();
289 int32_t buffer_index = 0;
291 // Create buffer(s) 1~n(I) for input(s)
292 const auto size_input = graph.input_size();
293 for (int ci = 0; ci < size_input; ++ci)
295 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
296 buffer_vec.emplace_back(buffer_builder.Finish());
298 // Create buffer(s) n(I)+1~n(I)+n(O) for output(s)
299 const auto size_output = graph.output_size();
300 for (int co = 0; co < size_output; ++co)
302 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
303 buffer_vec.emplace_back(buffer_builder.Finish());
306 auto input_names = as_dataset(graph.input()).vectorize();
307 auto output_names = as_dataset(graph.output()).vectorize();
309 for (const auto &operand : graph.operand())
311 assert(operand.has_name());
313 assert(operand.has_type());
315 flatbuffers::Offset<tflite::SparsityParameters> sparsity_index;
317 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
318 std::vector<int32_t> dims;
319 if (operand.has_shape())
321 dims = as_dims(operand.shape());
322 shape = flatbuffer_builder->CreateVector(dims);
325 auto name = flatbuffer_builder->CreateString(operand.name());
329 // Create Buffer if filler is specified
330 if (operand.has_filler())
332 const auto &filler = operand.filler();
334 assert(filler.has_tag());
336 auto args = ranged_arguments(filler.arg().begin(), filler.arg().end());
337 auto chef = data_chef_registry(operand.type()).lookup(filler.tag()).create(args);
339 assert(chef != nullptr);
342 int32_t count = (element_count(dims) > 0) ? element_count(dims) : filler.arg_size();
343 auto data_vec = chef->generate(count);
345 if (operand.has_make_sparse() && operand.make_sparse())
347 assert(not operand.has_sparsity());
348 assert(operand.has_shape());
350 const int32_t dims_count = dims.size();
351 std::vector<int> traversal_order_vec;
352 std::vector<sparsity::TfLiteDimensionType> format_vec;
353 for (int32_t o = 0; o < dims_count; ++o)
354 traversal_order_vec.push_back(o);
355 for (int32_t o = 0; o < dims_count - 1; ++o)
356 format_vec.push_back(sparsity::kTfLiteDimDense);
357 format_vec.push_back(sparsity::kTfLiteDimSparseCSR);
359 if (operand.type() == tflchef::FLOAT32)
361 ::sparsity::FormatConverter<float> converter(dims, traversal_order_vec, format_vec);
362 converter.DenseToSparse(reinterpret_cast<const float *>(data_vec.data()));
363 const auto &sparse_data = converter.GetData();
365 std::vector<uint8_t> sparse_uint8;
366 for (int c = 0; c < sparse_data.size(); ++c)
368 const float value = sparse_data.at(c);
369 const uint8_t *arr = reinterpret_cast<const uint8_t *>(&value);
370 for (uint32_t b = 0; b < sizeof(float); ++b)
372 sparse_uint8.emplace_back(arr[b]);
375 auto data = flatbuffer_builder->CreateVector(sparse_uint8);
378 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
379 buffer_builder.add_data(data);
380 auto buffer = buffer_builder.Finish();
382 // Update Buffer Index & Vector
383 buffer_index = buffer_vec.size();
384 buffer_vec.emplace_back(buffer);
386 // save SparsityParameters
387 auto traversal_order = flatbuffer_builder->CreateVector(traversal_order_vec);
390 std::vector<int> block_map_vec{};
391 auto block_map = flatbuffer_builder->CreateVector(block_map_vec);
393 // Create dimension metadata
394 const auto &dim_metadata_src = converter.GetDimMetadata();
395 auto dim_metadata_vec =
396 make_dim_metadata_vec(flatbuffer_builder.get(), dims_count, traversal_order_vec,
397 format_vec, dim_metadata_src);
398 auto dim_metadata = flatbuffer_builder->CreateVector(dim_metadata_vec);
399 sparsity_index = tflite::CreateSparsityParameters(*flatbuffer_builder, traversal_order,
400 block_map, dim_metadata);
402 else if (operand.type() == tflchef::FLOAT16)
404 ::sparsity::FormatConverter<uint16_t> converter(dims, traversal_order_vec, format_vec);
405 converter.DenseToSparse(reinterpret_cast<const uint16_t *>(data_vec.data()));
406 const auto &sparse_data = converter.GetData();
408 std::vector<uint8_t> sparse_uint8;
409 for (int c = 0; c < sparse_data.size(); ++c)
411 const uint16_t value = sparse_data.at(c);
412 const uint8_t *arr = reinterpret_cast<const uint8_t *>(&value);
413 for (uint32_t b = 0; b < sizeof(uint16_t); ++b)
415 sparse_uint8.emplace_back(arr[b]);
418 auto data = flatbuffer_builder->CreateVector(sparse_uint8);
421 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
422 buffer_builder.add_data(data);
423 auto buffer = buffer_builder.Finish();
425 // Update Buffer Index & Vector
426 buffer_index = buffer_vec.size();
427 buffer_vec.emplace_back(buffer);
429 // save SparsityParameters
430 auto traversal_order = flatbuffer_builder->CreateVector(traversal_order_vec);
433 std::vector<int> block_map_vec{};
434 auto block_map = flatbuffer_builder->CreateVector(block_map_vec);
436 // Create dimension metadata
437 const auto &dim_metadata_src = converter.GetDimMetadata();
438 auto dim_metadata_vec =
439 make_dim_metadata_vec(flatbuffer_builder.get(), dims_count, traversal_order_vec,
440 format_vec, dim_metadata_src);
441 auto dim_metadata = flatbuffer_builder->CreateVector(dim_metadata_vec);
442 sparsity_index = tflite::CreateSparsityParameters(*flatbuffer_builder, traversal_order,
443 block_map, dim_metadata);
447 throw std::runtime_error{"NYI: unsupported operand type"};
452 auto data = flatbuffer_builder->CreateVector(data_vec);
455 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
456 buffer_builder.add_data(data);
457 auto buffer = buffer_builder.Finish();
459 // Update Buffer Index & Vector
460 buffer_index = buffer_vec.size();
461 buffer_vec.emplace_back(buffer);
466 // if this is input or output, assign to that buffer_index
468 for (auto it = input_names.begin(); it != input_names.end(); ++it, ++idx)
470 if (*it == operand.name())
472 buffer_index = buffer_start + idx;
476 if (buffer_index == 0)
479 for (auto it = output_names.begin(); it != output_names.end(); ++it, ++idx)
481 if (*it == operand.name())
483 buffer_index = buffer_start + size_input + idx;
488 if (buffer_index == 0)
490 // we couldn't find the buffer; create an empty buffer for this tensor
491 buffer_index = buffer_vec.size();
493 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
494 buffer_vec.emplace_back(buffer_builder.Finish());
497 assert(buffer_index != 0);
499 flatbuffers::Offset<tflite::QuantizationParameters> quant_index;
501 // Create QuantizationParameters if quant is specified
502 if (operand.has_quant())
504 const auto &quant = operand.quant();
506 // Create each parameters
507 // NOTE if some parameters are not given, those will be set to default value
508 std::vector<float> quant_max_vec(quant.max_size());
509 std::vector<float> quant_min_vec(quant.min_size());
510 std::vector<float> quant_scale_vec(quant.scale_size());
511 std::vector<int64_t> quant_zero_point_vec(quant.zero_point_size());
513 for (uint32_t i = 0; i < quant.max_size(); ++i)
514 quant_max_vec.at(i) = quant.max(i);
515 for (uint32_t i = 0; i < quant.min_size(); ++i)
516 quant_min_vec.at(i) = quant.min(i);
517 for (uint32_t i = 0; i < quant.scale_size(); ++i)
518 quant_scale_vec.at(i) = quant.scale(i);
519 for (uint32_t i = 0; i < quant.zero_point_size(); ++i)
520 quant_zero_point_vec.at(i) = quant.zero_point(i);
522 auto quant_max = flatbuffer_builder->CreateVector(quant_max_vec);
523 auto quant_min = flatbuffer_builder->CreateVector(quant_min_vec);
524 auto quant_scale = flatbuffer_builder->CreateVector(quant_scale_vec);
525 auto quant_zero_point = flatbuffer_builder->CreateVector(quant_zero_point_vec);
527 // Create QuantizationParameters
528 tflite::QuantizationParametersBuilder quant_builder{*flatbuffer_builder};
529 quant_builder.add_max(quant_max);
530 quant_builder.add_min(quant_min);
531 quant_builder.add_scale(quant_scale);
532 quant_builder.add_zero_point(quant_zero_point);
533 quant_builder.add_quantized_dimension(quant.quantized_dimension());
535 // Update QuantizationParameters Index
536 quant_index = quant_builder.Finish();
539 if (operand.has_sparsity())
541 const auto &sparsity = operand.sparsity();
543 // Create traversal order
544 std::vector<int> traversal_order_vec{sparsity.traversal_order().dim().begin(),
545 sparsity.traversal_order().dim().end()};
546 auto traversal_order = flatbuffer_builder->CreateVector(traversal_order_vec);
549 std::vector<int> block_map_vec{sparsity.block_map().dim().begin(),
550 sparsity.block_map().dim().end()};
551 auto block_map = flatbuffer_builder->CreateVector(block_map_vec);
553 // Create dimension metadata
554 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> dim_metadata_vec;
555 auto recipe_dim_metadata = sparsity.dim_metadata();
556 for (const auto &dm : recipe_dim_metadata)
558 // Create array segments
559 auto tflite_array_segments =
560 as_tflite_sparse_index_vec(*flatbuffer_builder, dm.array_segments());
562 // Create array indices
563 auto tflite_array_indices =
564 as_tflite_sparse_index_vec(*flatbuffer_builder, dm.array_indices());
566 auto tflite_dim_metadata_builder = tflite::DimensionMetadataBuilder{*flatbuffer_builder};
567 tflite_dim_metadata_builder.add_format(as_tflite_dimensiontype(dm.format()));
568 tflite_dim_metadata_builder.add_dense_size(dm.dense_size());
569 tflite_dim_metadata_builder.add_array_segments(tflite_array_segments);
570 tflite_dim_metadata_builder.add_array_segments_type(
571 as_tflite_sparse_idx_vec_type(dm.array_segments().type()));
572 tflite_dim_metadata_builder.add_array_indices(tflite_array_indices);
573 tflite_dim_metadata_builder.add_array_indices_type(
574 as_tflite_sparse_idx_vec_type(dm.array_indices().type()));
575 auto tflite_dim_metadata = tflite_dim_metadata_builder.Finish();
576 dim_metadata_vec.emplace_back(tflite_dim_metadata);
578 auto dim_metadata = flatbuffer_builder->CreateVector(dim_metadata_vec);
580 sparsity_index = tflite::CreateSparsityParameters(*flatbuffer_builder, traversal_order,
581 block_map, dim_metadata);
584 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
585 if (operand.has_shape_signature())
587 auto signature = as_dims(operand.shape_signature());
588 shape_signature = flatbuffer_builder->CreateVector(signature);
592 tflite::TensorBuilder tensor_builder{*flatbuffer_builder};
594 tensor_builder.add_shape(shape);
595 tensor_builder.add_type(as_tflite_tensortype(operand.type()));
596 tensor_builder.add_buffer(buffer_index);
597 tensor_builder.add_name(name);
598 tensor_builder.add_is_variable(operand.is_variable());
599 if (operand.has_quant())
600 tensor_builder.add_quantization(quant_index);
601 tensor_builder.add_sparsity(sparsity_index);
602 if (operand.has_shape_signature())
603 tensor_builder.add_shape_signature(shape_signature);
606 tensor_vec.emplace_back(tensor_builder.Finish());
608 // Update Tensor Name -> Tensor Index Map
609 int32_t tensor_index = symbol_table.size();
610 const auto &tensor_name = operand.name();
612 INFO(l) << "Symbol [" << tensor_name << "] = Tensor " << tensor_index << std::endl;
614 symbol_table[tensor_name] = tensor_index;
618 for (const auto &operation : graph.operation())
620 assert(operation.has_type());
622 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
625 std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
626 auto inputs = flatbuffer_builder->CreateVector(input_vec);
629 std::vector<int32_t> output_vec = as_dataset(operation.output()).map(lookup).vectorize();
630 auto outputs = flatbuffer_builder->CreateVector(output_vec);
633 auto options = op_chef->value(*flatbuffer_builder);
635 // Create Custom option
636 auto circle_custom_options = op_chef->custom_value(*flatbuffer_builder);
639 tflite::OperatorBuilder op_builder{*flatbuffer_builder};
641 // Note that opcode_index is an index into the operator_codes vector.
642 // operator_codes consists of buildtin_code and custom_code, which is inserted sequentially.
643 uint32_t opcode_index = 0;
644 auto op_it = builtin_code_map.find(op_chef->code());
646 if (op_it != builtin_code_map.end())
648 opcode_index = std::distance(builtin_code_map.begin(), op_it);
653 auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type());
654 assert(op_it != custom_code_vec.end());
655 opcode_index = builtin_code_map.size();
656 opcode_index += std::distance(custom_code_vec.begin(), op_it);
659 op_builder.add_opcode_index(opcode_index);
660 op_builder.add_inputs(inputs);
661 op_builder.add_outputs(outputs);
662 op_builder.add_builtin_options_type(op_chef->type());
663 op_builder.add_builtin_options(options);
664 op_builder.add_custom_options(circle_custom_options);
665 op_builder.add_custom_options_format(tflite::CustomOptionsFormat_FLEXBUFFERS);
667 operator_vec.emplace_back(op_builder.Finish());
670 // Create network input/output vector
671 std::vector<int32_t> input_vec = as_dataset(graph.input()).map(lookup).vectorize();
672 std::vector<int32_t> output_vec = as_dataset(graph.output()).map(lookup).vectorize();
674 // Create "SubGraph" arguments
675 auto tensors = flatbuffer_builder->CreateVector(tensor_vec);
676 auto inputs = flatbuffer_builder->CreateVector(input_vec);
677 auto outputs = flatbuffer_builder->CreateVector(output_vec);
678 auto operators = flatbuffer_builder->CreateVector(operator_vec);
679 auto name = flatbuffer_builder->CreateString(graph_name);
681 tflite::SubGraphBuilder subgraph_builder{*flatbuffer_builder};
683 subgraph_builder.add_tensors(tensors);
684 subgraph_builder.add_inputs(inputs);
685 subgraph_builder.add_outputs(outputs);
686 subgraph_builder.add_operators(operators);
687 subgraph_builder.add_name(name);
689 subgraph_vec.emplace_back(subgraph_builder.Finish());
700 * @brief Generate a (in-memory) TensorFlow Lite model from a given model recipe
702 GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
704 // Initialize Op Chef Registry
705 #define OP_CHEF(NAME, FACTORY_CLASS) \
706 op_chef_registry().add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
707 #include "OpChef.def"
710 // Initialize Data Chef Registry
711 #define DATA_CHEF(TYPE, NAME, FACTORY_CLASS) \
712 data_chef_registry(::tflchef::TYPE) \
713 .add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
714 #include "DataChef.def"
718 // Create FlatBufferBuilder
720 auto flatbuffer_builder =
721 std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));
724 std::vector<flatbuffers::Offset<::tflite::Buffer>> buffer_vec;
727 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
729 // SignatureDef-related
730 std::vector<flatbuffers::Offset<::tflite::SignatureDef>> signdef_vec;
733 std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
735 // Create OperatorCode with Builtin Operator
736 auto builtin_code_map = gather_builtincode_map(model_recipe);
737 for (auto const &opcode : builtin_code_map)
739 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
740 // 127 is BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES
741 // This is the way to handle deprecated builtin code
743 // https://github.com/tensorflow/tensorflow/blob/a0afe8f9218be5eb9ed5dffc2dff652996da8c28/tensorflow/lite/schema/schema.fbs#L1061-L1077
744 if (opcode.first < 127)
746 code_builder.add_deprecated_builtin_code(opcode.first);
750 code_builder.add_deprecated_builtin_code(
751 ::tflite::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES);
753 code_builder.add_version(opcode.second);
754 code_builder.add_builtin_code(opcode.first);
755 auto code = code_builder.Finish();
756 // Update OperatorCode vector
757 code_vec.emplace_back(code);
760 // Create OperatorCode with Custom Operator
761 std::set<std::string> custom_code_set = gather_customcode_set(model_recipe);
762 std::vector<std::string> custom_code_vec{custom_code_set.begin(), custom_code_set.end()};
764 for (auto opcode : custom_code_vec)
766 auto custom_code = flatbuffer_builder->CreateString(opcode);
767 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
768 code_builder.add_deprecated_builtin_code(tflite::BuiltinOperator_CUSTOM);
769 code_builder.add_custom_code(custom_code);
770 code_builder.add_builtin_code(tflite::BuiltinOperator_CUSTOM);
771 auto code = code_builder.Finish();
772 // Update OperatorCode vector
773 code_vec.emplace_back(code);
776 // Create an Empty Buffer
778 // Buffer 0 SHOULD be an empty buffer in TensorFlow Lite model file
779 // (Please refer to the comment for Tensor.buffer field in schema)
781 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
782 buffer_vec.emplace_back(buffer_builder.Finish());
785 // symbol_tables stores symbol_table of each sub graph
786 // this is used to find tensor ID(index) with tensor name
787 std::vector<std::map<std::string, int32_t>> symbol_tables;
792 CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
793 builtin_code_map, custom_code_vec, "main"};
795 auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
796 symbol_tables.push_back(table);
799 // Create subgraphs if exist
801 for (int g = 0; g < model_recipe.graph_size(); ++g)
803 const auto &graph = model_recipe.graph(g);
805 std::ostringstream stringStream;
806 stringStream << "sub_" << (g + 1);
808 CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
809 builtin_code_map, custom_code_vec, stringStream.str()};
811 auto table = cook_graph<::tflchef::Graph>(graph, cp);
812 symbol_tables.push_back(table);
815 // Create Signature-Def
817 for (int s = 0; s < model_recipe.signature_def_size(); ++s)
820 const auto &rec_signature_def = model_recipe.signature_def(s);
822 std::vector<flatbuffers::Offset<::tflite::TensorMap>> tensormap_inputs;
823 std::vector<flatbuffers::Offset<::tflite::TensorMap>> tensormap_outputs;
825 // which subgraph index to cook
826 auto subgraph_index = 0;
827 if (rec_signature_def.has_subgraph_index())
829 subgraph_index = rec_signature_def.subgraph_index();
831 assert(subgraph_index < symbol_tables.size());
832 auto &symbol_table = symbol_tables[subgraph_index];
835 for (int si = 0; si < rec_signature_def.inputs_size(); ++si)
837 // recipe for input TensorMap
838 auto rec_tm_input = rec_signature_def.inputs(si);
839 auto name = flatbuffer_builder->CreateString(rec_tm_input.name());
840 uint32_t tensor_index = 0;
841 // either tensor or tensor_index should exist
842 assert(rec_tm_input.has_tensor() || rec_tm_input.has_tensor_index());
843 if (rec_tm_input.has_tensor())
845 // we can get tensor_index from symbol_table
846 auto tensor = rec_tm_input.tensor();
847 tensor_index = symbol_table[tensor];
851 // or we can use tensor_index itself
852 tensor_index = rec_tm_input.tensor_index();
855 ::tflite::TensorMapBuilder tensormap_builder{*flatbuffer_builder};
856 tensormap_builder.add_name(name);
857 tensormap_builder.add_tensor_index(tensor_index);
858 tensormap_inputs.push_back(tensormap_builder.Finish());
860 // cook for outputs, same as inputs
861 for (int so = 0; so < rec_signature_def.outputs_size(); ++so)
863 auto rec_tm_output = rec_signature_def.outputs(so);
864 auto name = flatbuffer_builder->CreateString(rec_tm_output.name());
865 uint32_t tensor_index = 0;
866 assert(rec_tm_output.has_tensor() || rec_tm_output.has_tensor_index());
867 if (rec_tm_output.has_tensor())
869 auto tensor = rec_tm_output.tensor();
870 tensor_index = symbol_table[tensor];
874 tensor_index = rec_tm_output.tensor_index();
877 ::tflite::TensorMapBuilder tensormap_builder{*flatbuffer_builder};
878 tensormap_builder.add_name(name);
879 tensormap_builder.add_tensor_index(tensor_index);
880 tensormap_outputs.push_back(tensormap_builder.Finish());
883 auto inputs = flatbuffer_builder->CreateVector(tensormap_inputs);
884 auto outputs = flatbuffer_builder->CreateVector(tensormap_outputs);
885 auto signature_key = flatbuffer_builder->CreateString(rec_signature_def.signature_key());
886 // TODO add validation for signature_key
888 ::tflite::SignatureDefBuilder signature_def_builder{*flatbuffer_builder};
889 signature_def_builder.add_inputs(inputs);
890 signature_def_builder.add_outputs(outputs);
891 signature_def_builder.add_signature_key(signature_key);
892 signature_def_builder.add_subgraph_index(rec_signature_def.subgraph_index());
894 signdef_vec.emplace_back(signature_def_builder.Finish());
897 // Create "Model" arguments
898 auto buffers = flatbuffer_builder->CreateVector(buffer_vec);
899 auto signdefs = flatbuffer_builder->CreateVector(signdef_vec);
900 auto operator_codes = flatbuffer_builder->CreateVector(code_vec);
901 auto subgraphs = flatbuffer_builder->CreateVector(subgraph_vec);
902 auto description = flatbuffer_builder->CreateString("Generated by tflchef");
905 tflite::ModelBuilder model_builder{*flatbuffer_builder};
907 model_builder.add_version(3);
908 model_builder.add_operator_codes(operator_codes);
909 model_builder.add_signature_defs(signdefs);
910 model_builder.add_subgraphs(subgraphs);
911 model_builder.add_description(description);
912 model_builder.add_buffers(buffers);
914 auto model = model_builder.Finish();
917 ::tflite::FinishModelBuffer(*flatbuffer_builder, model);
919 // Return "GenerateModel"
920 return GeneratedModel{
921 std::unique_ptr<GeneratedModelImpl>(new GeneratedModelImpl(std::move(flatbuffer_builder)))};
924 } // namespace tflchef