From cd6367d4a4bc24ae4bf439b7085f0036c63db38b Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Wed, 21 Nov 2018 19:31:57 +0900 Subject: [PATCH] [tflchef] Implement generate_recipe (#2360) * [tflchef] Implement generate_recipe This implements partial of generate_recipe() for operators, operands and inputs,outputs. Signed-off-by: SaeHie Park * remove unused * reorder a line * fix typo * need check quantization * fix typo --- contrib/tflchef/tflite/src/RecipeChef.cpp | 89 ++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/contrib/tflchef/tflite/src/RecipeChef.cpp b/contrib/tflchef/tflite/src/RecipeChef.cpp index ff4537e..8c5dd02 100644 --- a/contrib/tflchef/tflite/src/RecipeChef.cpp +++ b/contrib/tflchef/tflite/src/RecipeChef.cpp @@ -16,6 +16,8 @@ #include +#include "Convert.h" +#include "TFliteImport.h" #include "TFliteOpChef.h" #include "TFliteOpChefs.h" #include "TFliteOpRegistry.h" @@ -25,10 +27,95 @@ namespace tflchef { +/** + * @brief This will build ModelRecipe from tflite::Model + * First to check operand filler options by scanning all operators, + * then translate all operands and operators. + * Last will set network inputs and outputs. + */ std::unique_ptr generate_recipe(const tflite::Model *model) { std::unique_ptr model_recipe{new ModelRecipe()}; - // TODO fill this + + TFliteImport tflite_import(model); + + assert(tflite_import.num_subgraph() == 1); + tflite_import.select_sub_graph(0); + + auto tensors = tflite_import.tensors(); + auto buffers = tflite_import.buffers(); + auto operators = tflite_import.operators(); + + // operand fillers for adding all operators + for (uint32_t i = 0; i < operators->Length(); ++i) + { + const auto *op = operators->Get(i); + tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op); + + // TODO add handler for builtincode + } + + // add all operands(tensors) + for (uint32_t i = 0; i < tensors->Length(); ++i) + { + auto tensor = tensors->Get(i); + // TODO support quantization + assert(tensor->quantization() == nullptr); + + // check buffer + if (tensor->buffer() >= buffers->size()) + throw std::runtime_error{"file load failed"}; + + ::tflchef::Operand *operand = model_recipe->add_operand(); + + operand->set_name(tensor_name(tensor)); + operand->set_type(as_tflchef_type(tensor->type())); + + std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); + ::tflchef::TensorShape *shape = operand->mutable_shape(); + for (auto dim : dims) + { + shape->add_dim(dim); + } + + // filler for weights, bias and so on + if (tflite_import.get_tensor_filler(i)) + { + tflchef::TensorFiller *filler = operand->mutable_filler(); + // Note: it is OK to use random weights for functionality validation + filler->set_tag("gaussian"); + filler->add_arg("0.0"); // average + filler->add_arg("0.1"); // standard deviation + } + } + + // add all operators + for (uint32_t i = 0; i < operators->Length(); ++i) + { + const auto *op = operators->Get(i); + tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op); + + // TODO add handler for builtincode + } + + // network inputs/outputs + const std::vector &inputs = tflite_import.inputs(); + const std::vector &outputs = tflite_import.outputs(); + + for (const auto input : inputs) + { + auto tensor = tensors->Get(input); + std::string name = tensor_name(tensor); + + model_recipe->add_input(name); + } + for (const auto output : outputs) + { + auto tensor = tensors->Get(output); + std::string name = tensor_name(tensor); + + model_recipe->add_output(name); + } return std::move(model_recipe); } -- 2.7.4