[moco] provide istream as input of conversion (#3334)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 22 Apr 2019 04:30:27 +0000 (13:30 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 22 Apr 2019 04:30:27 +0000 (13:30 +0900)
This will add another load() to parse from istream and a test code

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/include/moco/tf/Frontend.h
contrib/moco/lib/frontend/tf/src/Frontend.cpp
contrib/moco/lib/frontend/tf/src/Frontend.test.cpp

index f9495fa..8f48e9f 100644 (file)
@@ -19,6 +19,7 @@
 
 #include <loco.h>
 
+#include <istream>
 #include <memory>
 #include <string>
 #include <vector>
@@ -56,6 +57,7 @@ public:
 
 public:
   std::unique_ptr<loco::Graph> load(const ModelSignature &, const char *, FileType) const;
+  std::unique_ptr<loco::Graph> load(const ModelSignature &, std::istream *, FileType) const;
 };
 
 } // namespace tf
index f687144..f5da991 100644 (file)
@@ -212,5 +212,19 @@ std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, con
   return std::move(graph);
 }
 
+std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, std::istream *stream,
+                                            FileType type) const
+{
+  tensorflow::GraphDef tf_graph_def;
+
+  load_tf(stream, type, tf_graph_def);
+
+  auto graph = loco::make_graph();
+
+  convert_graph(signature, tf_graph_def, graph.get());
+
+  return std::move(graph);
+}
+
 } // namespace tf
 } // namespace moco
index 9eec957..a52d111 100644 (file)
 
 #include <gtest/gtest.h>
 
+#include <cstring>
+
+#define STRING(content) #content
+
 TEST(MocoTensotFlowFrontendTest, Dummy) { moco::tf::Frontend frontend; }
 
 TEST(TensorFlowFrontend, load_model)
@@ -31,3 +35,83 @@ TEST(TensorFlowFrontend, load_model)
   frontend.load(signature, "../../../test/tf/Placeholder_000.pb",
                 moco::tf::Frontend::FileType::Binary);
 }
+
+namespace
+{
+
+struct membuf : std::streambuf
+{
+  membuf(char const *base, size_t size)
+  {
+    char *p(const_cast<char *>(base));
+    this->setg(p, p, p + size);
+  }
+};
+
+struct imemstream : virtual membuf, std::istream
+{
+  imemstream(char const *base, size_t size)
+      : membuf(base, size), std::istream(static_cast<std::streambuf *>(this))
+  {
+  }
+};
+
+// clang-format off
+const char *basic_pbtxtdata = STRING(
+node {
+  name: "Placeholder"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        dim {
+          size: 1
+        }
+        dim {
+          size: 2
+        }
+        dim {
+          size: 1
+        }
+        dim {
+          size: 2
+        }
+      }
+    }
+  }
+}
+node {
+  name: "output/identity"
+  op: "Identity"
+  input: "Placeholder"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, load_model_withio)
+{
+  moco::tf::Frontend frontend;
+  moco::tf::ModelSignature signature;
+
+  imemstream mempb(basic_pbtxtdata, std::strlen(basic_pbtxtdata));
+
+  signature.add_input("Placeholder");
+  signature.add_output("output/identity");
+
+  frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+}