From 591c36bba0e033623f0da6dc55923fa403641839 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Wed, 6 Nov 2013 15:27:26 -0800 Subject: [PATCH] distributed server update. bug in synchronous connections. --- examples/convert_imageset.cpp | 6 ++++-- examples/demo_compute_image_mean.cpp | 2 ++ examples/dist_train_server.cpp | 4 ++-- include/caffe/distributed_solver.hpp | 1 + include/caffe/util/io.hpp | 2 +- src/caffe/distributed_solver.cpp | 20 ++++++++++++++------ src/caffe/proto/caffe.proto | 2 +- src/caffe/util/io.cpp | 8 ++++++-- 8 files changed, 31 insertions(+), 14 deletions(-) diff --git a/examples/convert_imageset.cpp b/examples/convert_imageset.cpp index 6b5ee78..6a6b86d 100644 --- a/examples/convert_imageset.cpp +++ b/examples/convert_imageset.cpp @@ -64,8 +64,10 @@ int main(int argc, char** argv) { char key_cstr[100]; leveldb::WriteBatch* batch = new leveldb::WriteBatch(); for (int line_id = 0; line_id < lines.size(); ++line_id) { - ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second, - &datum); + if (!ReadImageToDatum(root_folder + lines[line_id].first, lines[line_id].second, + &datum)) { + continue; + }; // sequential sprintf(key_cstr, "%08d_%s", line_id, lines[line_id].first.c_str()); string value; diff --git a/examples/demo_compute_image_mean.cpp b/examples/demo_compute_image_mean.cpp index b0130de..f05a8c5 100644 --- a/examples/demo_compute_image_mean.cpp +++ b/examples/demo_compute_image_mean.cpp @@ -39,6 +39,7 @@ int main(int argc, char** argv) { sum_blob.set_channels(datum.channels()); sum_blob.set_height(datum.height()); sum_blob.set_width(datum.width()); + const int data_size = datum.channels() * datum.height() * datum.width(); for (int i = 0; i < datum.data().size(); ++i) { sum_blob.add_data(0.); } @@ -47,6 +48,7 @@ int main(int argc, char** argv) { // just a dummy operation datum.ParseFromString(it->value().ToString()); const string& data = datum.data(); + CHECK_EQ(data.size(), data_size) << "Incorrect data field size " << data.size(); for (int i = 0; i < data.size(); ++i) { sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); } diff --git a/examples/dist_train_server.cpp b/examples/dist_train_server.cpp index b8f983b..e5ae41e 100644 --- a/examples/dist_train_server.cpp +++ b/examples/dist_train_server.cpp @@ -20,8 +20,8 @@ int main(int argc, char** argv) { return 0; } - Caffe::SetDevice(0); - Caffe::set_mode(Caffe::GPU); + //Caffe::SetDevice(0); + Caffe::set_mode(Caffe::CPU); SolverParameter solver_param; ReadProtoFromTextFile(argv[1], &solver_param); diff --git a/include/caffe/distributed_solver.hpp b/include/caffe/distributed_solver.hpp index 4e08405..4283642 100644 --- a/include/caffe/distributed_solver.hpp +++ b/include/caffe/distributed_solver.hpp @@ -11,6 +11,7 @@ #include "caffe/solver.hpp" + namespace caffe { template diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 03df4b2..e37e740 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -40,7 +40,7 @@ inline void WriteProtoToBinaryFile( WriteProtoToBinaryFile(proto, filename.c_str()); } -void ReadImageToDatum(const string& filename, const int label, Datum* datum); +bool ReadImageToDatum(const string& filename, const int label, Datum* datum); } // namespace caffe diff --git a/src/caffe/distributed_solver.cpp b/src/caffe/distributed_solver.cpp index cbc7c04..0de1a2e 100644 --- a/src/caffe/distributed_solver.cpp +++ b/src/caffe/distributed_solver.cpp @@ -28,6 +28,7 @@ void DistributedSolverParamServer::Solve(const char* resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Solver::Restore(resume_file); } + next_snapshot_ = this->iter_ + this->param_.snapshot(); // the main loop. LOG(INFO) << "Waiting for incoming updates..."; @@ -76,9 +77,10 @@ void DistributedSolverParamServer::ReceiveAndSend() { total_received += count; } LOG(INFO) << "Received " << total_received << " variables."; - // Check Error - if (!data_stream) { - LOG(ERROR) << "Error in receiving."; + // Check error: if there are any error in the receiving phase, we will not + // trust the passed in update. + if (data_stream.error()) { + LOG(ERROR) << "Error in receiving. Error code: " << data_stream.error().message(); } else { // If the read is successful, update the network. this->iter_ += incoming_iter; @@ -101,7 +103,6 @@ void DistributedSolverParamServer::ReceiveAndSend() { } LOG(INFO) << "Sent " << total_sent << " variables."; data_stream.flush(); - data_stream.close(); } @@ -121,6 +122,7 @@ void DistributedSolverParamClient::Solve(const char* resume_file) { // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. vector*> bottom_vec; + next_display_ = this->iter_ + this->param_.display(); while (this->iter_++ < this->param_.max_iter()) { Dtype loss = this->net_->ForwardBackward(bottom_vec); ComputeUpdateValue(); @@ -128,7 +130,7 @@ void DistributedSolverParamClient::Solve(const char* resume_file) { if (this->param_.display() && this->iter_ > next_display_) { LOG(INFO) << "Iteration " << this->iter_ << ", loss = " << loss; - next_display_ += this->param_.display(); + next_display_ = this->iter_ + this->param_.display(); } } LOG(INFO) << "Optimization Done."; @@ -138,7 +140,9 @@ void DistributedSolverParamClient::Solve(const char* resume_file) { template void DistributedSolverParamClient::SendAndReceive(bool receive_only) { tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port()); - CHECK(data_stream) << "Error in connection."; + if (!data_stream) { + LOG(FATAL) << "Unable to connect. Error code: " << data_stream.error().message(); + } data_stream.write(reinterpret_cast(&receive_only), sizeof(receive_only)); if (!receive_only) { LOG(INFO) << "Sending local changes."; @@ -157,6 +161,8 @@ void DistributedSolverParamClient::SendAndReceive(bool receive_only) { total_sent += count; } LOG(INFO) << "Sent " << total_sent << " variables."; + CHECK(!data_stream.error()) << "Error in sending. Error code: " + << data_stream.error().message(); }// else { // LOG(INFO) << "Not sending local changes. Receive only."; //} @@ -179,6 +185,8 @@ void DistributedSolverParamClient::SendAndReceive(bool receive_only) { memset(net_params[param_id]->mutable_cpu_diff(), 0, net_params[param_id]->count() * sizeof(Dtype)); } + CHECK(!data_stream.error()) << "Error in communication. Error code: " + << data_stream.error().message(); LOG(INFO) << "Received " << total_received << " variables."; // Set the next send iter. next_send_iter_ = this->iter_ + this->param_.communication_interval(); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index cbe306d..939df91 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -134,4 +134,4 @@ message SolverState { optional int32 iter = 1; // The current iteration optional string learned_net = 2; // The file that stores the learned net. repeated BlobProto history = 3; // The history for sgd solvers -} \ No newline at end of file +} diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index a3c520f..9bca5be 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -66,10 +66,13 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) { } -void ReadImageToDatum(const string& filename, const int label, Datum* datum) { +bool ReadImageToDatum(const string& filename, const int label, Datum* datum) { cv::Mat cv_img; cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR); - CHECK(cv_img.data) << "Could not open or find the image."; + if (!cv_img.data) { + LOG(ERROR) << "Could not open or find file " << filename; + return false; + } datum->set_channels(3); datum->set_height(cv_img.rows); datum->set_width(cv_img.cols); @@ -84,6 +87,7 @@ void ReadImageToDatum(const string& filename, const int label, Datum* datum) { } } } + return true; } } // namespace caffe -- 2.7.4