[pycaffe] re-expose SGDSolver, and expose other solvers
authorJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 6 Jan 2015 05:53:51 +0000 (21:53 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 17 Feb 2015 06:46:13 +0000 (22:46 -0800)
python/caffe/__init__.py
python/caffe/_caffe.cpp

index 9f0e12c..49f8678 100644 (file)
@@ -1,6 +1,6 @@
 from .pycaffe import Net, SGDSolver
 from ._caffe import set_mode_cpu, set_mode_gpu, set_device, \
-    set_phase_train, set_phase_test, Layer
+    set_phase_train, set_phase_test, Layer, get_solver
 from .classifier import Classifier
 from .detector import Detector
 import io
index 135178f..30e5c72 100644 (file)
@@ -119,6 +119,12 @@ void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
       PyArray_DIMS(data_arr)[0]);
 }
 
+Solver<Dtype>* GetSolverFromFile(const string& filename) {
+  SolverParameter param;
+  ReadProtoFromTextFileOrDie(filename, &param);
+  return GetSolver<Dtype>(param);
+}
+
 struct NdarrayConverterGenerator {
   template <typename T> struct apply;
 };
@@ -157,6 +163,8 @@ struct NdarrayCallPolicies : public bp::default_call_policies {
   }
 };
 
+BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
+
 BOOST_PYTHON_MODULE(_caffe) {
   // below, we prepend an underscore to methods that will be replaced
   // in Python
@@ -216,6 +224,29 @@ BOOST_PYTHON_MODULE(_caffe) {
     .add_property("type_name", bp::make_function(&Layer<Dtype>::type_name,
         bp::return_value_policy<bp::copy_const_reference>()));
 
+  bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
+    "Solver", bp::no_init)
+    .add_property("net", &Solver<Dtype>::net)
+    .add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
+          bp::return_internal_reference<>()))
+    .add_property("iter", &Solver<Dtype>::iter)
+    .def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
+          &Solver<Dtype>::Solve), SolveOverloads())
+    .def("step", &Solver<Dtype>::Step);
+
+  bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
+    shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
+        "SGDSolver", bp::init<string>());
+  bp::class_<NesterovSolver<Dtype>, bp::bases<Solver<Dtype> >,
+    shared_ptr<NesterovSolver<Dtype> >, boost::noncopyable>(
+        "NesterovSolver", bp::init<string>());
+  bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
+    shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
+        "AdaGradSolver", bp::init<string>());
+
+  bp::def("get_solver", &GetSolverFromFile,
+      bp::return_value_policy<bp::manage_new_object>());
+
   // vector wrappers for all the vector types we use
   bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
     .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
@@ -225,6 +256,8 @@ BOOST_PYTHON_MODULE(_caffe) {
     .def(bp::vector_indexing_suite<vector<string> >());
   bp::class_<vector<int> >("IntVec")
     .def(bp::vector_indexing_suite<vector<int> >());
+  bp::class_<vector<shared_ptr<Net<Dtype> > > >("NetVec")
+    .def(bp::vector_indexing_suite<vector<shared_ptr<Net<Dtype> > >, true>());
 
   import_array();
 }