distributed server update. bug in synchronous connections.
authorYangqing Jia <jiayq84@gmail.com>
Wed, 6 Nov 2013 23:27:26 +0000 (15:27 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 6 Nov 2013 23:27:26 +0000 (15:27 -0800)
examples/convert_imageset.cpp
examples/demo_compute_image_mean.cpp
examples/dist_train_server.cpp
include/caffe/distributed_solver.hpp
include/caffe/util/io.hpp
src/caffe/distributed_solver.cpp
src/caffe/proto/caffe.proto
src/caffe/util/io.cpp

index 6b5ee78..6a6b86d 100644 (file)
@@ -64,8 +64,10 @@ int main(int argc, char** argv) {
   char key_cstr[100];
   leveldb::WriteBatch* batch = new leveldb::WriteBatch();
   for (int line_id = 0; line_id < lines.size(); ++line_id) {
-    ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second,
-        &datum);
+    if (!ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second,
+        &datum)) {
+      continue;
+    };
     // sequential
     sprintf(key_cstr, "%08d_%s", line_id, lines[line_id].first.c_str());
     string value;
index b0130de..f05a8c5 100644 (file)
@@ -39,6 +39,7 @@ int main(int argc, char** argv) {
   sum_blob.set_channels(datum.channels());
   sum_blob.set_height(datum.height());
   sum_blob.set_width(datum.width());
+  const int data_size = datum.channels() * datum.height() * datum.width();
   for (int i = 0; i < datum.data().size(); ++i) {
     sum_blob.add_data(0.);
   }
@@ -47,6 +48,7 @@ int main(int argc, char** argv) {
     // just a dummy operation
     datum.ParseFromString(it->value().ToString());
     const string& data = datum.data();
+    CHECK_EQ(data.size(), data_size) << "Incorrect data field size " << data.size();
     for (int i = 0; i < data.size(); ++i) {
       sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
     }
index b8f983b..e5ae41e 100644 (file)
@@ -20,8 +20,8 @@ int main(int argc, char** argv) {
     return 0;
   }
 
-  Caffe::SetDevice(0);
-  Caffe::set_mode(Caffe::GPU);
+  //Caffe::SetDevice(0);
+  Caffe::set_mode(Caffe::CPU);
 
   SolverParameter solver_param;
   ReadProtoFromTextFile(argv[1], &solver_param);
index 4e08405..4283642 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "caffe/solver.hpp"
 
+
 namespace caffe {
 
 template <typename Dtype>
index 03df4b2..e37e740 100644 (file)
@@ -40,7 +40,7 @@ inline void WriteProtoToBinaryFile(
   WriteProtoToBinaryFile(proto, filename.c_str());
 }
 
-void ReadImageToDatum(const string& filename, const int label, Datum* datum);
+bool ReadImageToDatum(const string& filename, const int label, Datum* datum);
 
 }  // namespace caffe
 
index cbc7c04..0de1a2e 100644 (file)
@@ -28,6 +28,7 @@ void DistributedSolverParamServer<Dtype>::Solve(const char* resume_file) {
     LOG(INFO) << "Restoring previous solver status from " << resume_file;
     Solver<Dtype>::Restore(resume_file);
   }
+  next_snapshot_ = this->iter_ + this->param_.snapshot();
 
   // the main loop.
   LOG(INFO) << "Waiting for incoming updates...";
@@ -76,9 +77,10 @@ void DistributedSolverParamServer<Dtype>::ReceiveAndSend() {
       total_received += count;
     }
     LOG(INFO) << "Received " << total_received << " variables.";
-    // Check Error
-    if (!data_stream) {
-      LOG(ERROR) << "Error in receiving.";
+    // Check error: if there are any error in the receiving phase, we will not
+    // trust the passed in update.
+    if (data_stream.error()) {
+      LOG(ERROR) << "Error in receiving. Error code: " << data_stream.error().message();
     } else {
       // If the read is successful, update the network.
       this->iter_ += incoming_iter;
@@ -101,7 +103,6 @@ void DistributedSolverParamServer<Dtype>::ReceiveAndSend() {
   }
   LOG(INFO) << "Sent " << total_sent << " variables.";
   data_stream.flush();
-  data_stream.close();
 }
 
 
@@ -121,6 +122,7 @@ void DistributedSolverParamClient<Dtype>::Solve(const char* resume_file) {
   // For a network that is trained by the solver, no bottom or top vecs
   // should be given, and we will just provide dummy vecs.
   vector<Blob<Dtype>*> bottom_vec;
+  next_display_ = this->iter_ + this->param_.display();
   while (this->iter_++ < this->param_.max_iter()) {
     Dtype loss = this->net_->ForwardBackward(bottom_vec);
     ComputeUpdateValue();
@@ -128,7 +130,7 @@ void DistributedSolverParamClient<Dtype>::Solve(const char* resume_file) {
 
     if (this->param_.display() && this->iter_ > next_display_) {
       LOG(INFO) << "Iteration " << this->iter_ << ", loss = " << loss;
-      next_display_ += this->param_.display();
+      next_display_ = this->iter_ + this->param_.display();
     }
   }
   LOG(INFO) << "Optimization Done.";
@@ -138,7 +140,9 @@ void DistributedSolverParamClient<Dtype>::Solve(const char* resume_file) {
 template <typename Dtype>
 void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
   tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port());
-  CHECK(data_stream) << "Error in connection.";
+  if (!data_stream) {
+    LOG(FATAL) << "Unable to connect. Error code: " << data_stream.error().message();
+  }
   data_stream.write(reinterpret_cast<char*>(&receive_only), sizeof(receive_only));
   if (!receive_only) {
     LOG(INFO) << "Sending local changes.";
@@ -157,6 +161,8 @@ void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
       total_sent += count;
     }
     LOG(INFO) << "Sent " << total_sent << " variables.";
+    CHECK(!data_stream.error()) << "Error in sending. Error code: "
+        << data_stream.error().message();
   }// else {
   //  LOG(INFO) << "Not sending local changes. Receive only.";
   //}
@@ -179,6 +185,8 @@ void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
     memset(net_params[param_id]->mutable_cpu_diff(), 0,
         net_params[param_id]->count() * sizeof(Dtype));
   }
+  CHECK(!data_stream.error()) << "Error in communication. Error code: "
+      << data_stream.error().message();
   LOG(INFO) << "Received " << total_received << " variables.";
   // Set the next send iter.
   next_send_iter_ = this->iter_ + this->param_.communication_interval();
index cbe306d..939df91 100644 (file)
@@ -134,4 +134,4 @@ message SolverState {
   optional int32 iter = 1; // The current iteration
   optional string learned_net = 2; // The file that stores the learned net.
   repeated BlobProto history = 3; // The history for sgd solvers
-}
\ No newline at end of file
+}
index a3c520f..9bca5be 100644 (file)
@@ -66,10 +66,13 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
 }
 
 
-void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
+bool ReadImageToDatum(const string& filename, const int label, Datum* datum) {
   cv::Mat cv_img;
   cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
-  CHECK(cv_img.data) << "Could not open or find the image.";
+  if (!cv_img.data) {
+    LOG(ERROR) << "Could not open or find file " << filename;
+    return false;
+  }
   datum->set_channels(3);
   datum->set_height(cv_img.rows);
   datum->set_width(cv_img.cols);
@@ -84,6 +87,7 @@ void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
       }
     }
   }
+  return true;
 }
 
 }  // namespace caffe