2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflchef/ModelChef.h"
18 #include <souschef/RangedArguments.h>
19 #include <souschef/Registry.h>
23 #include <souschef/DataChefs.h>
28 #include <souschef/Dataset.h>
47 using namespace souschef;
49 template <typename T> std::vector<T> as_vector(const ::google::protobuf::RepeatedPtrField<T> &field)
52 for (const auto &elem : field)
54 res.emplace_back(elem);
59 template <typename T> Dataset<T> as_dataset(const ::google::protobuf::RepeatedPtrField<T> &field)
61 return Dataset<T>(as_vector<T>(field));
69 template <typename T> using Dims = std::vector<T>;
71 Dims<int32_t> as_dims(const tflchef::TensorShape &shape)
73 std::vector<int32_t> res;
75 for (auto &dim : shape.dim())
77 res.emplace_back(static_cast<int32_t>(dim));
83 int32_t element_count(const Dims<int32_t> &dims)
85 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int32_t>());
93 class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl
96 GeneratedModelImpl(std::unique_ptr<flatbuffers::FlatBufferBuilder> &&builder)
97 : _builder{std::move(builder)}
103 const char *base(void) const override
105 // Return the base address of generated flatbuffer model
106 return reinterpret_cast<const char *>(_builder->GetBufferPointer());
110 size_t size(void) const override
112 // Return the size of generated flatbuffer model
113 return _builder->GetSize();
117 std::unique_ptr<flatbuffers::FlatBufferBuilder> _builder;
125 struct DataChefRegistry final : public Registry<DataChefFactory>
129 DataChefRegistry &data_chef_registry(const tflchef::TensorType &type)
131 static DataChefRegistry s32;
132 static DataChefRegistry s64;
133 static DataChefRegistry fp32;
134 static DataChefRegistry u8;
135 static DataChefRegistry boolean;
143 case tflchef::FLOAT32:
153 throw std::runtime_error{"Unknown tensor type"};
156 struct OpChefRegistry final : public Registry<OpChefFactory>
160 OpChefRegistry &op_chef_registry(void)
162 static OpChefRegistry registry;
166 /// @brief This will prepare a map of unique builtin codes in the model recipe
167 std::map<tflite::BuiltinOperator, int32_t>
168 gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
170 // Key and value of the map are BuiltinOperator and operator version
171 std::map<tflite::BuiltinOperator, int32_t> builtin_map;
173 for (const auto &operation : model_recipe.operation())
175 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
176 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
179 // Various operation version is unified as the highest version among them
180 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
181 builtin_map[op_chef->code()] < operation.version())
182 builtin_map[op_chef->code()] = operation.version();
185 // Add ops used in Graphs(subgraphs)
186 for (int g = 0; g < model_recipe.graph_size(); ++g)
188 const auto &graph = model_recipe.graph(g);
189 for (const auto &operation : graph.operation())
191 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
192 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
195 // Various operation version is unified as the highest version among them
196 if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
197 builtin_map[op_chef->code()] < operation.version())
198 builtin_map[op_chef->code()] = operation.version();
205 /// @brief This will prepare a set of unique custom codes in the mode recipe
206 std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_recipe)
208 std::set<std::string> customcode_set;
209 for (const auto &operation : model_recipe.operation())
211 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
212 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
213 customcode_set.insert(operation.type());
216 // Add ops used in Graphs(subgraphs)
217 for (int g = 0; g < model_recipe.graph_size(); ++g)
219 const auto &graph = model_recipe.graph(g);
220 for (const auto &operation : graph.operation())
222 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
223 if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
224 customcode_set.insert(operation.type());
228 return customcode_set;
238 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec;
239 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec;
240 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec;
241 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder;
242 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map;
246 template <typename T> void cook_graph(const T &graph, CookParams &cp)
250 std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = cp.buffer_vec;
251 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = cp.code_vec;
252 std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = cp.subgraph_vec;
253 std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = cp.flatbuffer_builder;
254 std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = cp.builtin_code_map;
257 std::vector<flatbuffers::Offset<::tflite::Tensor>> tensor_vec;
260 std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;
262 // default name for graph
263 std::string graph_name = cp.noname;
264 if (graph.has_name())
265 graph_name = graph.name();
267 // Tensor Name -> Tensor ID mapping (per Graph)
268 std::map<std::string, int32_t> symbol_table;
270 auto lookup = [&symbol_table, &graph_name](const std::string &name) {
271 if (symbol_table.find(name) != symbol_table.end())
272 return symbol_table.at(name);
274 return -1; // -1 in TFLite means that optional input tensor is empty.
277 std::string msg = "tflchef : input not found in " + graph_name + " graph";
278 throw std::runtime_error(msg.c_str());
282 int32_t buffer_start = buffer_vec.size();
283 int32_t buffer_index = 0;
285 // Create buffer(s) 1~n(I) for input(s)
286 const auto size_input = graph.input_size();
287 for (int ci = 0; ci < size_input; ++ci)
289 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
290 buffer_vec.emplace_back(buffer_builder.Finish());
292 // Create buffer(s) n(I)+1~n(I)+n(O) for output(s)
293 const auto size_output = graph.output_size();
294 for (int co = 0; co < size_output; ++co)
296 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
297 buffer_vec.emplace_back(buffer_builder.Finish());
300 auto input_names = as_dataset(graph.input()).vectorize();
301 auto output_names = as_dataset(graph.output()).vectorize();
303 for (const auto &operand : graph.operand())
305 assert(operand.has_name());
307 assert(operand.has_type());
309 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
310 std::vector<int32_t> dims;
311 if (operand.has_shape())
313 dims = as_dims(operand.shape());
314 shape = flatbuffer_builder->CreateVector(dims);
317 auto name = flatbuffer_builder->CreateString(operand.name());
321 // Create Buffer if filler is specified
322 if (operand.has_filler())
324 const auto &filler = operand.filler();
326 assert(filler.has_tag());
328 auto args = ranged_arguments(filler.arg().begin(), filler.arg().end());
329 auto chef = data_chef_registry(operand.type()).lookup(filler.tag()).create(args);
331 assert(chef != nullptr);
334 int32_t count = (element_count(dims) > 0) ? element_count(dims) : filler.arg_size();
335 auto data_vec = chef->generate(count);
336 auto data = flatbuffer_builder->CreateVector(data_vec);
339 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
340 buffer_builder.add_data(data);
341 auto buffer = buffer_builder.Finish();
343 // Update Buffer Index & Vector
344 buffer_index = buffer_vec.size();
345 buffer_vec.emplace_back(buffer);
349 // if this is input or output, assign to that buffer_index
351 for (auto it = input_names.begin(); it != input_names.end(); ++it, ++idx)
353 if (*it == operand.name())
355 buffer_index = buffer_start + idx;
359 if (buffer_index == 0)
362 for (auto it = output_names.begin(); it != output_names.end(); ++it, ++idx)
364 if (*it == operand.name())
366 buffer_index = buffer_start + size_input + idx;
371 if (buffer_index == 0)
373 // we couldn't find the buffer; create an empty buffer for this tensor
374 buffer_index = buffer_vec.size();
376 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
377 buffer_vec.emplace_back(buffer_builder.Finish());
380 assert(buffer_index != 0);
382 flatbuffers::Offset<tflite::QuantizationParameters> quant_index;
384 // Create QuantizationParameters if quant is specified
385 if (operand.has_quant())
387 const auto &quant = operand.quant();
389 // Create each parameters
390 // NOTE if some parameters are not given, those will be set to default value
391 std::vector<float> quant_max_vec(quant.max_size());
392 std::vector<float> quant_min_vec(quant.min_size());
393 std::vector<float> quant_scale_vec(quant.scale_size());
394 std::vector<int64_t> quant_zero_point_vec(quant.zero_point_size());
396 for (uint32_t i = 0; i < quant.max_size(); ++i)
397 quant_max_vec.at(i) = quant.max(i);
398 for (uint32_t i = 0; i < quant.min_size(); ++i)
399 quant_min_vec.at(i) = quant.min(i);
400 for (uint32_t i = 0; i < quant.scale_size(); ++i)
401 quant_scale_vec.at(i) = quant.scale(i);
402 for (uint32_t i = 0; i < quant.zero_point_size(); ++i)
403 quant_zero_point_vec.at(i) = quant.zero_point(i);
405 auto quant_max = flatbuffer_builder->CreateVector(quant_max_vec);
406 auto quant_min = flatbuffer_builder->CreateVector(quant_min_vec);
407 auto quant_scale = flatbuffer_builder->CreateVector(quant_scale_vec);
408 auto quant_zero_point = flatbuffer_builder->CreateVector(quant_zero_point_vec);
410 // Create QuantizationParameters
411 tflite::QuantizationParametersBuilder quant_builder{*flatbuffer_builder};
412 quant_builder.add_max(quant_max);
413 quant_builder.add_min(quant_min);
414 quant_builder.add_scale(quant_scale);
415 quant_builder.add_zero_point(quant_zero_point);
417 // Update QuantizationParameters Index
418 quant_index = quant_builder.Finish();
422 tflite::TensorBuilder tensor_builder{*flatbuffer_builder};
424 tensor_builder.add_shape(shape);
425 tensor_builder.add_type(as_tflite_tensortype(operand.type()));
426 tensor_builder.add_buffer(buffer_index);
427 tensor_builder.add_name(name);
428 if (operand.has_quant())
429 tensor_builder.add_quantization(quant_index);
432 tensor_vec.emplace_back(tensor_builder.Finish());
434 // Update Tensor Name -> Tensor Index Map
435 int32_t tensor_index = symbol_table.size();
436 const auto &tensor_name = operand.name();
438 INFO(l) << "Symbol [" << tensor_name << "] = Tensor " << tensor_index << std::endl;
440 symbol_table[tensor_name] = tensor_index;
444 for (const auto &operation : graph.operation())
446 assert(operation.has_type());
448 auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
451 std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
452 auto inputs = flatbuffer_builder->CreateVector(input_vec);
455 std::vector<int32_t> output_vec = as_dataset(operation.output()).map(lookup).vectorize();
456 auto outputs = flatbuffer_builder->CreateVector(output_vec);
459 auto options = op_chef->value(*flatbuffer_builder);
461 // Create Custom option
462 auto circle_custom_options = op_chef->custom_value(*flatbuffer_builder);
465 tflite::OperatorBuilder op_builder{*flatbuffer_builder};
467 // Get operator code index from builtin_code_set with assumption, order of
468 // builtin_code_set is same as that of code_vec
469 auto op_it = builtin_code_map.find(op_chef->code());
470 assert(op_it != builtin_code_map.end());
471 uint32_t opcode_index = std::distance(builtin_code_map.begin(), op_it);
473 op_builder.add_opcode_index(opcode_index);
474 op_builder.add_inputs(inputs);
475 op_builder.add_outputs(outputs);
476 op_builder.add_builtin_options_type(op_chef->type());
477 op_builder.add_builtin_options(options);
478 op_builder.add_custom_options(circle_custom_options);
479 op_builder.add_custom_options_format(tflite::CustomOptionsFormat_FLEXBUFFERS);
481 operator_vec.emplace_back(op_builder.Finish());
484 // Create network input/output vector
485 std::vector<int32_t> input_vec = as_dataset(graph.input()).map(lookup).vectorize();
486 std::vector<int32_t> output_vec = as_dataset(graph.output()).map(lookup).vectorize();
488 // Create "SubGraph" arguments
489 auto tensors = flatbuffer_builder->CreateVector(tensor_vec);
490 auto inputs = flatbuffer_builder->CreateVector(input_vec);
491 auto outputs = flatbuffer_builder->CreateVector(output_vec);
492 auto operators = flatbuffer_builder->CreateVector(operator_vec);
493 auto name = flatbuffer_builder->CreateString(graph_name);
495 tflite::SubGraphBuilder subgraph_builder{*flatbuffer_builder};
497 subgraph_builder.add_tensors(tensors);
498 subgraph_builder.add_inputs(inputs);
499 subgraph_builder.add_outputs(outputs);
500 subgraph_builder.add_operators(operators);
501 subgraph_builder.add_name(name);
503 subgraph_vec.emplace_back(subgraph_builder.Finish());
512 * @brief Generate a (in-memory) TensorFlow Lite model from a given model recipe
514 GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
516 // Initialize Op Chef Registry
517 #define OP_CHEF(NAME, FACTORY_CLASS) \
518 op_chef_registry().add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
519 #include "OpChef.def"
522 // Initialize Data Chef Registry
523 #define DATA_CHEF(TYPE, NAME, FACTORY_CLASS) \
524 data_chef_registry(::tflchef::TYPE) \
525 .add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
526 #include <souschef/DataChef.def>
530 // Create FlatBufferBuilder
532 auto flatbuffer_builder =
533 std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));
536 std::vector<flatbuffers::Offset<::tflite::Buffer>> buffer_vec;
539 std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
542 std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
544 // Create OperatorCode with Builtin Operator
545 auto builtin_code_map = gather_builtincode_map(model_recipe);
546 for (auto const &opcode : builtin_code_map)
548 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
549 code_builder.add_builtin_code(opcode.first);
550 code_builder.add_version(opcode.second);
551 auto code = code_builder.Finish();
552 // Update OperatorCode vector
553 code_vec.emplace_back(code);
556 // Create OperatorCode with Custom Operator
557 std::set<std::string> custom_code_set = gather_customcode_set(model_recipe);
558 if (custom_code_set.size() &&
559 builtin_code_map.find(tflite::BuiltinOperator_CUSTOM) == builtin_code_map.end())
560 builtin_code_map[tflite::BuiltinOperator_CUSTOM] = 1;
562 for (auto opcode : custom_code_set)
564 auto custom_code = flatbuffer_builder->CreateString(opcode);
565 tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
566 code_builder.add_builtin_code(tflite::BuiltinOperator_CUSTOM);
567 code_builder.add_custom_code(custom_code);
568 auto code = code_builder.Finish();
569 // Update OperatorCode vector
570 code_vec.emplace_back(code);
573 // Create an Empty Buffer
575 // Buffer 0 SHOULD be an empty buffer in TensorFlow Lite model file
576 // (Please refer to the comment for Tensor.buffer field in schema)
578 tflite::BufferBuilder buffer_builder{*flatbuffer_builder};
579 buffer_vec.emplace_back(buffer_builder.Finish());
585 CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder, builtin_code_map, "main"};
587 cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
590 // Create subgraphs if exist
592 for (int g = 0; g < model_recipe.graph_size(); ++g)
594 const auto &graph = model_recipe.graph(g);
596 std::ostringstream stringStream;
597 stringStream << "sub_" << (g + 1);
599 CookParams cp{buffer_vec, code_vec, subgraph_vec,
600 flatbuffer_builder, builtin_code_map, stringStream.str()};
602 cook_graph<::tflchef::Graph>(graph, cp);
605 // Create "Model" arguments
606 auto buffers = flatbuffer_builder->CreateVector(buffer_vec);
607 auto operator_codes = flatbuffer_builder->CreateVector(code_vec);
608 auto subgraphs = flatbuffer_builder->CreateVector(subgraph_vec);
609 auto description = flatbuffer_builder->CreateString("Generated by tflchef");
612 tflite::ModelBuilder model_builder{*flatbuffer_builder};
614 model_builder.add_version(3);
615 model_builder.add_operator_codes(operator_codes);
616 model_builder.add_subgraphs(subgraphs);
617 model_builder.add_description(description);
618 model_builder.add_buffers(buffers);
620 auto model = model_builder.Finish();
623 ::tflite::FinishModelBuffer(*flatbuffer_builder, model);
625 // Return "GenerateModel"
626 return GeneratedModel{
627 std::unique_ptr<GeneratedModelImpl>(new GeneratedModelImpl(std::move(flatbuffer_builder)))};
630 } // namespace tflchef