#include "Dataset.h"
+#include <iterator>
#include <map>
#include <string>
#include <vector>
return registry;
}
+// @brief This will prepare a set of unique operator codes in the mode recipe
+std::set<tflite::BuiltinOperator> gather_opcode_set(const ::tflchef::ModelRecipe &model_recipe)
+{
+ std::set<tflite::BuiltinOperator> opcode_set;
+ for (const auto &operation : model_recipe.operation())
+ {
+ auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
+ opcode_set.insert(op_chef->code());
+ }
+ return opcode_set;
+}
+
} // namespace
namespace tflchef
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;
+ // Create OperatorCode
+ std::set<tflite::BuiltinOperator> opcode_set = gather_opcode_set(model_recipe);
+ for (auto opcode : opcode_set)
+ {
+ tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
+ code_builder.add_builtin_code(opcode);
+ auto code = code_builder.Finish();
+ // Update OperatorCode vector
+ code_vec.emplace_back(code);
+ }
+
// Create an Empty Buffer
//
// Buffer 0 SHOULD be an empty buffer in TensorFlow Lite model file
symbol_table[tensor_name] = tensor_index;
}
- // Create Operator & OperatorCode
+ // Create Operator
for (const auto &operation : model_recipe.operation())
{
assert(operation.has_type());
std::vector<int32_t> output_vec = as_dataset(operation.output()).map(lookup).vectorize();
auto outputs = flatbuffer_builder->CreateVector(output_vec);
- // Create OperatorCode
- tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
- code_builder.add_builtin_code(op_chef->code());
- auto code = code_builder.Finish();
-
- // Update OperatorCode vector
- uint32_t opcode_index = code_vec.size();
- code_vec.emplace_back(code);
-
// Create Option
auto options = op_chef->value(*flatbuffer_builder);
// Create Operator
tflite::OperatorBuilder op_builder{*flatbuffer_builder};
+ // Get operator code index from opcode_set with assumption, order of
+ // opcode_set is same as that of code_vec
+ auto op_it = opcode_set.find(op_chef->code());
+ assert(op_it != opcode_set.end());
+ uint32_t opcode_index = std::distance(opcode_set.begin(), op_it);
+
op_builder.add_opcode_index(opcode_index);
op_builder.add_inputs(inputs);
op_builder.add_outputs(outputs);