Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / tflchef / core / src / ModelChef.cpp
index 7028bd9..ada5ff5 100644 (file)
@@ -207,7 +207,7 @@ struct CookParams
   std::string noname;
 };
 
-template <typename T> void cook_graph(const T &graph, CookParams &cp)
+template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph, CookParams &cp)
 {
   LOGGER(l);
 
@@ -537,6 +537,8 @@ template <typename T> void cook_graph(const T &graph, CookParams &cp)
   subgraph_builder.add_name(name);
 
   subgraph_vec.emplace_back(subgraph_builder.Finish());
+
+  return symbol_table;
 }
 
 } // namespace
@@ -574,6 +576,9 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
   // Operation-related
   std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
 
+  // SignatureDef-related
+  std::vector<flatbuffers::Offset<::tflite::SignatureDef>> signdef_vec;
+
   // Graphs-related
   std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
 
@@ -617,13 +622,18 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
     buffer_vec.emplace_back(buffer_builder.Finish());
   }
 
+  // symbol_tables stores symbol_table of each sub graph
+  // this is used to find tensor ID(index) with tensor name
+  std::vector<std::map<std::string, int32_t>> symbol_tables;
+
   //
   // Create Main graph
   //
   CookParams cp{buffer_vec,       code_vec,        subgraph_vec, flatbuffer_builder,
                 builtin_code_map, custom_code_vec, "main"};
 
-  cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
+  auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
+  symbol_tables.push_back(table);
 
   //
   // Create subgraphs if exist
@@ -638,11 +648,97 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
     CookParams cp{buffer_vec,       code_vec,        subgraph_vec,      flatbuffer_builder,
                   builtin_code_map, custom_code_vec, stringStream.str()};
 
-    cook_graph<::tflchef::Graph>(graph, cp);
+    auto table = cook_graph<::tflchef::Graph>(graph, cp);
+    symbol_tables.push_back(table);
+  }
+
+  // Create Signature-Def
+  //
+  for (int s = 0; s < model_recipe.signature_def_size(); ++s)
+  {
+    // load from recipe
+    const auto &rec_signature_def = model_recipe.signature_def(s);
+
+    std::vector<flatbuffers::Offset<::tflite::TensorMap>> tensormap_inputs;
+    std::vector<flatbuffers::Offset<::tflite::TensorMap>> tensormap_outputs;
+
+    // which subgraph index to cook
+    auto subgraph_index = 0;
+    if (rec_signature_def.has_subgraph_index())
+    {
+      subgraph_index = rec_signature_def.subgraph_index();
+    }
+    assert(subgraph_index < symbol_tables.size());
+    auto &symbol_table = symbol_tables[subgraph_index];
+
+    // cook for inputs
+    for (int si = 0; si < rec_signature_def.inputs_size(); ++si)
+    {
+      // recipe for input TensorMap
+      auto rec_tm_input = rec_signature_def.inputs(si);
+      auto name = flatbuffer_builder->CreateString(rec_tm_input.name());
+      uint32_t tensor_index = 0;
+      // either tensor or tensor_index should exist
+      assert(rec_tm_input.has_tensor() || rec_tm_input.has_tensor_index());
+      if (rec_tm_input.has_tensor())
+      {
+        // we can get tensor_index from symbol_table
+        auto tensor = rec_tm_input.tensor();
+        tensor_index = symbol_table[tensor];
+      }
+      else
+      {
+        // or we can use tensor_index itself
+        tensor_index = rec_tm_input.tensor_index();
+      }
+
+      ::tflite::TensorMapBuilder tensormap_builder{*flatbuffer_builder};
+      tensormap_builder.add_name(name);
+      tensormap_builder.add_tensor_index(tensor_index);
+      tensormap_inputs.push_back(tensormap_builder.Finish());
+    }
+    // cook for outputs, same as inputs
+    for (int so = 0; so < rec_signature_def.outputs_size(); ++so)
+    {
+      auto rec_tm_output = rec_signature_def.outputs(so);
+      auto name = flatbuffer_builder->CreateString(rec_tm_output.name());
+      uint32_t tensor_index = 0;
+      assert(rec_tm_output.has_tensor() || rec_tm_output.has_tensor_index());
+      if (rec_tm_output.has_tensor())
+      {
+        auto tensor = rec_tm_output.tensor();
+        tensor_index = symbol_table[tensor];
+      }
+      else
+      {
+        tensor_index = rec_tm_output.tensor_index();
+      }
+
+      ::tflite::TensorMapBuilder tensormap_builder{*flatbuffer_builder};
+      tensormap_builder.add_name(name);
+      tensormap_builder.add_tensor_index(tensor_index);
+      tensormap_outputs.push_back(tensormap_builder.Finish());
+    }
+
+    auto inputs = flatbuffer_builder->CreateVector(tensormap_inputs);
+    auto outputs = flatbuffer_builder->CreateVector(tensormap_outputs);
+    auto method_name = flatbuffer_builder->CreateString(rec_signature_def.method_name());
+    auto key = flatbuffer_builder->CreateString(rec_signature_def.key());
+    // TODO add validation for method_name and key
+
+    ::tflite::SignatureDefBuilder signature_def_builder{*flatbuffer_builder};
+    signature_def_builder.add_inputs(inputs);
+    signature_def_builder.add_outputs(outputs);
+    signature_def_builder.add_method_name(method_name);
+    signature_def_builder.add_key(key);
+    signature_def_builder.add_subgraph_index(rec_signature_def.subgraph_index());
+
+    signdef_vec.emplace_back(signature_def_builder.Finish());
   }
 
   // Create "Model" arguments
   auto buffers = flatbuffer_builder->CreateVector(buffer_vec);
+  auto signdefs = flatbuffer_builder->CreateVector(signdef_vec);
   auto operator_codes = flatbuffer_builder->CreateVector(code_vec);
   auto subgraphs = flatbuffer_builder->CreateVector(subgraph_vec);
   auto description = flatbuffer_builder->CreateString("Generated by tflchef");
@@ -652,6 +748,7 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
 
   model_builder.add_version(3);
   model_builder.add_operator_codes(operator_codes);
+  model_builder.add_signature_defs(signdefs);
   model_builder.add_subgraphs(subgraphs);
   model_builder.add_description(description);
   model_builder.add_buffers(buffers);