pycaffe: store a shared_ptr<CaffeNet> in SGDSolver
authorJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 25 Apr 2014 21:29:50 +0000 (14:29 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 2 May 2014 20:25:51 +0000 (13:25 -0700)
Doing this, rather than constructing the CaffeNet wrapper every time,
will allow the wrapper to hold references that last at least as long as
SGDSolver (which will be necessary to ensure that data used by
MemoryDataLayer doesn't get freed).

python/caffe/_caffe.cpp

index 1a44deb..853ddbe 100644 (file)
@@ -301,9 +301,12 @@ class CaffeSGDSolver {
     // exception if param_file can't be opened
     CheckFile(param_file);
     solver_.reset(new SGDSolver<float>(param_file));
+    // we need to explicitly store the net wrapper, rather than constructing
+    // it on the fly, so that it can hold references to Python objects
+    net_.reset(new CaffeNet(solver_->net()));
   }
 
-  CaffeNet net() { return CaffeNet(solver_->net()); }
+  shared_ptr<CaffeNet> net() { return net_; }
   void Solve() { return solver_->Solve(); }
   void SolveResume(const string& resume_file) {
     CheckFile(resume_file);
@@ -311,6 +314,7 @@ class CaffeSGDSolver {
   }
 
  protected:
+  shared_ptr<CaffeNet> net_;
   shared_ptr<SGDSolver<float> > solver_;
 };