From 185385b87f54986968c2c18da0d573fdd8a1a937 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Wed, 6 Nov 2013 11:28:17 -0800 Subject: [PATCH] working asynchronous sgd code. may have errors. --- Makefile | 2 +- examples/dist_train_server.cpp | 46 ++++++++++++++++++++++++ include/caffe/caffe.hpp | 1 + src/caffe/distributed_solver.cpp | 76 ++++++++++++++++++++++++++-------------- 4 files changed, 98 insertions(+), 27 deletions(-) create mode 100644 examples/dist_train_server.cpp diff --git a/Makefile b/Makefile index a74d8b5..ab8c870 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64 INCLUDE_DIRS := ./src ./include /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR) LIBRARY_DIRS := /usr/lib /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR) LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \ - glog mkl_rt mkl_intel_thread leveldb snappy pthread + glog mkl_rt mkl_intel_thread leveldb snappy pthread boost_system WARNINGS := -Wall COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) diff --git a/examples/dist_train_server.cpp b/examples/dist_train_server.cpp new file mode 100644 index 0000000..b8f983b --- /dev/null +++ b/examples/dist_train_server.cpp @@ -0,0 +1,46 @@ +// Copyright 2013 Yangqing Jia +// +// This is a simple script that allows one to quickly train a network whose +// parameters are specified by text format protocol buffers. +// Usage: +// train_net net_proto_file solver_proto_file [resume_point_file] + +#include + +#include + +#include "caffe/caffe.hpp" + +using namespace caffe; + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc < 3) { + LOG(ERROR) << "Usage: dist_train_server solver_proto_file (server|client) [resume_point_file]"; + return 0; + } + + Caffe::SetDevice(0); + Caffe::set_mode(Caffe::GPU); + + SolverParameter solver_param; + ReadProtoFromTextFile(argv[1], &solver_param); + + LOG(INFO) << "Starting Optimization"; + shared_ptr > solver; + if (strcmp(argv[2], "server") == 0) { + solver.reset(new DistributedSolverParamServer(solver_param)); + } else if (strcmp(argv[2], "client") == 0) { + solver.reset(new DistributedSolverParamClient(solver_param)); + } + + if (argc == 4) { + LOG(INFO) << "Resuming from " << argv[2]; + solver->Solve(argv[2]); + } else { + solver->Solve(); + } + LOG(INFO) << "Optimization Done."; + + return 0; +} diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index 6557dfd..6ce38aa 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -11,6 +11,7 @@ #include "caffe/layer.hpp" #include "caffe/net.hpp" #include "caffe/solver.hpp" +#include "caffe/distributed_solver.hpp" #include "caffe/util/io.hpp" #include "caffe/vision_layers.hpp" diff --git a/src/caffe/distributed_solver.cpp b/src/caffe/distributed_solver.cpp index 5e13c54..cbc7c04 100644 --- a/src/caffe/distributed_solver.cpp +++ b/src/caffe/distributed_solver.cpp @@ -30,6 +30,7 @@ void DistributedSolverParamServer::Solve(const char* resume_file) { } // the main loop. + LOG(INFO) << "Waiting for incoming updates..."; while (this->iter_ < this->param_.max_iter()) { ReceiveAndSend(); // Check if we need to do snapshot @@ -57,18 +58,24 @@ void DistributedSolverParamServer::ReceiveAndSend() { io_s, tcp::endpoint(tcp::v4(), atoi(this->param_.tcp_port().c_str()))); tcp::iostream data_stream; data_acceptor.accept(*(data_stream.rdbuf())); - data_stream >> send_only; + LOG(INFO) << "Incoming connection."; + data_stream.read(reinterpret_cast(&send_only), sizeof(send_only)); vector > >& net_params = this->net_->params(); if (!send_only) { // Receive data - data_stream >> incoming_iter; + LOG(INFO) << "Receiving data."; + data_stream.read(reinterpret_cast(&incoming_iter), + sizeof(incoming_iter)); + int total_received = 0; + LOG(INFO) << "Incoming iterations: " << incoming_iter; for (int param_id = 0; param_id < net_params.size(); ++param_id) { Dtype* param_diff = net_params[param_id]->mutable_cpu_diff(); int count = net_params[param_id]->count(); - for (int i = 0; i < count; ++i) { - data_stream >> param_diff[i]; - } + data_stream.read(reinterpret_cast(param_diff), + count * sizeof(Dtype)); + total_received += count; } + LOG(INFO) << "Received " << total_received << " variables."; // Check Error if (!data_stream) { LOG(ERROR) << "Error in receiving."; @@ -77,20 +84,23 @@ void DistributedSolverParamServer::ReceiveAndSend() { this->iter_ += incoming_iter; this->net_->Update(); } + } else { + LOG(INFO) << "No incoming updates. Will simply send data."; } // Send data - data_stream << this->iter_; + LOG(INFO) << "Sending data"; + data_stream.write(reinterpret_cast(&(this->iter_)), sizeof(this->iter_)); + LOG(INFO) << "Current iteration: " << this->iter_; + int total_sent = 0; for (int param_id = 0; param_id < net_params.size(); ++param_id) { const Dtype* param_data = net_params[param_id]->cpu_data(); int count = net_params[param_id]->count(); - for (int i = 0; i < count; ++i) { - data_stream << param_data[i]; - } + data_stream.write(reinterpret_cast(param_data), + sizeof(Dtype) * count); + total_sent += count; } + LOG(INFO) << "Sent " << total_sent << " variables."; data_stream.flush(); - if (!data_stream) { - LOG(ERROR) << "Error in sending."; - } data_stream.close(); } @@ -104,7 +114,9 @@ void DistributedSolverParamClient::Solve(const char* resume_file) { PreSolve(); // Send and receive once to get the current iteration and the parameters + LOG(INFO) << "Obtaining initial parameters."; SendAndReceive(true); + LOG(INFO) << "Initial communication finished."; // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. @@ -127,35 +139,49 @@ template void DistributedSolverParamClient::SendAndReceive(bool receive_only) { tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port()); CHECK(data_stream) << "Error in connection."; - data_stream << receive_only; + data_stream.write(reinterpret_cast(&receive_only), sizeof(receive_only)); if (!receive_only) { - data_stream << this->iter_; + LOG(INFO) << "Sending local changes."; + int local_iters = this->param_.communication_interval(); + data_stream.write(reinterpret_cast(&local_iters), + sizeof(local_iters)); + int total_sent = 0; // TODO: send the accumulated gradient stored at history_, and set it to // zero for future accumulation for (int param_id = 0; param_id < this->history_.size(); ++param_id) { Dtype* accum_history_data = this->history_[param_id]->mutable_cpu_diff(); int count = this->history_[param_id]->count(); - for (int i = 0; i < count; ++i) { - data_stream << accum_history_data[i]; - accum_history_data[i] = 0; - } + data_stream.write(reinterpret_cast(accum_history_data), + sizeof(Dtype) * count); + memset(accum_history_data, 0, sizeof(Dtype) * count); + total_sent += count; } - } + LOG(INFO) << "Sent " << total_sent << " variables."; + }// else { + // LOG(INFO) << "Not sending local changes. Receive only."; + //} data_stream.flush(); // Receive parameters - data_stream >> this->iter_; + LOG(INFO) << "Receiving parameters."; + data_stream.read(reinterpret_cast(&(this->iter_)), + sizeof(this->iter_)); + LOG(INFO) << "New iteration: " << this->iter_; + int total_received = 0; vector > >& net_params = this->net_->params(); for (int param_id = 0; param_id < net_params.size(); ++param_id) { Dtype* param_data = net_params[param_id]->mutable_cpu_data(); int count = net_params[param_id]->count(); - for (int i = 0; i < count; ++i) { - data_stream >> param_data[i]; - } + data_stream.read(reinterpret_cast(param_data), + sizeof(Dtype) * count); + total_received += count; // Also, let's set the param_diff to be zero so that this update does not // change the parameter value, since it has already been updated. memset(net_params[param_id]->mutable_cpu_diff(), 0, net_params[param_id]->count() * sizeof(Dtype)); } + LOG(INFO) << "Received " << total_received << " variables."; + // Set the next send iter. + next_send_iter_ = this->iter_ + this->param_.communication_interval(); } @@ -184,10 +210,8 @@ void DistributedSolverParamClient::ComputeUpdateValue() { LOG(FATAL) << "Unknown caffe mode."; } // See if we need to do communication. - if (this->iter_ > next_send_iter_) { - DLOG(INFO) << "Send and receive parameters."; + if (this->iter_ >= next_send_iter_) { SendAndReceive(); - next_send_iter_ += this->param_.communication_interval(); } } -- 2.7.4