[caffegen] Introduce 'fill' command (#104)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 19 Apr 2018 10:41:59 +0000 (19:41 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 19 Apr 2018 10:41:59 +0000 (19:41 +0900)
This commit implements 'fill' command. Currently, only 'Input' layer
is supported.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/caffegen/src/Driver.cpp
contrib/caffegen/src/FillCommand.cpp [new file with mode: 0644]
contrib/caffegen/src/FillCommand.h [new file with mode: 0644]
contrib/caffegen/src/ParameterRandomizePass.cpp [new file with mode: 0644]
contrib/caffegen/src/ParameterRandomizePass.h [new file with mode: 0644]

index 2d1b3b1..05899a6 100644 (file)
@@ -1,3 +1,4 @@
+#include "FillCommand.h"
 #include "EncodeCommand.h"
 
 #include <nncc/foundation/Memory.h>
@@ -9,6 +10,7 @@ int main(int argc, char **argv)
 {
   std::map<std::string, std::unique_ptr<Command>> commands;
 
+  commands["fill"] = nncc::foundation::make_unique<FillCommand>();
   commands["encode"] = nncc::foundation::make_unique<EncodeCommand>();
 
   return commands.at(argv[1])->run();
diff --git a/contrib/caffegen/src/FillCommand.cpp b/contrib/caffegen/src/FillCommand.cpp
new file mode 100644 (file)
index 0000000..cc05f85
--- /dev/null
@@ -0,0 +1,58 @@
+#include "FillCommand.h"
+#include "LayerResolver.h"
+#include "NetworkBuilder.h"
+#include "ParameterRandomizePass.h"
+
+#include <caffe.pb.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <chrono>
+#include <random>
+#include <iostream>
+
+int FillCommand::run(void) const
+{
+  caffe::NetParameter param;
+
+  // Read from standard input
+  google::protobuf::io::FileInputStream is{0};
+  if (!google::protobuf::TextFormat::Parse(&is, &param))
+  {
+    std::cerr << "ERROR: Failed to parse prototxt" << std::endl;
+    return 255;
+  }
+
+  auto net = NetworkBuilder{LayerResolver{}}.build(&param);
+
+  uint32_t seed = std::chrono::system_clock::now().time_since_epoch().count();
+
+  // Allow users to override seed
+  {
+    char *env = std::getenv("SEED");
+
+    if (env)
+    {
+      seed = std::stoi(env);
+    }
+  }
+
+  std::cerr << "Use '" << seed << "' as seed" << std::endl;
+
+  // Create a random number generator
+  std::default_random_engine generator{seed};
+
+  // Randomize parameters
+  for (uint32_t n = 0; n < net->layers().size(); ++n)
+  {
+    net->layers().at(n).accept(ParameterRandomizePass{generator});
+  }
+
+  // Write to standard output
+  google::protobuf::io::FileOutputStream output(1);
+  google::protobuf::TextFormat::Print(param, &output);
+
+  return 0;
+}
diff --git a/contrib/caffegen/src/FillCommand.h b/contrib/caffegen/src/FillCommand.h
new file mode 100644 (file)
index 0000000..d9ce04e
--- /dev/null
@@ -0,0 +1,11 @@
+#ifndef __FILL_COMMAND_H__
+#define __FILL_COMMAND_H__
+
+#include "Command.h"
+
+struct FillCommand final : public Command
+{
+  int run(void) const override;
+};
+
+#endif // __FILL_COMMAND_H__
diff --git a/contrib/caffegen/src/ParameterRandomizePass.cpp b/contrib/caffegen/src/ParameterRandomizePass.cpp
new file mode 100644 (file)
index 0000000..a760142
--- /dev/null
@@ -0,0 +1,13 @@
+#include "ParameterRandomizePass.h"
+
+ParameterRandomizePass::ParameterRandomizePass(std::default_random_engine &generator)
+    : _generator{generator}
+{
+  // DO NOTHING
+}
+
+void ParameterRandomizePass::visit(InputLayer &)
+{
+  // InputLayer has no parameter to be randomized
+  return;
+}
diff --git a/contrib/caffegen/src/ParameterRandomizePass.h b/contrib/caffegen/src/ParameterRandomizePass.h
new file mode 100644 (file)
index 0000000..7c9ae8b
--- /dev/null
@@ -0,0 +1,20 @@
+#ifndef __PARAMETER_RANDOMIZE_PASS_H__
+#define __PARAMETER_RANDOMIZE_PASS_H__
+
+#include "LayerTransformPass.h"
+
+#include <random>
+
+class ParameterRandomizePass : public LayerTransformPass
+{
+public:
+  ParameterRandomizePass(std::default_random_engine &generator);
+
+public:
+  void visit(InputLayer &) override;
+
+private:
+  std::default_random_engine &_generator;
+};
+
+#endif // __PARAMETER_RANDOMIZE_PASS_H__