From 3f6f655aae58d4339d3d29b62425893320a85dbe Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A8=D0=B5=D0=B4?= =?utf8?q?=D1=8C=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Assistant=20Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 12 Sep 2018 12:57:44 +0300 Subject: [PATCH] [nnc] Add support for loading model weights to Caffegen (#1414) Added binary model with weights maker to caffegen. This actually reads weights from `.caffemodel` files. Signed-off-by: Andrei Shedko --- contrib/caffegen/src/Driver.cpp | 4 ++++ contrib/caffegen/src/MergeCommand.cpp | 42 +++++++++++++++++++++++++++++++++++ contrib/caffegen/src/MergeCommand.h | 17 ++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 contrib/caffegen/src/MergeCommand.cpp create mode 100644 contrib/caffegen/src/MergeCommand.h diff --git a/contrib/caffegen/src/Driver.cpp b/contrib/caffegen/src/Driver.cpp index 814b69f..835bd35 100644 --- a/contrib/caffegen/src/Driver.cpp +++ b/contrib/caffegen/src/Driver.cpp @@ -2,6 +2,7 @@ #include "FillCommand.h" #include "EncodeCommand.h" #include "DecodeCommand.h" +#include "MergeCommand.h" #include #include @@ -13,10 +14,13 @@ int main(int argc, char **argv) { cli::App app{argv[0]}; + // all receive data from stdin app.insert("init", nncc::foundation::make_unique()); app.insert("fill", nncc::foundation::make_unique()); app.insert("encode", nncc::foundation::make_unique()); app.insert("decode", nncc::foundation::make_unique()); + // takes 2 args: prototxt model and caffemodel weights in that order + app.insert("merge", nncc::foundation::make_unique()); return app.run(argc - 1, argv + 1); } diff --git a/contrib/caffegen/src/MergeCommand.cpp b/contrib/caffegen/src/MergeCommand.cpp new file mode 100644 index 0000000..58cc026 --- /dev/null +++ b/contrib/caffegen/src/MergeCommand.cpp @@ -0,0 +1,42 @@ +#include "MergeCommand.h" + +#include +#include +#include + +#include +#include +#include + +#include +#include + +int MergeCommand::run(int argc, const char *const *argv) const +{ + if (argc != 2) { + std::cerr << "ERROR: this command requires exactly 2 arguments" << std::endl; + return 254; + } + + std::string model_file = argv[0]; + std::string trained_file = argv[1]; + + // Load the network + caffe::Net caffe_test_net(model_file, caffe::TEST); + // Load the weights + caffe_test_net.CopyTrainedLayersFrom(trained_file); + + caffe::NetParameter net_param; + caffe_test_net.ToProto(&net_param); + + // Write binary with initialized params into standard output + google::protobuf::io::OstreamOutputStream os(&std::cout); + google::protobuf::io::CodedOutputStream coded_os{&os}; + + if (!net_param.SerializeToCodedStream(&coded_os)) + { + std::cerr << "ERROR: Failed to serialize" << std::endl; + return 255; + } + return 0; +} diff --git a/contrib/caffegen/src/MergeCommand.h b/contrib/caffegen/src/MergeCommand.h new file mode 100644 index 0000000..9ded07b --- /dev/null +++ b/contrib/caffegen/src/MergeCommand.h @@ -0,0 +1,17 @@ +#ifndef __MERGE_COMMAND_H__ +#define __MERGE_COMMAND_H__ + +#include + +/** + * @brief Takes .prototxt and .caffemodel filenames from ARGV + * and fills the model with trained wights. + * The resulting binary model with weights to be consumed by nnc is printed to StdOut + * @returns error code + */ +struct MergeCommand final : public cli::Command +{ + int run(int argc, const char * const *argv) const override; +}; + +#endif //__MERGE_COMMAND_H__ -- 2.7.4