2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflchef/ModelChef.h"
18 #include <souschef/RangedArguments.h>
19 #include <souschef/Registry.h>
23 #include <souschef/DataChefs.h>
28 #include <souschef/Dataset.h>
29 #include <souschef/Dims.h>
45 using namespace souschef;
50 class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl
53 GeneratedModelImpl(std::unique_ptr<flatbuffers::FlatBufferBuilder> &&builder)
54 : _builder{std::move(builder)}
60 const char *base(void) const override
62 // Return the base address of generated flatbuffer model
63 return reinterpret_cast<const char *>(_builder->GetBufferPointer());
67 size_t size(void) const override
69 // Return the size of generated flatbuffer model
70 return _builder->GetSize();
74 std::unique_ptr<flatbuffers::FlatBufferBuilder> _builder;
82 struct DataChefRegistry final : public Registry<DataChefFactory>
86 DataChefRegistry &data_chef_registry(const tflchef::TensorType &type)
88 static DataChefRegistry s32;
89 static DataChefRegistry s64;
90 static DataChefRegistry fp32;
91 static DataChefRegistry u8;
92 static DataChefRegistry boolean;
100 case tflchef::FLOAT32:
110 throw std::runtime_error{"Unknown tensor type"};
113 struct OpChefRegistry final : public Registry<OpChefFactory>
117 OpChefRegistry &op_chef_registry(void)
119 static OpChefRegistry registry;
123 /// @brief This will prepare a map of unique builtin codes in the model recipe
124 std::map<tflite::BuiltinOperator, int32_t>
125 gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
127 // Key and value of the map are BuiltinOperator and operator version
128 std::map<tflite::BuiltinOperator, int32_t> builtin_map;
130 for (const auto &operation : model_recipe.operation())
132 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
133 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
136 // Various operation version is unified as the highest version among them
137 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
138 builtin_map[op_chef->code()] < operation.version())
139 builtin_map[op_chef->code()] = operation.version();
142 // Add ops used in Graphs(subgraphs)
143 for (int g = 0; g < model_recipe.graph_size(); ++g)
145 const auto &graph = model_recipe.graph(g);
146 for (const auto &operation : graph.operation())
148 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
149 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
152 // Various operation version is unified as the highest version among them
153 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
154 builtin_map[op_chef->code()] < operation.version())
155 builtin_map[op_chef->code()] = operation.version();
162 /// @brief This will prepare a set of unique custom codes in the mode recipe
163 std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_recipe)
165 std::set<std::string> customcode_set;
166 for (const auto &operation : model_recipe.operation())
168 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
169 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
170 customcode_set.insert(operation.type());
173 // Add ops used in Graphs(subgraphs)
174 for (int g = 0; g < model_recipe.graph_size(); ++g)
176 const auto &graph = model_recipe.graph(g);
177 for (const auto &operation : graph.operation())
179 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
180 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
181 customcode_set.insert(operation.type());
185 return customcode_set;
195 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec;
196 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec;
197 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec;
198 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder;
199 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map;
203 template <typename T> void cook_graph(const T &graph, CookParams &cp)
207 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = cp.buffer_vec;
208 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = cp.code_vec;
209 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = cp.subgraph_vec;
210 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = cp.flatbuffer_builder;
211 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = cp.builtin_code_map;
214 std::vector<flatbuffers::Offset<::tflite::Tensor>> tensor_vec;
217 std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;
219 // default name for graph
220 std::string graph_name = cp.noname;
221 if (graph.has_name())
222 graph_name = graph.name();
224 // Tensor Name -> Tensor ID mapping (per Graph)
225 std::map<std::string, int32_t> symbol_table;
227 auto lookup = [&symbol_table, &graph_name](const std::string &name) {
228 if (symbol_table.find(name) != symbol_table.end())
229 return symbol_table.at(name);
231 return -1; // -1 in TFLite means that optional input tensor is empty.
234 std::string msg = "tflchef : input not found in " + graph_name + " graph";
235 throw std::runtime_error(msg.c_str());
239 int32_t buffer_start = buffer_vec.size();
240 int32_t buffer_index = 0;
242 // Create buffer(s) 1~n(I) for input(s)
243 const auto size_input = graph.input_size();
244 for (int ci = 0; ci < size_input; ++ci)
246 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
247 buffer_vec.emplace_back(buffer_builder.Finish());
249 // Create buffer(s) n(I)+1~n(I)+n(O) for output(s)
250 const auto size_output = graph.output_size();
251 for (int co = 0; co < size_output; ++co)
253 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
254 buffer_vec.emplace_back(buffer_builder.Finish());
257 auto input_names = as_dataset(graph.input()).vectorize();
258 auto output_names = as_dataset(graph.output()).vectorize();
260 for (const auto &operand : graph.operand())
262 assert(operand.has_name());
264 assert(operand.has_type());
266 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
267 std::vector<int32_t> dims;
268 if (operand.has_shape())
270 dims = as_dims(operand.shape());
271 shape = flatbuffer_builder->CreateVector(dims);
274 auto name = flatbuffer_builder->CreateString(operand.name());
278 // Create Buffer if filler is specified
279 if (operand.has_filler())
281 const auto &filler = operand.filler();
283 assert(filler.has_tag());
285 auto args = ranged_arguments(filler.arg().begin(), filler.arg().end());
286 auto chef = data_chef_registry(operand.type()).lookup(filler.tag()).create(args);
288 assert(chef != nullptr);
291 int32_t count = (element_count(dims) > 0) ? element_count(dims) : filler.arg_size();
292 auto data_vec = chef->generate(count);
293 auto data = flatbuffer_builder->CreateVector(data_vec);
296 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
297 buffer_builder.add_data(data);
298 auto buffer = buffer_builder.Finish();
300 // Update Buffer Index & Vector
301 buffer_index = buffer_vec.size();
302 buffer_vec.emplace_back(buffer);
306 // if this is input or output, assign to that buffer_index
308 for (auto it = input_names.begin(); it != input_names.end(); ++it, ++idx)
310 if (*it == operand.name())
312 buffer_index = buffer_start + idx;
316 if (buffer_index == 0)
319 for (auto it = output_names.begin(); it != output_names.end(); ++it, ++idx)
321 if (*it == operand.name())
323 buffer_index = buffer_start + size_input + idx;
328 if (buffer_index == 0)
330 // we couldn't find the buffer; create an empty buffer for this tensor
331 buffer_index = buffer_vec.size();
333 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
334 buffer_vec.emplace_back(buffer_builder.Finish());
337 assert(buffer_index != 0);
339 flatbuffers::Offset<tflite::QuantizationParameters> quant_index;
341 // Create QuantizationParameters if quant is specified
342 if (operand.has_quant())
344 const auto &quant = operand.quant();
346 // Create each parameters
347 // NOTE if some parameters are not given, those will be set to default value
348 std::vector<float> quant_max_vec(quant.max_size());
349 std::vector<float> quant_min_vec(quant.min_size());
350 std::vector<float> quant_scale_vec(quant.scale_size());
351 std::vector<int64_t> quant_zero_point_vec(quant.zero_point_size());
353 for (uint32_t i = 0; i < quant.max_size(); ++i)
354 quant_max_vec.at(i) = quant.max(i);
355 for (uint32_t i = 0; i < quant.min_size(); ++i)
356 quant_min_vec.at(i) = quant.min(i);
357 for (uint32_t i = 0; i < quant.scale_size(); ++i)
358 quant_scale_vec.at(i) = quant.scale(i);
359 for (uint32_t i = 0; i < quant.zero_point_size(); ++i)
360 quant_zero_point_vec.at(i) = quant.zero_point(i);
362 auto quant_max = flatbuffer_builder->CreateVector(quant_max_vec);
363 auto quant_min = flatbuffer_builder->CreateVector(quant_min_vec);
364 auto quant_scale = flatbuffer_builder->CreateVector(quant_scale_vec);
365 auto quant_zero_point = flatbuffer_builder->CreateVector(quant_zero_point_vec);
367 // Create QuantizationParameters
368 tflite::QuantizationParametersBuilder quant_builder{*flatbuffer_builder};
369 quant_builder.add_max(quant_max);
370 quant_builder.add_min(quant_min);
371 quant_builder.add_scale(quant_scale);
372 quant_builder.add_zero_point(quant_zero_point);
373 quant_builder.add_quantized_dimension(quant.quantized_dimension());
375 // Update QuantizationParameters Index
376 quant_index = quant_builder.Finish();
380 tflite::TensorBuilder tensor_builder{*flatbuffer_builder};
382 tensor_builder.add_shape(shape);
383 tensor_builder.add_type(as_tflite_tensortype(operand.type()));
384 tensor_builder.add_buffer(buffer_index);
385 tensor_builder.add_name(name);
386 if (operand.has_quant())
387 tensor_builder.add_quantization(quant_index);
390 tensor_vec.emplace_back(tensor_builder.Finish());
392 // Update Tensor Name -> Tensor Index Map
393 int32_t tensor_index = symbol_table.size();
394 const auto &tensor_name = operand.name();
396 INFO(l) << "Symbol [" << tensor_name << "] = Tensor " << tensor_index << std::endl;
398 symbol_table[tensor_name] = tensor_index;
402 for (const auto &operation : graph.operation())
404 assert(operation.has_type());
406 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
409 std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
410 auto inputs = flatbuffer_builder->CreateVector(input_vec);
413 std::vector<int32_t> output_vec = as_dataset(operation.output()).map(lookup).vectorize();
414 auto outputs = flatbuffer_builder->CreateVector(output_vec);
417 auto options = op_chef->value(*flatbuffer_builder);
419 // Create Custom option
420 auto circle_custom_options = op_chef->custom_value(*flatbuffer_builder);
423 tflite::OperatorBuilder op_builder{*flatbuffer_builder};
425 // Get operator code index from builtin_code_set with assumption, order of
426 // builtin_code_set is same as that of code_vec
427 auto op_it = builtin_code_map.find(op_chef->code());
428 assert(op_it != builtin_code_map.end());
429 uint32_t opcode_index = std::distance(builtin_code_map.begin(), op_it);
431 op_builder.add_opcode_index(opcode_index);
432 op_builder.add_inputs(inputs);
433 op_builder.add_outputs(outputs);
434 op_builder.add_builtin_options_type(op_chef->type());
435 op_builder.add_builtin_options(options);
436 op_builder.add_custom_options(circle_custom_options);
437 op_builder.add_custom_options_format(tflite::CustomOptionsFormat_FLEXBUFFERS);
439 operator_vec.emplace_back(op_builder.Finish());
442 // Create network input/output vector
443 std::vector<int32_t> input_vec = as_dataset(graph.input()).map(lookup).vectorize();
444 std::vector<int32_t> output_vec = as_dataset(graph.output()).map(lookup).vectorize();
446 // Create "SubGraph" arguments
447 auto tensors = flatbuffer_builder->CreateVector(tensor_vec);
448 auto inputs = flatbuffer_builder->CreateVector(input_vec);
449 auto outputs = flatbuffer_builder->CreateVector(output_vec);
450 auto operators = flatbuffer_builder->CreateVector(operator_vec);
451 auto name = flatbuffer_builder->CreateString(graph_name);
453 tflite::SubGraphBuilder subgraph_builder{*flatbuffer_builder};
455 subgraph_builder.add_tensors(tensors);
456 subgraph_builder.add_inputs(inputs);
457 subgraph_builder.add_outputs(outputs);
458 subgraph_builder.add_operators(operators);
459 subgraph_builder.add_name(name);
461 subgraph_vec.emplace_back(subgraph_builder.Finish());
470 * @brief Generate a (in-memory) TensorFlow Lite model from a given model recipe
472 GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
474 // Initialize Op Chef Registry
475 #define OP_CHEF(NAME, FACTORY_CLASS) \
476 op_chef_registry().add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
477 #include "OpChef.def"
480 // Initialize Data Chef Registry
481 #define DATA_CHEF(TYPE, NAME, FACTORY_CLASS) \
482 data_chef_registry(::tflchef::TYPE) \
483 .add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
484 #include <souschef/DataChef.def>
488 // Create FlatBufferBuilder
490 auto flatbuffer_builder =
491 std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));
494 std::vector<flatbuffers::Offset<::tflite::Buffer>> buffer_vec;
497 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
500 std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
502 // Create OperatorCode with Builtin Operator
503 auto builtin_code_map = gather_builtincode_map(model_recipe);
504 for (auto const &opcode : builtin_code_map)
506 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
507 code_builder.add_builtin_code(opcode.first);
508 code_builder.add_version(opcode.second);
509 auto code = code_builder.Finish();
510 // Update OperatorCode vector
511 code_vec.emplace_back(code);
514 // Create OperatorCode with Custom Operator
515 std::set<std::string> custom_code_set = gather_customcode_set(model_recipe);
516 if (custom_code_set.size() &&
517 builtin_code_map.find(tflite::BuiltinOperator_CUSTOM) == builtin_code_map.end())
518 builtin_code_map[tflite::BuiltinOperator_CUSTOM] = 1;
520 for (auto opcode : custom_code_set)
522 auto custom_code = flatbuffer_builder->CreateString(opcode);
523 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
524 code_builder.add_builtin_code(tflite::BuiltinOperator_CUSTOM);
525 code_builder.add_custom_code(custom_code);
526 auto code = code_builder.Finish();
527 // Update OperatorCode vector
528 code_vec.emplace_back(code);
531 // Create an Empty Buffer
533 // Buffer 0 SHOULD be an empty buffer in TensorFlow Lite model file
534 // (Please refer to the comment for Tensor.buffer field in schema)
536 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
537 buffer_vec.emplace_back(buffer_builder.Finish());
543 CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder, builtin_code_map, "main"};
545 cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
548 // Create subgraphs if exist
550 for (int g = 0; g < model_recipe.graph_size(); ++g)
552 const auto &graph = model_recipe.graph(g);
554 std::ostringstream stringStream;
555 stringStream << "sub_" << (g + 1);
557 CookParams cp{buffer_vec, code_vec, subgraph_vec,
558 flatbuffer_builder, builtin_code_map, stringStream.str()};
560 cook_graph<::tflchef::Graph>(graph, cp);
563 // Create "Model" arguments
564 auto buffers = flatbuffer_builder->CreateVector(buffer_vec);
565 auto operator_codes = flatbuffer_builder->CreateVector(code_vec);
566 auto subgraphs = flatbuffer_builder->CreateVector(subgraph_vec);
567 auto description = flatbuffer_builder->CreateString("Generated by tflchef");
570 tflite::ModelBuilder model_builder{*flatbuffer_builder};
572 model_builder.add_version(3);
573 model_builder.add_operator_codes(operator_codes);
574 model_builder.add_subgraphs(subgraphs);
575 model_builder.add_description(description);
576 model_builder.add_buffers(buffers);
578 auto model = model_builder.Finish();
581 ::tflite::FinishModelBuffer(*flatbuffer_builder, model);
583 // Return "GenerateModel"
584 return GeneratedModel{
585 std::unique_ptr<GeneratedModelImpl>(new GeneratedModelImpl(std::move(flatbuffer_builder)))};
588 } // namespace tflchef