Exposing solver callbacks to python
authorphilkr <philkr@users.noreply.github.com>
Thu, 3 Sep 2015 21:28:55 +0000 (14:28 -0700)
committerphilkr <philkr@users.noreply.github.com>
Fri, 17 Jun 2016 21:37:24 +0000 (14:37 -0700)
python/caffe/_caffe.cpp

index 48a0c8f..334088e 100644 (file)
@@ -228,6 +228,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
   return bp::object();
 }
 
+template<typename Dtype>
+class PythonCallback: public Solver<Dtype>::Callback {
+ protected:
+  bp::object on_start_, on_gradients_ready_;
+
+ public:
+  PythonCallback(bp::object on_start, bp::object on_gradients_ready)
+    : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
+  virtual void on_gradients_ready() {
+    on_gradients_ready_();
+  }
+  virtual void on_start() {
+    on_start_();
+  }
+};
+template<typename Dtype>
+void Solver_add_callback(Solver<Dtype> * solver, bp::object on_start,
+  bp::object on_gradients_ready) {
+  solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
+}
+
 BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
 
 BOOST_PYTHON_MODULE(_caffe) {
@@ -317,6 +338,7 @@ BOOST_PYTHON_MODULE(_caffe) {
     .add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
           bp::return_internal_reference<>()))
     .add_property("iter", &Solver<Dtype>::iter)
+    .def("add_callback", &Solver_add_callback<Dtype>)
     .def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
           &Solver<Dtype>::Solve), SolveOverloads())
     .def("step", &Solver<Dtype>::Step)