Python Multi-GPU
[platform/upstream/caffeonacl.git] / python / caffe / _caffe.cpp
index 0a86045..04dac23 100644 (file)
@@ -267,12 +267,12 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
 }
 
 template<typename Dtype>
-class PythonCallback: public Solver<Dtype>::Callback {
+class SolverCallback: public Solver<Dtype>::Callback {
  protected:
   bp::object on_start_, on_gradients_ready_;
 
  public:
-  PythonCallback(bp::object on_start, bp::object on_gradients_ready)
+  SolverCallback(bp::object on_start, bp::object on_gradients_ready)
     : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
   virtual void on_gradients_ready() {
     on_gradients_ready_();
@@ -284,9 +284,61 @@ class PythonCallback: public Solver<Dtype>::Callback {
 template<typename Dtype>
 void Solver_add_callback(Solver<Dtype> * solver, bp::object on_start,
   bp::object on_gradients_ready) {
-  solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
+  solver->add_callback(new SolverCallback<Dtype>(on_start, on_gradients_ready));
 }
 
+// Seems boost cannot call the base method directly
+void Solver_add_nccl(SGDSolver<Dtype>* solver
+#ifdef USE_NCCL
+  , NCCL<Dtype>* nccl
+#endif
+) {
+#ifdef USE_NCCL
+  solver->add_callback(nccl);
+#endif
+}
+
+template<typename Dtype>
+class NetCallback: public Net<Dtype>::Callback {
+ public:
+  explicit NetCallback(bp::object run) : run_(run) {}
+
+ protected:
+  virtual void run(int layer) {
+    run_(layer);
+  }
+  bp::object run_;
+};
+void Net_before_forward(Net<Dtype>* net, bp::object run) {
+  net->add_before_forward(new NetCallback<Dtype>(run));
+}
+void Net_after_forward(Net<Dtype>* net, bp::object run) {
+  net->add_after_forward(new NetCallback<Dtype>(run));
+}
+void Net_before_backward(Net<Dtype>* net, bp::object run) {
+  net->add_before_backward(new NetCallback<Dtype>(run));
+}
+void Net_after_backward(Net<Dtype>* net, bp::object run) {
+  net->add_after_backward(new NetCallback<Dtype>(run));
+}
+
+void Net_add_nccl(Net<Dtype>* net
+#ifdef USE_NCCL
+  , NCCL<Dtype>* nccl
+#endif
+) {
+#ifdef USE_NCCL
+  net->add_after_backward(nccl);
+#endif
+}
+#ifndef USE_NCCL
+template<typename Dtype>
+class NCCL {
+ public:
+  NCCL(shared_ptr<Solver<Dtype> > solver, const string& uid) {}
+};
+#endif
+
 BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
 
 BOOST_PYTHON_MODULE(_caffe) {
@@ -303,6 +355,10 @@ BOOST_PYTHON_MODULE(_caffe) {
   bp::def("set_mode_gpu", &set_mode_gpu);
   bp::def("set_random_seed", &set_random_seed);
   bp::def("set_device", &Caffe::SetDevice);
+  bp::def("solver_count", &Caffe::solver_count);
+  bp::def("set_solver_count", &Caffe::set_solver_count);
+  bp::def("solver_rank", &Caffe::solver_rank);
+  bp::def("set_solver_rank", &Caffe::set_solver_rank);
 
   bp::def("layer_type_list", &LayerRegistry<Dtype>::LayerTypeList);
 
@@ -346,7 +402,12 @@ BOOST_PYTHON_MODULE(_caffe) {
         bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
     .def("save", &Net_Save)
     .def("save_hdf5", &Net_SaveHDF5)
-    .def("load_hdf5", &Net_LoadHDF5);
+    .def("load_hdf5", &Net_LoadHDF5)
+    .def("before_forward", &Net_before_forward)
+    .def("after_forward", &Net_after_forward)
+    .def("before_backward", &Net_before_backward)
+    .def("after_backward", &Net_after_backward)
+    .def("after_backward", &Net_add_nccl);
   BP_REGISTER_SHARED_PTR_TO_PYTHON(Net<Dtype>);
 
   bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
@@ -378,6 +439,10 @@ BOOST_PYTHON_MODULE(_caffe) {
     .add_property("type", bp::make_function(&Layer<Dtype>::type));
   BP_REGISTER_SHARED_PTR_TO_PYTHON(Layer<Dtype>);
 
+  bp::class_<SolverParameter>("SolverParameter", bp::no_init)
+    .add_property("max_iter", &SolverParameter::max_iter)
+    .add_property("display", &SolverParameter::display)
+    .add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce);
   bp::class_<LayerParameter>("LayerParameter", bp::no_init);
 
   bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
@@ -387,11 +452,14 @@ BOOST_PYTHON_MODULE(_caffe) {
           bp::return_internal_reference<>()))
     .add_property("iter", &Solver<Dtype>::iter)
     .def("add_callback", &Solver_add_callback<Dtype>)
+    .def("add_callback", &Solver_add_nccl)
     .def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
           &Solver<Dtype>::Solve), SolveOverloads())
     .def("step", &Solver<Dtype>::Step)
     .def("restore", &Solver<Dtype>::Restore)
-    .def("snapshot", &Solver<Dtype>::Snapshot);
+    .def("snapshot", &Solver<Dtype>::Snapshot)
+    .add_property("param", bp::make_function(&Solver<Dtype>::param,
+              bp::return_value_policy<bp::copy_const_reference>()));
   BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver<Dtype>);
 
   bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
@@ -435,6 +503,24 @@ BOOST_PYTHON_MODULE(_caffe) {
   bp::class_<vector<bool> >("BoolVec")
     .def(bp::vector_indexing_suite<vector<bool> >());
 
+  bp::class_<NCCL<Dtype>, shared_ptr<NCCL<Dtype> >,
+    boost::noncopyable>("NCCL",
+                        bp::init<shared_ptr<Solver<Dtype> >, const string&>())
+#ifdef USE_NCCL
+    .def("new_uid", &NCCL<Dtype>::new_uid).staticmethod("new_uid")
+    .def("bcast", &NCCL<Dtype>::Broadcast)
+#endif
+    /* NOLINT_NEXT_LINE(whitespace/semicolon) */
+  ;
+  BP_REGISTER_SHARED_PTR_TO_PYTHON(NCCL<Dtype>);
+
+  bp::class_<Timer, shared_ptr<Timer>, boost::noncopyable>(
+    "Timer", bp::init<>())
+    .def("start", &Timer::Start)
+    .def("stop", &Timer::Stop)
+    .add_property("ms", &Timer::MilliSeconds);
+  BP_REGISTER_SHARED_PTR_TO_PYTHON(Timer);
+
   // boost python expects a void (missing) return value, while import_array
   // returns NULL for python3. import_array1() forces a void return value.
   import_array1();