From de6e9b186a8c9de2989276a501d6fc8c2f5f254d Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Tue, 13 Apr 2021 11:05:39 +0900 Subject: [PATCH] [Tf/Skeleton] Add basic meta data This patch adds basic meta data to tflite file for startup **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/compiler/tflite_interpreter.cpp | 51 ++++++++++++++++++++++++- test/unittest/compiler/unittest_interpreter.cpp | 6 +++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/nntrainer/compiler/tflite_interpreter.cpp b/nntrainer/compiler/tflite_interpreter.cpp index 690eb33..d9b220c 100644 --- a/nntrainer/compiler/tflite_interpreter.cpp +++ b/nntrainer/compiler/tflite_interpreter.cpp @@ -11,14 +11,63 @@ */ #include +#include +#include + #include +#include + +static constexpr const char *FUNC_TAG = "[TFLITE INTERPRETER] "; + +namespace { +/** + * @brief after finishing building, call this to safe to a file + * + * @param builder flatbuffer builder + * @param out out + */ +void builder2file(const flatbuffers::FlatBufferBuilder &builder, + const std::string &out) { + uint8_t *buf = builder.GetBufferPointer(); + size_t size = builder.GetSize(); + flatbuffers::Verifier v(buf, size); + + NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument) + << FUNC_TAG << "Verifying serialized model failed"; + + std::ofstream os(out, std::ios_base::binary); + NNTR_THROW_IF(!os.good(), std::invalid_argument) + << FUNC_TAG << "failed to open, reason: " << strerror(errno); + os.write((char *)builder.GetBufferPointer(), builder.GetSize()); + os.close(); +} +} // namespace + namespace nntrainer { void TfliteInterpreter::serialize( std::shared_ptr representation, const std::string &out) { - /** NYI!! */ + /// @todo check if graph is finalized + flatbuffers::FlatBufferBuilder fb_builder; + + /// @todo parse subgraph + /// 1. in&outs&weights + /// 2. buffer + /// 2. ops + /// 3. op_codes + + auto desc = fb_builder.CreateString("This file is generated from NNTrainer"); + + tflite::ModelBuilder model_builder(fb_builder); + model_builder.add_version(3); + model_builder.add_description(desc); + auto model = model_builder.Finish(); + + fb_builder.Finish(model, tflite::ModelIdentifier()); + + builder2file(fb_builder, out); } std::shared_ptr diff --git a/test/unittest/compiler/unittest_interpreter.cpp b/test/unittest/compiler/unittest_interpreter.cpp index 8bd40cb..9ad00ed 100644 --- a/test/unittest/compiler/unittest_interpreter.cpp +++ b/test/unittest/compiler/unittest_interpreter.cpp @@ -161,6 +161,12 @@ auto fc0 = LayerReprentation("fully_connected", auto flatten = LayerReprentation("flatten", {"name=flat"}); +#ifdef ENABLE_TFLITE_INTERPRETER +TEST(flatbuffer, playground) { + nntrainer::TfliteInterpreter interpreter; + interpreter.serialize(nullptr, "test.tflite"); +} +#endif /** * @brief make ini test case from given parameter */ -- 2.7.4