[tflchef] Implement generate_recipe (#2360)
author박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 21 Nov 2018 10:31:57 +0000 (19:31 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 21 Nov 2018 10:31:57 +0000 (19:31 +0900)
* [tflchef] Implement generate_recipe

This implements partial of generate_recipe() for operators, operands and inputs,outputs.

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* remove unused

* reorder a line

* fix typo

* need check quantization

* fix typo

contrib/tflchef/tflite/src/RecipeChef.cpp

index ff4537e..8c5dd02 100644 (file)
@@ -16,6 +16,8 @@
 
 #include <tflchef/RecipeChef.h>
 
+#include "Convert.h"
+#include "TFliteImport.h"
 #include "TFliteOpChef.h"
 #include "TFliteOpChefs.h"
 #include "TFliteOpRegistry.h"
 namespace tflchef
 {
 
+/**
+ * @brief This will build ModelRecipe from tflite::Model
+ *        First to check operand filler options by scanning all operators,
+ *        then translate all operands and operators.
+ *        Last will set network inputs and outputs.
+ */
 std::unique_ptr<ModelRecipe> generate_recipe(const tflite::Model *model)
 {
   std::unique_ptr<ModelRecipe> model_recipe{new ModelRecipe()};
-  // TODO fill this
+
+  TFliteImport tflite_import(model);
+
+  assert(tflite_import.num_subgraph() == 1);
+  tflite_import.select_sub_graph(0);
+
+  auto tensors = tflite_import.tensors();
+  auto buffers = tflite_import.buffers();
+  auto operators = tflite_import.operators();
+
+  // operand fillers for adding all operators
+  for (uint32_t i = 0; i < operators->Length(); ++i)
+  {
+    const auto *op = operators->Get(i);
+    tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
+
+    // TODO add handler for builtincode
+  }
+
+  // add all operands(tensors)
+  for (uint32_t i = 0; i < tensors->Length(); ++i)
+  {
+    auto tensor = tensors->Get(i);
+    // TODO support quantization
+    assert(tensor->quantization() == nullptr);
+
+    // check buffer
+    if (tensor->buffer() >= buffers->size())
+      throw std::runtime_error{"file load failed"};
+
+    ::tflchef::Operand *operand = model_recipe->add_operand();
+
+    operand->set_name(tensor_name(tensor));
+    operand->set_type(as_tflchef_type(tensor->type()));
+
+    std::vector<uint32_t> dims = FlatBufferIntArrayToVector(tensor->shape());
+    ::tflchef::TensorShape *shape = operand->mutable_shape();
+    for (auto dim : dims)
+    {
+      shape->add_dim(dim);
+    }
+
+    // filler for weights, bias and so on
+    if (tflite_import.get_tensor_filler(i))
+    {
+      tflchef::TensorFiller *filler = operand->mutable_filler();
+      // Note: it is OK to use random weights for functionality validation
+      filler->set_tag("gaussian");
+      filler->add_arg("0.0"); // average
+      filler->add_arg("0.1"); // standard deviation
+    }
+  }
+
+  // add all operators
+  for (uint32_t i = 0; i < operators->Length(); ++i)
+  {
+    const auto *op = operators->Get(i);
+    tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
+
+    // TODO add handler for builtincode
+  }
+
+  // network inputs/outputs
+  const std::vector<uint32_t> &inputs = tflite_import.inputs();
+  const std::vector<uint32_t> &outputs = tflite_import.outputs();
+
+  for (const auto input : inputs)
+  {
+    auto tensor = tensors->Get(input);
+    std::string name = tensor_name(tensor);
+
+    model_recipe->add_input(name);
+  }
+  for (const auto output : outputs)
+  {
+    auto tensor = tensors->Get(output);
+    std::string name = tensor_name(tensor);
+
+    model_recipe->add_output(name);
+  }
 
   return std::move(model_recipe);
 }