Loading weights is moved from caffe.exe to solver class, so new "weights" solver parameter is used not only from command line but when caffe is used as library (including python)
corrected formatting
fixed line length
more formatting corrected
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
+// SolverParameter next available ID: 43 (last added: weights)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];
+
+ // Path to caffemodel file(s) with pretrained weights to initialize finetuning.
+ // Tha same as command line --weights parameter for caffe train command.
+ // If command line --weights parameter if specified, it has higher priority
+ // and owerwrites this one(s).
+ // If --snapshot command line parameter is specified, this one(s) are ignored.
+ // If several model files are expected, they can be listed in a one
+ // weights parameter separated by ',' (like in a command string) or
+ // in repeated weights parameters separately.
+ repeated string weights = 42;
}
// A message that stores the solver snapshots
#include <string>
#include <vector>
+#include "boost/algorithm/string.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
current_step_ = 0;
}
+// Load weights from the caffemodel(s) specified in "weights" solver parameter
+// into the train and test nets.
+template <typename Dtype>
+void LoadNetWeights(shared_ptr<Net<Dtype> > net,
+ const std::string& model_list) {
+ std::vector<std::string> model_names;
+ boost::split(model_names, model_list, boost::is_any_of(","));
+ for (int i = 0; i < model_names.size(); ++i) {
+ boost::trim(model_names[i]);
+ LOG(INFO) << "Finetuning from " << model_names[i];
+ net->CopyTrainedLayersFrom(model_names[i]);
+ }
+}
+
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
net_.reset(new Net<Dtype>(net_param));
+ for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
+ LoadNetWeights(net_, param_.weights(w_idx));
+ }
}
template <typename Dtype>
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
+ for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
+ LoadNetWeights(test_nets_[i], param_.weights(w_idx));
+ }
}
}
for (int i = 0; i < 6; ++i) {
const string& input_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
+ "weights: 'examples/mnist/lenet_train_test1.caffemodel' "
+ "weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
"solver_type: " + std::string(old_type_vec[i]) + " ";
const string& expected_output_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
+ "weights: 'examples/mnist/lenet_train_test1.caffemodel' "
+ "weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
}
RegisterBrewFunction(device_query);
-// Load the weights from the specified caffemodel(s) into the train and
-// test nets.
-void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
- std::vector<std::string> model_names;
- boost::split(model_names, model_list, boost::is_any_of(",") );
- for (int i = 0; i < model_names.size(); ++i) {
- LOG(INFO) << "Finetuning from " << model_names[i];
- solver->net()->CopyTrainedLayersFrom(model_names[i]);
- for (int j = 0; j < solver->test_nets().size(); ++j) {
- solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
- }
- }
-}
-
// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));
+ if (FLAGS_snapshot.size()) {
+ solver_param.clear_weights();
+ } else if (FLAGS_weights.size()) {
+ solver_param.clear_weights();
+ solver_param.add_weights(FLAGS_weights);
+ }
+
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
- } else if (FLAGS_weights.size()) {
- CopyLayers(solver.get(), FLAGS_weights);
}
LOG(INFO) << "Starting Optimization";