Weight parameter in solver is used in caffe.exe
authoriovodov <b@ovdv.ru>
Mon, 18 Dec 2017 19:39:03 +0000 (22:39 +0300)
committeriovodov <b@ovdv.ru>
Sat, 10 Feb 2018 18:51:51 +0000 (21:51 +0300)
Loading weights is moved from caffe.exe to solver class, so new "weights" solver parameter is used not only from command line but when caffe is used as library (including python)

corrected formatting

fixed line length

more formatting corrected

src/caffe/proto/caffe.proto
src/caffe/solver.cpp
src/caffe/test/test_upgrade_proto.cpp
tools/caffe.cpp

index c96966b..4567861 100644 (file)
@@ -98,7 +98,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
+// SolverParameter next available ID: 43 (last added: weights)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -241,6 +241,16 @@ message SolverParameter {
 
   // Overlap compute and communication for data parallel training
   optional bool layer_wise_reduce = 41 [default = true];
+
+  // Path to caffemodel file(s) with pretrained weights to initialize finetuning.
+  // Tha same as command line --weights parameter for caffe train command.
+  // If command line --weights parameter if specified, it has higher priority
+  // and owerwrites this one(s).
+  // If --snapshot command line parameter is specified, this one(s) are ignored.
+  // If several model files are expected, they can be listed in a one 
+  // weights parameter separated by ',' (like in a command string) or
+  // in repeated weights parameters separately.
+  repeated string weights = 42;
 }
 
 // A message that stores the solver snapshots
index 0442693..d229acf 100644 (file)
@@ -3,6 +3,7 @@
 #include <string>
 #include <vector>
 
+#include "boost/algorithm/string.hpp"
 #include "caffe/solver.hpp"
 #include "caffe/util/format.hpp"
 #include "caffe/util/hdf5.hpp"
@@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
   current_step_ = 0;
 }
 
+// Load weights from the caffemodel(s) specified in "weights" solver parameter
+// into the train and test nets.
+template <typename Dtype>
+void LoadNetWeights(shared_ptr<Net<Dtype> > net,
+    const std::string& model_list) {
+  std::vector<std::string> model_names;
+  boost::split(model_names, model_list, boost::is_any_of(","));
+  for (int i = 0; i < model_names.size(); ++i) {
+    boost::trim(model_names[i]);
+    LOG(INFO) << "Finetuning from " << model_names[i];
+    net->CopyTrainedLayersFrom(model_names[i]);
+  }
+}
+
 template <typename Dtype>
 void Solver<Dtype>::InitTrainNet() {
   const int num_train_nets = param_.has_net() + param_.has_net_param() +
@@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() {
   net_state.MergeFrom(param_.train_state());
   net_param.mutable_state()->CopyFrom(net_state);
   net_.reset(new Net<Dtype>(net_param));
+  for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
+    LoadNetWeights(net_, param_.weights(w_idx));
+  }
 }
 
 template <typename Dtype>
@@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() {
         << "Creating test net (#" << i << ") specified by " << sources[i];
     test_nets_[i].reset(new Net<Dtype>(net_params[i]));
     test_nets_[i]->set_debug_info(param_.debug_info());
+    for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
+      LoadNetWeights(test_nets_[i], param_.weights(w_idx));
+    }
   }
 }
 
index 9dcc2aa..769112e 100644 (file)
@@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
   for (int i = 0; i < 6; ++i) {
     const string& input_proto =
         "net: 'examples/mnist/lenet_train_test.prototxt' "
+        "weights: 'examples/mnist/lenet_train_test1.caffemodel' "
+        "weights: 'examples/mnist/lenet_train_test2.caffemodel' "
         "test_iter: 100 "
         "test_interval: 500 "
         "base_lr: 0.01 "
@@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
         "solver_type: " + std::string(old_type_vec[i]) + " ";
     const string& expected_output_proto =
         "net: 'examples/mnist/lenet_train_test.prototxt' "
+        "weights: 'examples/mnist/lenet_train_test1.caffemodel' "
+        "weights: 'examples/mnist/lenet_train_test2.caffemodel' "
         "test_iter: 100 "
         "test_interval: 500 "
         "base_lr: 0.01 "
index 3587d8a..389cfb8 100644 (file)
@@ -146,20 +146,6 @@ int device_query() {
 }
 RegisterBrewFunction(device_query);
 
-// Load the weights from the specified caffemodel(s) into the train and
-// test nets.
-void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
-  std::vector<std::string> model_names;
-  boost::split(model_names, model_list, boost::is_any_of(",") );
-  for (int i = 0; i < model_names.size(); ++i) {
-    LOG(INFO) << "Finetuning from " << model_names[i];
-    solver->net()->CopyTrainedLayersFrom(model_names[i]);
-    for (int j = 0; j < solver->test_nets().size(); ++j) {
-      solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
-    }
-  }
-}
-
 // Translate the signal effect the user specified on the command-line to the
 // corresponding enumeration.
 caffe::SolverAction::Enum GetRequestedAction(
@@ -233,6 +219,13 @@ int train() {
         GetRequestedAction(FLAGS_sigint_effect),
         GetRequestedAction(FLAGS_sighup_effect));
 
+  if (FLAGS_snapshot.size()) {
+    solver_param.clear_weights();
+  } else if (FLAGS_weights.size()) {
+    solver_param.clear_weights();
+    solver_param.add_weights(FLAGS_weights);
+  }
+
   shared_ptr<caffe::Solver<float> >
       solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
 
@@ -241,8 +234,6 @@ int train() {
   if (FLAGS_snapshot.size()) {
     LOG(INFO) << "Resuming from " << FLAGS_snapshot;
     solver->Restore(FLAGS_snapshot.c_str());
-  } else if (FLAGS_weights.size()) {
-    CopyLayers(solver.get(), FLAGS_weights);
   }
 
   LOG(INFO) << "Starting Optimization";