return bp::object();
}
+template<typename Dtype>
+class PythonCallback: public Solver<Dtype>::Callback {
+ protected:
+ bp::object on_start_, on_gradients_ready_;
+
+ public:
+ PythonCallback(bp::object on_start, bp::object on_gradients_ready)
+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
+ virtual void on_gradients_ready() {
+ on_gradients_ready_();
+ }
+ virtual void on_start() {
+ on_start_();
+ }
+};
+template<typename Dtype>
+void Solver_add_callback(Solver<Dtype> * solver, bp::object on_start,
+ bp::object on_gradients_ready) {
+ solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
+}
+
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
BOOST_PYTHON_MODULE(_caffe) {
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
bp::return_internal_reference<>()))
.add_property("iter", &Solver<Dtype>::iter)
+ .def("add_callback", &Solver_add_callback<Dtype>)
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
&Solver<Dtype>::Solve), SolveOverloads())
.def("step", &Solver<Dtype>::Step)