split p2psync::run()
authorJun Shi <junshi@yahoo-inc.com>
Fri, 22 Jan 2016 17:58:37 +0000 (09:58 -0800)
committerJun Shi <junshi@yahoo-inc.com>
Sat, 5 Mar 2016 15:07:21 +0000 (07:07 -0800)
include/caffe/parallel.hpp
src/caffe/parallel.cpp
src/caffe/test/test_gradient_based_solver.cpp
tools/caffe.cpp

index 85fc2b5..6c496c8 100644 (file)
@@ -93,7 +93,10 @@ class P2PSync : public GPUParams<Dtype>, public Solver<Dtype>::Callback,
     return solver_;
   }
 
-  void run(const vector<int>& gpus);
+  void Run(const vector<int>& gpus);
+  void Prepare(const vector<int>& gpus,
+               vector<shared_ptr<P2PSync<Dtype> > >* syncs);
+  inline const int initial_iter() const { return initial_iter_; }
 
  protected:
   void on_start();
index 62f5d73..5bc41c6 100644 (file)
@@ -380,7 +380,8 @@ void P2PSync<Dtype>::on_gradients_ready() {
 }
 
 template<typename Dtype>
-void P2PSync<Dtype>::run(const vector<int>& gpus) {
+void P2PSync<Dtype>::Prepare(const vector<int>& gpus,
+            vector<shared_ptr<P2PSync<Dtype> > >* syncs) {
   // Pair devices for map-reduce synchronization
   vector<DevicePair> pairs;
   DevicePair::compute(gpus, &pairs);
@@ -391,15 +392,14 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) {
   LOG(INFO)<< "GPUs pairs " << s.str();
 
   SolverParameter param(solver_->param());
-  vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
 
   // Build the GPU tree by finding the parent for each solver
   for (int attempts = 0; attempts < pairs.size(); ++attempts) {
     for (int i = 1; i < pairs.size(); ++i) {
-      if (!syncs[i].get()) {
+      if (!syncs->at(i).get()) {
         P2PSync<Dtype>* parent = NULL;
-        for (int j = 0; j < syncs.size(); ++j) {
-          P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get();
+        for (int j = 0; j < syncs->size(); ++j) {
+          P2PSync<Dtype>* sync = j == 0 ? this : syncs->at(j).get();
           if (sync) {
             const SolverParameter& p = sync->solver()->param();
             if (p.device_id() == pairs[i].parent()) {
@@ -409,12 +409,18 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) {
         }
         if (parent) {
           param.set_device_id(pairs[i].device());
-          syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param));
-          parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get());
+          syncs->at(i).reset(new P2PSync<Dtype>(solver_, parent, param));
+          parent->children_.push_back((P2PSync<Dtype>*) syncs->at(i).get());
         }
       }
     }
   }
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::Run(const vector<int>& gpus) {
+  vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
+  Prepare(gpus, &syncs);
 
   LOG(INFO)<< "Starting Optimization";
 
index 09ec3a7..975a8f0 100644 (file)
@@ -204,7 +204,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
       Caffe::set_solver_count(gpus.size());
       this->sync_.reset(new P2PSync<Dtype>(
           this->solver_, NULL, this->solver_->param()));
-      this->sync_->run(gpus);
+      this->sync_->Run(gpus);
       Caffe::set_solver_count(1);
     }
     if (snapshot) {
index 95b2f82..5d9331f 100644 (file)
@@ -214,7 +214,7 @@ int train() {
 
   if (gpus.size() > 1) {
     caffe::P2PSync<float> sync(solver, NULL, solver->param());
-    sync.run(gpus);
+    sync.Run(gpus);
   } else {
     LOG(INFO) << "Starting Optimization";
     solver->Solve();