working asynchronous sgd code. may have errors.
authorYangqing Jia <jiayq84@gmail.com>
Wed, 6 Nov 2013 19:28:17 +0000 (11:28 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 6 Nov 2013 19:28:17 +0000 (11:28 -0800)
Makefile
examples/dist_train_server.cpp [new file with mode: 0644]
include/caffe/caffe.hpp
src/caffe/distributed_solver.cpp

index a74d8b5..ab8c870 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -44,7 +44,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64
 INCLUDE_DIRS := ./src ./include /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
 LIBRARY_DIRS := /usr/lib /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
 LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
-       glog mkl_rt mkl_intel_thread leveldb snappy pthread
+       glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system
 WARNINGS := -Wall
 
 COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
diff --git a/examples/dist_train_server.cpp b/examples/dist_train_server.cpp
new file mode 100644 (file)
index 0000000..b8f983b
--- /dev/null
@@ -0,0 +1,46 @@
+// Copyright 2013 Yangqing Jia
+//
+// This is a simple script that allows one to quickly train a network whose
+// parameters are specified by text format protocol buffers.
+// Usage:
+//    train_net net_proto_file solver_proto_file [resume_point_file]
+
+#include <cuda_runtime.h>
+
+#include <cstring>
+
+#include "caffe/caffe.hpp"
+
+using namespace caffe;
+
+int main(int argc, char** argv) {
+  ::google::InitGoogleLogging(argv[0]);
+  if (argc < 3) {
+    LOG(ERROR) << "Usage: dist_train_server solver_proto_file (server|client) [resume_point_file]";
+    return 0;
+  }
+
+  Caffe::SetDevice(0);
+  Caffe::set_mode(Caffe::GPU);
+
+  SolverParameter solver_param;
+  ReadProtoFromTextFile(argv[1], &solver_param);
+
+  LOG(INFO) << "Starting Optimization";
+  shared_ptr<Solver<float> > solver;
+  if (strcmp(argv[2], "server") == 0) {
+    solver.reset(new DistributedSolverParamServer<float>(solver_param));
+  } else if (strcmp(argv[2], "client") == 0) {
+    solver.reset(new DistributedSolverParamClient<float>(solver_param));
+  }
+
+  if (argc == 4) {
+    LOG(INFO) << "Resuming from " << argv[2];
+    solver->Solve(argv[2]);
+  } else {
+    solver->Solve();
+  }
+  LOG(INFO) << "Optimization Done.";
+
+  return 0;
+}
index 6557dfd..6ce38aa 100644 (file)
@@ -11,6 +11,7 @@
 #include "caffe/layer.hpp"
 #include "caffe/net.hpp"
 #include "caffe/solver.hpp"
+#include "caffe/distributed_solver.hpp"
 #include "caffe/util/io.hpp"
 #include "caffe/vision_layers.hpp"
 
index 5e13c54..cbc7c04 100644 (file)
@@ -30,6 +30,7 @@ void DistributedSolverParamServer<Dtype>::Solve(const char* resume_file) {
   }
 
   // the main loop.
+  LOG(INFO) << "Waiting for incoming updates...";
   while (this->iter_ < this->param_.max_iter()) {
     ReceiveAndSend();
     // Check if we need to do snapshot
@@ -57,18 +58,24 @@ void DistributedSolverParamServer<Dtype>::ReceiveAndSend() {
       io_s, tcp::endpoint(tcp::v4(), atoi(this->param_.tcp_port().c_str())));
   tcp::iostream data_stream;
   data_acceptor.accept(*(data_stream.rdbuf()));
-  data_stream >> send_only;
+  LOG(INFO) << "Incoming connection.";
+  data_stream.read(reinterpret_cast<char*>(&send_only), sizeof(send_only));
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   if (!send_only) {
     // Receive data
-    data_stream >> incoming_iter;
+    LOG(INFO) << "Receiving data.";
+    data_stream.read(reinterpret_cast<char*>(&incoming_iter),
+        sizeof(incoming_iter));
+    int total_received = 0;
+    LOG(INFO) << "Incoming iterations: " << incoming_iter;
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
       Dtype* param_diff = net_params[param_id]->mutable_cpu_diff();
       int count = net_params[param_id]->count();
-      for (int i = 0; i < count; ++i) {
-        data_stream >> param_diff[i];
-      }
+      data_stream.read(reinterpret_cast<char*>(param_diff),
+          count * sizeof(Dtype));
+      total_received += count;
     }
+    LOG(INFO) << "Received " << total_received << " variables.";
     // Check Error
     if (!data_stream) {
       LOG(ERROR) << "Error in receiving.";
@@ -77,20 +84,23 @@ void DistributedSolverParamServer<Dtype>::ReceiveAndSend() {
       this->iter_ += incoming_iter;
       this->net_->Update();
     }
+  } else {
+    LOG(INFO) << "No incoming updates. Will simply send data.";
   }
   // Send data
-  data_stream << this->iter_;
+  LOG(INFO) << "Sending data";
+  data_stream.write(reinterpret_cast<char*>(&(this->iter_)), sizeof(this->iter_));
+  LOG(INFO) << "Current iteration: " << this->iter_;
+  int total_sent = 0;
   for (int param_id = 0; param_id < net_params.size(); ++param_id) {
     const Dtype* param_data = net_params[param_id]->cpu_data();
     int count = net_params[param_id]->count();
-    for (int i = 0; i < count; ++i) {
-      data_stream << param_data[i];
-    }
+    data_stream.write(reinterpret_cast<const char*>(param_data),
+        sizeof(Dtype) * count);
+    total_sent += count;
   }
+  LOG(INFO) << "Sent " << total_sent << " variables.";
   data_stream.flush();
-  if (!data_stream) {
-    LOG(ERROR) << "Error in sending.";
-  }
   data_stream.close();
 }
 
@@ -104,7 +114,9 @@ void DistributedSolverParamClient<Dtype>::Solve(const char* resume_file) {
   PreSolve();
 
   // Send and receive once to get the current iteration and the parameters
+  LOG(INFO) << "Obtaining initial parameters.";
   SendAndReceive(true);
+  LOG(INFO) << "Initial communication finished.";
 
   // For a network that is trained by the solver, no bottom or top vecs
   // should be given, and we will just provide dummy vecs.
@@ -127,35 +139,49 @@ template <typename Dtype>
 void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
   tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port());
   CHECK(data_stream) << "Error in connection.";
-  data_stream << receive_only;
+  data_stream.write(reinterpret_cast<char*>(&receive_only), sizeof(receive_only));
   if (!receive_only) {
-    data_stream << this->iter_;
+    LOG(INFO) << "Sending local changes.";
+    int local_iters = this->param_.communication_interval();
+    data_stream.write(reinterpret_cast<char*>(&local_iters),
+        sizeof(local_iters));
+    int total_sent = 0;
     // TODO: send the accumulated gradient stored at history_, and set it to
     // zero for future accumulation
     for (int param_id = 0; param_id < this->history_.size(); ++param_id) {
       Dtype* accum_history_data = this->history_[param_id]->mutable_cpu_diff();
       int count = this->history_[param_id]->count();
-      for (int i = 0; i < count; ++i) {
-        data_stream << accum_history_data[i];
-        accum_history_data[i] = 0;
-      }
+      data_stream.write(reinterpret_cast<char*>(accum_history_data),
+          sizeof(Dtype) * count);
+      memset(accum_history_data, 0, sizeof(Dtype) * count);
+      total_sent += count;
     }
-  }
+    LOG(INFO) << "Sent " << total_sent << " variables.";
+  }// else {
+  //  LOG(INFO) << "Not sending local changes. Receive only.";
+  //}
   data_stream.flush();
   // Receive parameters
-  data_stream >> this->iter_;
+  LOG(INFO) << "Receiving parameters.";
+  data_stream.read(reinterpret_cast<char*>(&(this->iter_)),
+      sizeof(this->iter_));
+  LOG(INFO) << "New iteration: " << this->iter_;
+  int total_received = 0;
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   for (int param_id = 0; param_id < net_params.size(); ++param_id) {
     Dtype* param_data = net_params[param_id]->mutable_cpu_data();
     int count = net_params[param_id]->count();
-    for (int i = 0; i < count; ++i) {
-      data_stream >> param_data[i];
-    }
+    data_stream.read(reinterpret_cast<char*>(param_data),
+        sizeof(Dtype) * count);
+    total_received += count;
     // Also, let's set the param_diff to be zero so that this update does not
     // change the parameter value, since it has already been updated.
     memset(net_params[param_id]->mutable_cpu_diff(), 0,
         net_params[param_id]->count() * sizeof(Dtype));
   }
+  LOG(INFO) << "Received " << total_received << " variables.";
+  // Set the next send iter.
+  next_send_iter_ = this->iter_ + this->param_.communication_interval();
 }
 
 
@@ -184,10 +210,8 @@ void DistributedSolverParamClient<Dtype>::ComputeUpdateValue() {
     LOG(FATAL) << "Unknown caffe mode.";
   }
   // See if we need to do communication.
-  if (this->iter_ > next_send_iter_) {
-    DLOG(INFO) << "Send and receive parameters.";
+  if (this->iter_ >= next_send_iter_) {
     SendAndReceive();
-    next_send_iter_ += this->param_.communication_interval();
   }
 }