pycaffe: expose SGDSolver.solve
authorJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 2 Apr 2014 21:40:09 +0000 (14:40 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Sat, 5 Apr 2014 06:39:56 +0000 (23:39 -0700)
python/caffe/_caffe.cpp

index 9949430..1a0974f 100644 (file)
@@ -293,6 +293,11 @@ class CaffeSGDSolver {
   }
 
   CaffeNet net() { return CaffeNet(solver_->net()); }
+  void Solve() { return solver_->Solve(); }
+  void SolveResume(const string& resume_file) {
+    CheckFile(resume_file);
+    return solver_->Solve(resume_file);
+  }
 
  protected:
   shared_ptr<SGDSolver<float> > solver_;
@@ -333,7 +338,9 @@ BOOST_PYTHON_MODULE(_caffe) {
 
   boost::python::class_<CaffeSGDSolver, boost::noncopyable>(
       "SGDSolver", boost::python::init<string>())
-      .add_property("net", &CaffeSGDSolver::net);
+      .add_property("net", &CaffeSGDSolver::net)
+      .def("solve",        &CaffeSGDSolver::Solve)
+      .def("solve",        &CaffeSGDSolver::SolveResume);
 
   boost::python::class_<vector<CaffeBlob> >("BlobVec")
       .def(vector_indexing_suite<vector<CaffeBlob>, true>());