PyArray_DIMS(data_arr)[0]);
}
+Solver<Dtype>* GetSolverFromFile(const string& filename) {
+ SolverParameter param;
+ ReadProtoFromTextFileOrDie(filename, ¶m);
+ return GetSolver<Dtype>(param);
+}
+
struct NdarrayConverterGenerator {
template <typename T> struct apply;
};
}
};
+BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
+
BOOST_PYTHON_MODULE(_caffe) {
// below, we prepend an underscore to methods that will be replaced
// in Python
.add_property("type_name", bp::make_function(&Layer<Dtype>::type_name,
bp::return_value_policy<bp::copy_const_reference>()));
+ bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
+ "Solver", bp::no_init)
+ .add_property("net", &Solver<Dtype>::net)
+ .add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
+ bp::return_internal_reference<>()))
+ .add_property("iter", &Solver<Dtype>::iter)
+ .def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
+ &Solver<Dtype>::Solve), SolveOverloads())
+ .def("step", &Solver<Dtype>::Step);
+
+ bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
+ shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
+ "SGDSolver", bp::init<string>());
+ bp::class_<NesterovSolver<Dtype>, bp::bases<Solver<Dtype> >,
+ shared_ptr<NesterovSolver<Dtype> >, boost::noncopyable>(
+ "NesterovSolver", bp::init<string>());
+ bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
+ shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
+ "AdaGradSolver", bp::init<string>());
+
+ bp::def("get_solver", &GetSolverFromFile,
+ bp::return_value_policy<bp::manage_new_object>());
+
// vector wrappers for all the vector types we use
bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
.def(bp::vector_indexing_suite<vector<string> >());
bp::class_<vector<int> >("IntVec")
.def(bp::vector_indexing_suite<vector<int> >());
+ bp::class_<vector<shared_ptr<Net<Dtype> > > >("NetVec")
+ .def(bp::vector_indexing_suite<vector<shared_ptr<Net<Dtype> > >, true>());
import_array();
}