[moco] Parse Tensorflow pb/pbtxt graphdef (#3208)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 11 Apr 2019 04:01:59 +0000 (13:01 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 11 Apr 2019 04:01:59 +0000 (13:01 +0900)
* [moco] Parse Tensorflow pb/pbtxt graphdef

This will add codes to parse TensorFlow pb/pbtxt graphdef

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* load with type

* file type as enum

* remove unused include

contrib/moco/lib/frontend/tf/CMakeLists.txt
contrib/moco/lib/frontend/tf/include/moco/tf/Frontend.h
contrib/moco/lib/frontend/tf/src/Frontend.cpp
contrib/moco/lib/frontend/tf/src/Frontend.test.cpp
contrib/moco/requires.cmake

index 33d61ac..76b3cdc 100644 (file)
@@ -22,6 +22,7 @@ target_include_directories(moco_tf_frontend PUBLIC include)
 target_link_libraries(moco_tf_frontend PUBLIC moco_tf_proto)
 target_link_libraries(moco_tf_frontend PUBLIC loco)
 target_link_libraries(moco_tf_frontend PRIVATE stdex)
+target_link_libraries(moco_tf_frontend PRIVATE cwrap)
 
 nncc_find_package(GTest QUIET)
 
index c57c300..aa6f140 100644 (file)
@@ -25,10 +25,17 @@ namespace tf
 class Frontend
 {
 public:
+  enum class FileType
+  {
+    Text,
+    Binary,
+  };
+
+public:
   Frontend();
 
 public:
-  void load(const char *) const;
+  void load(const char *, FileType) const;
 };
 
 } // namespace tf
index 99ed566..cd96cb2 100644 (file)
 
 #include <moco/tf/Frontend.h>
 
+#include <cwrap/Fildes.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <sstream>
 #include <stdexcept>
 
+#include <fcntl.h>
+#include <unistd.h>
+
+namespace
+{
+
+bool load_text(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def)
+{
+  google::protobuf::io::FileInputStream fis(fildes.get());
+
+  return google::protobuf::TextFormat::Parse(&fis, &graph_def);
+}
+
+bool load_binary(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def)
+{
+  google::protobuf::io::FileInputStream fis(fildes.get());
+  google::protobuf::io::CodedInputStream cis(&fis);
+
+  return graph_def.ParseFromCodedStream(&cis);
+}
+
+void load_tf(const std::string &path, moco::tf::Frontend::FileType type,
+             tensorflow::GraphDef &graph_def)
+{
+  cwrap::Fildes fildes{open(path.c_str(), O_RDONLY)};
+
+  if (fildes.get() < 0)
+  {
+    std::ostringstream ostr;
+    ostr << "Error: " << path << " not found" << std::endl;
+    throw std::runtime_error{ostr.str()};
+  }
+
+  bool result = (type == moco::tf::Frontend::FileType::Text) ? load_text(fildes, graph_def)
+                                                             : load_binary(fildes, graph_def);
+  if (!result)
+  {
+    std::ostringstream ostr;
+    ostr << "Error: Failed to parse prototxt " << path << "\n";
+    throw std::runtime_error{ostr.str()};
+  }
+}
+
+} // namespace
+
 namespace moco
 {
 namespace tf
@@ -28,10 +82,13 @@ Frontend::Frontend()
   // DO NOTHING
 }
 
-void Frontend::load(const char *) const
+void Frontend::load(const char *modelfile, FileType type) const
 {
-  // TODO implement this
-  throw std::runtime_error{"NYI"};
+  tensorflow::GraphDef tf_graph_def;
+
+  load_tf(modelfile, type, tf_graph_def);
+
+  // TODO convert tf_graph_def
 }
 
 } // namespace tf
index 65efc95..ac5db20 100644 (file)
 #include <gtest/gtest.h>
 
 TEST(MocoTensotFlowFrontendTest, Dummy) { moco::tf::Frontend frontend; }
+
+TEST(TensorFlowFrontend, load_model)
+{
+  moco::tf::Frontend frontend;
+
+  // TODO fix not to use "../../.."
+  frontend.load("../../../test/tf/Placeholder_000.pbtxt", moco::tf::Frontend::FileType::Text);
+  frontend.load("../../../test/tf/Placeholder_000.pb", moco::tf::Frontend::FileType::Binary);
+}
index c61de93..2740cc8 100644 (file)
@@ -1 +1,2 @@
 require("tfkit")
+require("cwrap")