distributed solver still having bugs. Pausing for now...
authorYangqing Jia <jiayq84@gmail.com>
Thu, 7 Nov 2013 19:06:20 +0000 (11:06 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 7 Nov 2013 19:06:20 +0000 (11:06 -0800)
include/caffe/distributed_solver.hpp
src/caffe/common.cpp
src/caffe/distributed_solver.cpp

index 37f3f08..e840ff5 100644 (file)
@@ -12,6 +12,8 @@
 
 #include "caffe/solver.hpp"
 
+using boost::asio::ip::tcp;
+
 
 namespace caffe {
 
@@ -29,7 +31,8 @@ class DistributedSolverParamServer : public Solver<Dtype> {
   virtual void SnapshotSolverState(SolverState* state) {}
   virtual void RestoreSolverState(const SolverState& state) {}
   // The function that implements the actual communication.
-  void ReceiveAndSend(boost::asio::io_service& io_s);
+  void ReceiveAndSend(boost::asio::io_service& io_s,
+      tcp::acceptor& data_acceptor);
 
   int next_snapshot_;
 };
index a70a280..54ab476 100644 (file)
@@ -20,6 +20,16 @@ inline bool StillFresh() {
   return (difftime(time(NULL), mktime(&fresh_time)) < 0);
 }
 
+
+long cluster_seedgen(void) {
+  long s, seed, pid;
+  pid = getpid();
+  s = time(NULL);
+  seed = abs(((s * 181) * ((pid - 83) * 359)) % 104729);
+  return seed;
+}
+
+
 Caffe::Caffe()
     : mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
       curand_generator_(NULL), vsl_stream_(NULL) {
@@ -36,13 +46,13 @@ Caffe::Caffe()
   // Try to create a curand handler.
   if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)
       != CURAND_STATUS_SUCCESS ||
-      curandSetPseudoRandomGeneratorSeed(curand_generator_, time(NULL))
+      curandSetPseudoRandomGeneratorSeed(curand_generator_, cluster_seedgen())
       != CURAND_STATUS_SUCCESS) {
     LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
   }
   // Try to create a vsl stream. This should almost always work, but we will
   // check it anyway.
-  if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, time(NULL)) != VSL_STATUS_OK) {
+  if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
     LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
         << "won't be available.";
   }
@@ -89,7 +99,7 @@ void Caffe::SetDevice(const int device_id) {
   CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
       CURAND_RNG_PSEUDO_DEFAULT));
   CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
-      time(NULL)));
+      cluster_seedgen()));
 }
 
 void Caffe::DeviceQuery() {
index 6596838..c3bd826 100644 (file)
@@ -33,8 +33,10 @@ void DistributedSolverParamServer<Dtype>::Solve(const char* resume_file) {
   // the main loop.
   LOG(INFO) << "Waiting for incoming updates...";
   boost::asio::io_service io_s;
+  tcp::acceptor data_acceptor(
+      io_s, tcp::endpoint(tcp::v4(), atoi(this->param_.tcp_port().c_str())));
   while (this->iter_ < this->param_.max_iter()) {
-    ReceiveAndSend(io_s);
+    ReceiveAndSend(io_s, data_acceptor);
     // Check if we need to do snapshot
     if (this->param_.snapshot() && this->iter_ > next_snapshot_) {
       Solver<Dtype>::Snapshot();
@@ -52,58 +54,55 @@ void DistributedSolverParamServer<Dtype>::Solve(const char* resume_file) {
 // client.
 template <typename Dtype>
 void DistributedSolverParamServer<Dtype>::ReceiveAndSend(
-    boost::asio::io_service& io_s) {
+    boost::asio::io_service& io_s, tcp::acceptor& data_acceptor) {
   bool send_only;
   int incoming_iter;
 
-  tcp::acceptor data_acceptor(
-      io_s, tcp::endpoint(tcp::v4(), atoi(this->param_.tcp_port().c_str())));
-  tcp::iostream data_stream;
-  data_acceptor.accept(*(data_stream.rdbuf()));
+  tcp::socket socket(io_s);
+  data_acceptor.accept(socket);
   LOG(INFO) << "Incoming connection.";
-  data_stream.read(reinterpret_cast<char*>(&send_only), sizeof(send_only));
+  boost::asio::read(socket,
+      boost::asio::buffer(reinterpret_cast<void*>(&send_only), sizeof(send_only)));
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   if (!send_only) {
     // Receive data
     LOG(INFO) << "Receiving data.";
-    data_stream.read(reinterpret_cast<char*>(&incoming_iter),
-        sizeof(incoming_iter));
+    boost::asio::read(socket,
+        boost::asio::buffer(reinterpret_cast<void*>(&incoming_iter), sizeof(incoming_iter)));
     int total_received = 0;
     LOG(INFO) << "Incoming iterations: " << incoming_iter;
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
       Dtype* param_diff = net_params[param_id]->mutable_cpu_diff();
       int count = net_params[param_id]->count();
-      data_stream.read(reinterpret_cast<char*>(param_diff),
-          count * sizeof(Dtype));
+      boost::asio::read(socket,
+          boost::asio::buffer(reinterpret_cast<void*>(param_diff),
+                count * sizeof(Dtype)));
       total_received += count;
     }
     LOG(INFO) << "Received " << total_received << " variables.";
     // Check error: if there are any error in the receiving phase, we will not
     // trust the passed in update.
-    if (data_stream.error()) {
-      LOG(ERROR) << "Error in receiving. Error code: " << data_stream.error().message();
-    } else {
-      // If the read is successful, update the network.
-      this->iter_ += incoming_iter;
-      this->net_->Update();
-    }
+    this->iter_ += incoming_iter;
+    this->net_->Update();
   } else {
     LOG(INFO) << "No incoming updates. Will simply send data.";
   }
   // Send data
   LOG(INFO) << "Sending data";
-  data_stream.write(reinterpret_cast<char*>(&(this->iter_)), sizeof(this->iter_));
+  boost::asio::write(socket,
+      boost::asio::buffer(reinterpret_cast<char*>(&(this->iter_)),
+          sizeof(this->iter_)));
   LOG(INFO) << "Current iteration: " << this->iter_;
   int total_sent = 0;
   for (int param_id = 0; param_id < net_params.size(); ++param_id) {
     const Dtype* param_data = net_params[param_id]->cpu_data();
     int count = net_params[param_id]->count();
-    data_stream.write(reinterpret_cast<const char*>(param_data),
-        sizeof(Dtype) * count);
+    boost::asio::write(socket,
+        boost::asio::buffer(reinterpret_cast<const char*>(param_data),
+            sizeof(Dtype) * count));
     total_sent += count;
   }
   LOG(INFO) << "Sent " << total_sent << " variables.";
-  data_stream.flush();
 }
 
 
@@ -140,52 +139,60 @@ void DistributedSolverParamClient<Dtype>::Solve(const char* resume_file) {
 
 template <typename Dtype>
 void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
-  tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port());
-  if (!data_stream) {
-    LOG(FATAL) << "Unable to connect. Error code: " << data_stream.error().message();
+  boost::asio::io_service io_s;
+  tcp::resolver resolver(io_s);
+  tcp::resolver::query query(
+      this->param_.tcp_server(), this->param_.tcp_port());
+  tcp::resolver::iterator endpoint_iterator = resolver.resolve(query);
+  tcp::resolver::iterator end;
+  tcp::socket socket(io_s);
+  boost::system::error_code error = boost::asio::error::host_not_found;
+  while (error && endpoint_iterator != end) {
+    socket.close();
+    socket.connect(*endpoint_iterator++, error);
+  }
+  if (error) {
+    LOG(FATAL) << "Unable to connect. Error: " << error.message();
   }
-  data_stream.write(reinterpret_cast<char*>(&receive_only), sizeof(receive_only));
+  boost::asio::write(socket,
+      boost::asio::buffer(reinterpret_cast<char*>(&receive_only), sizeof(receive_only)));
   if (!receive_only) {
     LOG(INFO) << "Sending local changes.";
     int local_iters = this->param_.communication_interval();
-    data_stream.write(reinterpret_cast<char*>(&local_iters),
-        sizeof(local_iters));
+    boost::asio::write(socket,
+        boost::asio::buffer(reinterpret_cast<char*>(&local_iters), sizeof(local_iters)));
     int total_sent = 0;
     // TODO: send the accumulated gradient stored at history_, and set it to
     // zero for future accumulation
     for (int param_id = 0; param_id < this->history_.size(); ++param_id) {
       Dtype* accum_history_data = this->history_[param_id]->mutable_cpu_diff();
       int count = this->history_[param_id]->count();
-      data_stream.write(reinterpret_cast<char*>(accum_history_data),
-          sizeof(Dtype) * count);
+      boost::asio::write(socket,
+          boost::asio::buffer(reinterpret_cast<char*>(accum_history_data),
+              sizeof(Dtype) * count));
       memset(accum_history_data, 0, sizeof(Dtype) * count);
       total_sent += count;
     }
     LOG(INFO) << "Sent " << total_sent << " variables.";
-    CHECK(!data_stream.error()) << "Error in sending. Error code: "
-        << data_stream.error().message();
   }
-  data_stream.flush();
   // Receive parameters
   LOG(INFO) << "Receiving parameters.";
-  data_stream.read(reinterpret_cast<char*>(&(this->iter_)),
-      sizeof(this->iter_));
+  boost::asio::read(socket,
+      boost::asio::buffer(reinterpret_cast<char*>(&(this->iter_)), sizeof(this->iter_)));
   LOG(INFO) << "New iteration: " << this->iter_;
   int total_received = 0;
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   for (int param_id = 0; param_id < net_params.size(); ++param_id) {
     Dtype* param_data = net_params[param_id]->mutable_cpu_data();
     int count = net_params[param_id]->count();
-    data_stream.read(reinterpret_cast<char*>(param_data),
-        sizeof(Dtype) * count);
+    boost::asio::read(socket,
+        boost::asio::buffer(reinterpret_cast<char*>(param_data), sizeof(Dtype) * count));
     total_received += count;
     // Also, let's set the param_diff to be zero so that this update does not
     // change the parameter value, since it has already been updated.
     memset(net_params[param_id]->mutable_cpu_diff(), 0,
         net_params[param_id]->count() * sizeof(Dtype));
   }
-  CHECK(!data_stream.error()) << "Error in communication. Error code: "
-      << data_stream.error().message();
   LOG(INFO) << "Received " << total_received << " variables.";
   // Set the next send iter.
   next_send_iter_ = this->iter_ + this->param_.communication_interval();