std::string noname;
};
-template <typename T> void cook_graph(const T &graph, CookParams &cp)
+template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph, CookParams &cp)
{
LOGGER(l);
subgraph_builder.add_name(name);
subgraph_vec.emplace_back(subgraph_builder.Finish());
+
+ return symbol_table;
}
} // namespace
// Operation-related
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
+ // SignatureDef-related
+ std::vector<flatbuffers::Offset<::tflite::SignatureDef>> signdef_vec;
+
// Graphs-related
std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
buffer_vec.emplace_back(buffer_builder.Finish());
}
+ // symbol_tables stores symbol_table of each sub graph
+ // this is used to find tensor ID(index) with tensor name
+ std::vector<std::map<std::string, int32_t>> symbol_tables;
+
//
// Create Main graph
//
CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
builtin_code_map, custom_code_vec, "main"};
- cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
+ auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
+ symbol_tables.push_back(table);
//
// Create subgraphs if exist
CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
builtin_code_map, custom_code_vec, stringStream.str()};
- cook_graph<::tflchef::Graph>(graph, cp);
+ auto table = cook_graph<::tflchef::Graph>(graph, cp);
+ symbol_tables.push_back(table);
+ }
+
+ // Create Signature-Def
+ //
+ for (int s = 0; s < model_recipe.signature_def_size(); ++s)
+ {
+ // load from recipe
+ const auto &rec_signature_def = model_recipe.signature_def(s);
+
+ std::vector<flatbuffers::Offset<::tflite::TensorMap>> tensormap_inputs;
+ std::vector<flatbuffers::Offset<::tflite::TensorMap>> tensormap_outputs;
+
+ // which subgraph index to cook
+ auto subgraph_index = 0;
+ if (rec_signature_def.has_subgraph_index())
+ {
+ subgraph_index = rec_signature_def.subgraph_index();
+ }
+ assert(subgraph_index < symbol_tables.size());
+ auto &symbol_table = symbol_tables[subgraph_index];
+
+ // cook for inputs
+ for (int si = 0; si < rec_signature_def.inputs_size(); ++si)
+ {
+ // recipe for input TensorMap
+ auto rec_tm_input = rec_signature_def.inputs(si);
+ auto name = flatbuffer_builder->CreateString(rec_tm_input.name());
+ uint32_t tensor_index = 0;
+ // either tensor or tensor_index should exist
+ assert(rec_tm_input.has_tensor() || rec_tm_input.has_tensor_index());
+ if (rec_tm_input.has_tensor())
+ {
+ // we can get tensor_index from symbol_table
+ auto tensor = rec_tm_input.tensor();
+ tensor_index = symbol_table[tensor];
+ }
+ else
+ {
+ // or we can use tensor_index itself
+ tensor_index = rec_tm_input.tensor_index();
+ }
+
+ ::tflite::TensorMapBuilder tensormap_builder{*flatbuffer_builder};
+ tensormap_builder.add_name(name);
+ tensormap_builder.add_tensor_index(tensor_index);
+ tensormap_inputs.push_back(tensormap_builder.Finish());
+ }
+ // cook for outputs, same as inputs
+ for (int so = 0; so < rec_signature_def.outputs_size(); ++so)
+ {
+ auto rec_tm_output = rec_signature_def.outputs(so);
+ auto name = flatbuffer_builder->CreateString(rec_tm_output.name());
+ uint32_t tensor_index = 0;
+ assert(rec_tm_output.has_tensor() || rec_tm_output.has_tensor_index());
+ if (rec_tm_output.has_tensor())
+ {
+ auto tensor = rec_tm_output.tensor();
+ tensor_index = symbol_table[tensor];
+ }
+ else
+ {
+ tensor_index = rec_tm_output.tensor_index();
+ }
+
+ ::tflite::TensorMapBuilder tensormap_builder{*flatbuffer_builder};
+ tensormap_builder.add_name(name);
+ tensormap_builder.add_tensor_index(tensor_index);
+ tensormap_outputs.push_back(tensormap_builder.Finish());
+ }
+
+ auto inputs = flatbuffer_builder->CreateVector(tensormap_inputs);
+ auto outputs = flatbuffer_builder->CreateVector(tensormap_outputs);
+ auto method_name = flatbuffer_builder->CreateString(rec_signature_def.method_name());
+ auto key = flatbuffer_builder->CreateString(rec_signature_def.key());
+ // TODO add validation for method_name and key
+
+ ::tflite::SignatureDefBuilder signature_def_builder{*flatbuffer_builder};
+ signature_def_builder.add_inputs(inputs);
+ signature_def_builder.add_outputs(outputs);
+ signature_def_builder.add_method_name(method_name);
+ signature_def_builder.add_key(key);
+ signature_def_builder.add_subgraph_index(rec_signature_def.subgraph_index());
+
+ signdef_vec.emplace_back(signature_def_builder.Finish());
}
// Create "Model" arguments
auto buffers = flatbuffer_builder->CreateVector(buffer_vec);
+ auto signdefs = flatbuffer_builder->CreateVector(signdef_vec);
auto operator_codes = flatbuffer_builder->CreateVector(code_vec);
auto subgraphs = flatbuffer_builder->CreateVector(subgraph_vec);
auto description = flatbuffer_builder->CreateString("Generated by tflchef");
model_builder.add_version(3);
model_builder.add_operator_codes(operator_codes);
+ model_builder.add_signature_defs(signdefs);
model_builder.add_subgraphs(subgraphs);
model_builder.add_description(description);
model_builder.add_buffers(buffers);