[caffegen] Add 'init' command (#918)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 7 Aug 2018 03:40:42 +0000 (12:40 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 7 Aug 2018 03:40:42 +0000 (12:40 +0900)
This commit introduces 'init' command to caffegen. Unlike 'fill' command,
'init' uses Caffe itself to initialize parameters with random values.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/caffegen/CMakeLists.txt
contrib/caffegen/src/Driver.cpp
contrib/caffegen/src/InitCommand.cpp [new file with mode: 0644]
contrib/caffegen/src/InitCommand.h [new file with mode: 0644]

index 15f28aa..fc7129c 100644 (file)
@@ -4,6 +4,12 @@ if(NOT CaffeProto_FOUND)
   return()
 endif(NOT CaffeProto_FOUND)
 
+nncc_find_package(Caffe QUIET)
+
+if(NOT Caffe_FOUND)
+  return()
+endif(NOT Caffe_FOUND)
+
 file(GLOB_RECURSE SOURCES "src/*.cpp")
 
 add_executable(caffegen ${SOURCES})
@@ -11,3 +17,4 @@ target_include_directories(caffegen PRIVATE include)
 target_link_libraries(caffegen nncc_foundation)
 target_link_libraries(caffegen cli)
 target_link_libraries(caffegen caffeproto)
+target_link_libraries(caffegen caffe)
index 9e3239f..814b69f 100644 (file)
@@ -1,3 +1,4 @@
+#include "InitCommand.h"
 #include "FillCommand.h"
 #include "EncodeCommand.h"
 #include "DecodeCommand.h"
@@ -12,6 +13,7 @@ int main(int argc, char **argv)
 {
   cli::App app{argv[0]};
 
+  app.insert("init", nncc::foundation::make_unique<InitCommand>());
   app.insert("fill", nncc::foundation::make_unique<FillCommand>());
   app.insert("encode", nncc::foundation::make_unique<EncodeCommand>());
   app.insert("decode", nncc::foundation::make_unique<DecodeCommand>());
diff --git a/contrib/caffegen/src/InitCommand.cpp b/contrib/caffegen/src/InitCommand.cpp
new file mode 100644 (file)
index 0000000..184c7d2
--- /dev/null
@@ -0,0 +1,49 @@
+#include "InitCommand.h"
+
+#include <caffe/net.hpp>
+#include <caffe/util/upgrade_proto.hpp>
+#include <caffe/proto/caffe.pb.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <iostream>
+
+int InitCommand::run(int, const char * const *) const
+{
+  // Read prototxt from standard input
+  ::caffe::NetParameter in;
+  {
+    google::protobuf::io::FileInputStream is{0};
+    if (!google::protobuf::TextFormat::Parse(&is, &in))
+    {
+      std::cerr << "ERROR: Failed to parse prototxt" << std::endl;
+      return 255;
+    }
+  }
+
+  // Upgrade prototxt if necessary
+  if (::caffe::NetNeedsUpgrade(in))
+  {
+    if (!::caffe::UpgradeNetAsNeeded("<stdin>", &in))
+    {
+      std::cerr << "ERROR: Failed to upgrade prototxt" << std::endl;
+      return 255;
+    }
+  }
+
+  ::caffe::Net<float> net(in);
+
+  // Extract initialized parameters
+  ::caffe::NetParameter out;
+  {
+    net.ToProto(&out);
+  }
+
+  // Write initialized parameters to standard output
+  google::protobuf::io::FileOutputStream os(1);
+  google::protobuf::TextFormat::Print(out, &os);
+
+  return 0;
+}
diff --git a/contrib/caffegen/src/InitCommand.h b/contrib/caffegen/src/InitCommand.h
new file mode 100644 (file)
index 0000000..1334239
--- /dev/null
@@ -0,0 +1,11 @@
+#ifndef __INIT_COMMAND_H__
+#define __INIT_COMMAND_H__
+
+#include <cli/Command.h>
+
+struct InitCommand final : public cli::Command
+{
+  int run(int argc, const char * const *argv) const override;
+};
+
+#endif // __INIT_COMMAND_H__