#include <moco/tf/Frontend.h>
+#include <cwrap/Fildes.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <sstream>
#include <stdexcept>
+#include <fcntl.h>
+#include <unistd.h>
+
+namespace
+{
+
+bool load_text(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def)
+{
+ google::protobuf::io::FileInputStream fis(fildes.get());
+
+ return google::protobuf::TextFormat::Parse(&fis, &graph_def);
+}
+
+bool load_binary(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def)
+{
+ google::protobuf::io::FileInputStream fis(fildes.get());
+ google::protobuf::io::CodedInputStream cis(&fis);
+
+ return graph_def.ParseFromCodedStream(&cis);
+}
+
+void load_tf(const std::string &path, moco::tf::Frontend::FileType type,
+ tensorflow::GraphDef &graph_def)
+{
+ cwrap::Fildes fildes{open(path.c_str(), O_RDONLY)};
+
+ if (fildes.get() < 0)
+ {
+ std::ostringstream ostr;
+ ostr << "Error: " << path << " not found" << std::endl;
+ throw std::runtime_error{ostr.str()};
+ }
+
+ bool result = (type == moco::tf::Frontend::FileType::Text) ? load_text(fildes, graph_def)
+ : load_binary(fildes, graph_def);
+ if (!result)
+ {
+ std::ostringstream ostr;
+ ostr << "Error: Failed to parse prototxt " << path << "\n";
+ throw std::runtime_error{ostr.str()};
+ }
+}
+
+} // namespace
+
namespace moco
{
namespace tf
// DO NOTHING
}
-void Frontend::load(const char *) const
+void Frontend::load(const char *modelfile, FileType type) const
{
- // TODO implement this
- throw std::runtime_error{"NYI"};
+ tensorflow::GraphDef tf_graph_def;
+
+ load_tf(modelfile, type, tf_graph_def);
+
+ // TODO convert tf_graph_def
}
} // namespace tf
#include <gtest/gtest.h>
TEST(MocoTensotFlowFrontendTest, Dummy) { moco::tf::Frontend frontend; }
+
+TEST(TensorFlowFrontend, load_model)
+{
+ moco::tf::Frontend frontend;
+
+ // TODO fix not to use "../../.."
+ frontend.load("../../../test/tf/Placeholder_000.pbtxt", moco::tf::Frontend::FileType::Text);
+ frontend.load("../../../test/tf/Placeholder_000.pb", moco::tf::Frontend::FileType::Binary);
+}