solver minor change
authorYangqing Jia <jiayq84@gmail.com>
Mon, 4 Nov 2013 19:13:19 +0000 (11:13 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 4 Nov 2013 19:13:19 +0000 (11:13 -0800)
include/caffe/solver.hpp
src/caffe/solver.cpp

index 168f4b4..2b40242 100644 (file)
@@ -13,7 +13,7 @@ class Solver {
   explicit Solver(const SolverParameter& param);
   // The main entry of the solver function. In default, iter will be zero. Pass
   // in a non-zero iter number to resume training for a pre-trained net.
-  void Solve(const char* resume_file = NULL);
+  virtual void Solve(const char* resume_file = NULL);
   inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
   virtual ~Solver() {}
 
@@ -38,8 +38,8 @@ class Solver {
   virtual void RestoreSolverState(const SolverState& state) = 0;
   SolverParameter param_;
   int iter_;
-  Net<Dtype>* net_;
-  Net<Dtype>* test_net_;
+  shared_ptr<Net<Dtype> > net_;
+  shared_ptr<Net<Dtype> > test_net_;
 
   DISABLE_COPY_AND_ASSIGN(Solver);
 };
index 07b46a4..3562960 100644 (file)
@@ -19,7 +19,7 @@ namespace caffe {
 
 template <typename Dtype>
 Solver<Dtype>::Solver(const SolverParameter& param)
-    : param_(param), net_(NULL), test_net_(NULL) {
+    : param_(param), net_(), test_net_() {
   // Scaffolding code
   NetParameter train_net_param;
   ReadProtoFromTextFile(param_.train_net(), &train_net_param);
@@ -27,12 +27,12 @@ Solver<Dtype>::Solver(const SolverParameter& param)
   // a dummy bottom_vec instance to initialize the networks.
   vector<Blob<Dtype>*> bottom_vec;
   LOG(INFO) << "Creating training net.";
-  net_ = new Net<Dtype>(train_net_param, bottom_vec);
+  net_.reset(new Net<Dtype>(train_net_param, bottom_vec));
   if (param_.has_test_net()) {
     LOG(INFO) << "Creating testing net.";
     NetParameter test_net_param;
     ReadProtoFromTextFile(param_.test_net(), &test_net_param);
-    test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
+    test_net_.reset(new Net<Dtype>(test_net_param, bottom_vec));
     CHECK_GT(param_.test_iter(), 0);
     CHECK_GT(param_.test_interval(), 0);
   }
@@ -83,7 +83,7 @@ void Solver<Dtype>::Test() {
   LOG(INFO) << "Testing net";
   NetParameter net_param;
   net_->ToProto(&net_param);
-  CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
+  CHECK_NOTNULL(test_net_.get())->CopyTrainedLayersFrom(net_param);
   vector<Dtype> test_score;
   vector<Blob<Dtype>*> bottom_vec;
   for (int i = 0; i < param_.test_iter(); ++i) {
@@ -138,8 +138,10 @@ void Solver<Dtype>::Restore(const char* state_file) {
   SolverState state;
   NetParameter net_param;
   ReadProtoFromBinaryFile(state_file, &state);
-  ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
-  net_->CopyTrainedLayersFrom(net_param);
+  if (state.has_learned_net()) {
+    ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
+    net_->CopyTrainedLayersFrom(net_param);
+  }
   iter_ = state.iter();
   RestoreSolverState(state);
 }