From 32a6779c7aeb28bc9a62b0f91906ebff10b10a72 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 11 Apr 2019 13:01:59 +0900 Subject: [PATCH] [moco] Parse Tensorflow pb/pbtxt graphdef (#3208) * [moco] Parse Tensorflow pb/pbtxt graphdef This will add codes to parse TensorFlow pb/pbtxt graphdef Signed-off-by: SaeHie Park * load with type * file type as enum * remove unused include --- contrib/moco/lib/frontend/tf/CMakeLists.txt | 1 + .../lib/frontend/tf/include/moco/tf/Frontend.h | 9 +++- contrib/moco/lib/frontend/tf/src/Frontend.cpp | 63 ++++++++++++++++++++-- contrib/moco/lib/frontend/tf/src/Frontend.test.cpp | 9 ++++ contrib/moco/requires.cmake | 1 + 5 files changed, 79 insertions(+), 4 deletions(-) diff --git a/contrib/moco/lib/frontend/tf/CMakeLists.txt b/contrib/moco/lib/frontend/tf/CMakeLists.txt index 33d61ac..76b3cdc 100644 --- a/contrib/moco/lib/frontend/tf/CMakeLists.txt +++ b/contrib/moco/lib/frontend/tf/CMakeLists.txt @@ -22,6 +22,7 @@ target_include_directories(moco_tf_frontend PUBLIC include) target_link_libraries(moco_tf_frontend PUBLIC moco_tf_proto) target_link_libraries(moco_tf_frontend PUBLIC loco) target_link_libraries(moco_tf_frontend PRIVATE stdex) +target_link_libraries(moco_tf_frontend PRIVATE cwrap) nncc_find_package(GTest QUIET) diff --git a/contrib/moco/lib/frontend/tf/include/moco/tf/Frontend.h b/contrib/moco/lib/frontend/tf/include/moco/tf/Frontend.h index c57c300..aa6f140 100644 --- a/contrib/moco/lib/frontend/tf/include/moco/tf/Frontend.h +++ b/contrib/moco/lib/frontend/tf/include/moco/tf/Frontend.h @@ -25,10 +25,17 @@ namespace tf class Frontend { public: + enum class FileType + { + Text, + Binary, + }; + +public: Frontend(); public: - void load(const char *) const; + void load(const char *, FileType) const; }; } // namespace tf diff --git a/contrib/moco/lib/frontend/tf/src/Frontend.cpp b/contrib/moco/lib/frontend/tf/src/Frontend.cpp index 99ed566..cd96cb2 100644 --- a/contrib/moco/lib/frontend/tf/src/Frontend.cpp +++ b/contrib/moco/lib/frontend/tf/src/Frontend.cpp @@ -16,8 +16,62 @@ #include +#include + +#include + +#include +#include +#include + +#include #include +#include +#include + +namespace +{ + +bool load_text(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def) +{ + google::protobuf::io::FileInputStream fis(fildes.get()); + + return google::protobuf::TextFormat::Parse(&fis, &graph_def); +} + +bool load_binary(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def) +{ + google::protobuf::io::FileInputStream fis(fildes.get()); + google::protobuf::io::CodedInputStream cis(&fis); + + return graph_def.ParseFromCodedStream(&cis); +} + +void load_tf(const std::string &path, moco::tf::Frontend::FileType type, + tensorflow::GraphDef &graph_def) +{ + cwrap::Fildes fildes{open(path.c_str(), O_RDONLY)}; + + if (fildes.get() < 0) + { + std::ostringstream ostr; + ostr << "Error: " << path << " not found" << std::endl; + throw std::runtime_error{ostr.str()}; + } + + bool result = (type == moco::tf::Frontend::FileType::Text) ? load_text(fildes, graph_def) + : load_binary(fildes, graph_def); + if (!result) + { + std::ostringstream ostr; + ostr << "Error: Failed to parse prototxt " << path << "\n"; + throw std::runtime_error{ostr.str()}; + } +} + +} // namespace + namespace moco { namespace tf @@ -28,10 +82,13 @@ Frontend::Frontend() // DO NOTHING } -void Frontend::load(const char *) const +void Frontend::load(const char *modelfile, FileType type) const { - // TODO implement this - throw std::runtime_error{"NYI"}; + tensorflow::GraphDef tf_graph_def; + + load_tf(modelfile, type, tf_graph_def); + + // TODO convert tf_graph_def } } // namespace tf diff --git a/contrib/moco/lib/frontend/tf/src/Frontend.test.cpp b/contrib/moco/lib/frontend/tf/src/Frontend.test.cpp index 65efc95..ac5db20 100644 --- a/contrib/moco/lib/frontend/tf/src/Frontend.test.cpp +++ b/contrib/moco/lib/frontend/tf/src/Frontend.test.cpp @@ -19,3 +19,12 @@ #include TEST(MocoTensotFlowFrontendTest, Dummy) { moco::tf::Frontend frontend; } + +TEST(TensorFlowFrontend, load_model) +{ + moco::tf::Frontend frontend; + + // TODO fix not to use "../../.." + frontend.load("../../../test/tf/Placeholder_000.pbtxt", moco::tf::Frontend::FileType::Text); + frontend.load("../../../test/tf/Placeholder_000.pb", moco::tf::Frontend::FileType::Binary); +} diff --git a/contrib/moco/requires.cmake b/contrib/moco/requires.cmake index c61de93..2740cc8 100644 --- a/contrib/moco/requires.cmake +++ b/contrib/moco/requires.cmake @@ -1 +1,2 @@ require("tfkit") +require("cwrap") -- 2.7.4