From 5014ec28c57e31d651df83e0a788a54abf06d8f6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 19 Apr 2018 19:41:59 +0900 Subject: [PATCH] [caffegen] Introduce 'fill' command (#104) This commit implements 'fill' command. Currently, only 'Input' layer is supported. Signed-off-by: Jonghyun Park --- contrib/caffegen/src/Driver.cpp | 2 + contrib/caffegen/src/FillCommand.cpp | 58 +++++++++++++++++++++++++ contrib/caffegen/src/FillCommand.h | 11 +++++ contrib/caffegen/src/ParameterRandomizePass.cpp | 13 ++++++ contrib/caffegen/src/ParameterRandomizePass.h | 20 +++++++++ 5 files changed, 104 insertions(+) create mode 100644 contrib/caffegen/src/FillCommand.cpp create mode 100644 contrib/caffegen/src/FillCommand.h create mode 100644 contrib/caffegen/src/ParameterRandomizePass.cpp create mode 100644 contrib/caffegen/src/ParameterRandomizePass.h diff --git a/contrib/caffegen/src/Driver.cpp b/contrib/caffegen/src/Driver.cpp index 2d1b3b1..05899a6 100644 --- a/contrib/caffegen/src/Driver.cpp +++ b/contrib/caffegen/src/Driver.cpp @@ -1,3 +1,4 @@ +#include "FillCommand.h" #include "EncodeCommand.h" #include @@ -9,6 +10,7 @@ int main(int argc, char **argv) { std::map> commands; + commands["fill"] = nncc::foundation::make_unique(); commands["encode"] = nncc::foundation::make_unique(); return commands.at(argv[1])->run(); diff --git a/contrib/caffegen/src/FillCommand.cpp b/contrib/caffegen/src/FillCommand.cpp new file mode 100644 index 0000000..cc05f85 --- /dev/null +++ b/contrib/caffegen/src/FillCommand.cpp @@ -0,0 +1,58 @@ +#include "FillCommand.h" +#include "LayerResolver.h" +#include "NetworkBuilder.h" +#include "ParameterRandomizePass.h" + +#include + +#include +#include +#include + +#include +#include +#include + +int FillCommand::run(void) const +{ + caffe::NetParameter param; + + // Read from standard input + google::protobuf::io::FileInputStream is{0}; + if (!google::protobuf::TextFormat::Parse(&is, ¶m)) + { + std::cerr << "ERROR: Failed to parse prototxt" << std::endl; + return 255; + } + + auto net = NetworkBuilder{LayerResolver{}}.build(¶m); + + uint32_t seed = std::chrono::system_clock::now().time_since_epoch().count(); + + // Allow users to override seed + { + char *env = std::getenv("SEED"); + + if (env) + { + seed = std::stoi(env); + } + } + + std::cerr << "Use '" << seed << "' as seed" << std::endl; + + // Create a random number generator + std::default_random_engine generator{seed}; + + // Randomize parameters + for (uint32_t n = 0; n < net->layers().size(); ++n) + { + net->layers().at(n).accept(ParameterRandomizePass{generator}); + } + + // Write to standard output + google::protobuf::io::FileOutputStream output(1); + google::protobuf::TextFormat::Print(param, &output); + + return 0; +} diff --git a/contrib/caffegen/src/FillCommand.h b/contrib/caffegen/src/FillCommand.h new file mode 100644 index 0000000..d9ce04e --- /dev/null +++ b/contrib/caffegen/src/FillCommand.h @@ -0,0 +1,11 @@ +#ifndef __FILL_COMMAND_H__ +#define __FILL_COMMAND_H__ + +#include "Command.h" + +struct FillCommand final : public Command +{ + int run(void) const override; +}; + +#endif // __FILL_COMMAND_H__ diff --git a/contrib/caffegen/src/ParameterRandomizePass.cpp b/contrib/caffegen/src/ParameterRandomizePass.cpp new file mode 100644 index 0000000..a760142 --- /dev/null +++ b/contrib/caffegen/src/ParameterRandomizePass.cpp @@ -0,0 +1,13 @@ +#include "ParameterRandomizePass.h" + +ParameterRandomizePass::ParameterRandomizePass(std::default_random_engine &generator) + : _generator{generator} +{ + // DO NOTHING +} + +void ParameterRandomizePass::visit(InputLayer &) +{ + // InputLayer has no parameter to be randomized + return; +} diff --git a/contrib/caffegen/src/ParameterRandomizePass.h b/contrib/caffegen/src/ParameterRandomizePass.h new file mode 100644 index 0000000..7c9ae8b --- /dev/null +++ b/contrib/caffegen/src/ParameterRandomizePass.h @@ -0,0 +1,20 @@ +#ifndef __PARAMETER_RANDOMIZE_PASS_H__ +#define __PARAMETER_RANDOMIZE_PASS_H__ + +#include "LayerTransformPass.h" + +#include + +class ParameterRandomizePass : public LayerTransformPass +{ +public: + ParameterRandomizePass(std::default_random_engine &generator); + +public: + void visit(InputLayer &) override; + +private: + std::default_random_engine &_generator; +}; + +#endif // __PARAMETER_RANDOMIZE_PASS_H__ -- 2.7.4