--- /dev/null
+// Copyright 2013 Yangqing Jia
+//
+// This is a simple script that allows one to quickly train a network whose
+// parameters are specified by text format protocol buffers.
+// Usage:
+// train_net net_proto_file solver_proto_file [resume_point_file]
+
+#include <cuda_runtime.h>
+
+#include <cstring>
+
+#include "caffe/caffe.hpp"
+
+using namespace caffe;
+
+int main(int argc, char** argv) {
+ ::google::InitGoogleLogging(argv[0]);
+ if (argc < 3) {
+ LOG(ERROR) << "Usage: dist_train_server solver_proto_file (server|client) [resume_point_file]";
+ return 0;
+ }
+
+ Caffe::SetDevice(0);
+ Caffe::set_mode(Caffe::GPU);
+
+ SolverParameter solver_param;
+ ReadProtoFromTextFile(argv[1], &solver_param);
+
+ LOG(INFO) << "Starting Optimization";
+ shared_ptr<Solver<float> > solver;
+ if (strcmp(argv[2], "server") == 0) {
+ solver.reset(new DistributedSolverParamServer<float>(solver_param));
+ } else if (strcmp(argv[2], "client") == 0) {
+ solver.reset(new DistributedSolverParamClient<float>(solver_param));
+ }
+
+ if (argc == 4) {
+ LOG(INFO) << "Resuming from " << argv[2];
+ solver->Solve(argv[2]);
+ } else {
+ solver->Solve();
+ }
+ LOG(INFO) << "Optimization Done.";
+
+ return 0;
+}
}
// the main loop.
+ LOG(INFO) << "Waiting for incoming updates...";
while (this->iter_ < this->param_.max_iter()) {
ReceiveAndSend();
// Check if we need to do snapshot
io_s, tcp::endpoint(tcp::v4(), atoi(this->param_.tcp_port().c_str())));
tcp::iostream data_stream;
data_acceptor.accept(*(data_stream.rdbuf()));
- data_stream >> send_only;
+ LOG(INFO) << "Incoming connection.";
+ data_stream.read(reinterpret_cast<char*>(&send_only), sizeof(send_only));
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
if (!send_only) {
// Receive data
- data_stream >> incoming_iter;
+ LOG(INFO) << "Receiving data.";
+ data_stream.read(reinterpret_cast<char*>(&incoming_iter),
+ sizeof(incoming_iter));
+ int total_received = 0;
+ LOG(INFO) << "Incoming iterations: " << incoming_iter;
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype* param_diff = net_params[param_id]->mutable_cpu_diff();
int count = net_params[param_id]->count();
- for (int i = 0; i < count; ++i) {
- data_stream >> param_diff[i];
- }
+ data_stream.read(reinterpret_cast<char*>(param_diff),
+ count * sizeof(Dtype));
+ total_received += count;
}
+ LOG(INFO) << "Received " << total_received << " variables.";
// Check Error
if (!data_stream) {
LOG(ERROR) << "Error in receiving.";
this->iter_ += incoming_iter;
this->net_->Update();
}
+ } else {
+ LOG(INFO) << "No incoming updates. Will simply send data.";
}
// Send data
- data_stream << this->iter_;
+ LOG(INFO) << "Sending data";
+ data_stream.write(reinterpret_cast<char*>(&(this->iter_)), sizeof(this->iter_));
+ LOG(INFO) << "Current iteration: " << this->iter_;
+ int total_sent = 0;
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
const Dtype* param_data = net_params[param_id]->cpu_data();
int count = net_params[param_id]->count();
- for (int i = 0; i < count; ++i) {
- data_stream << param_data[i];
- }
+ data_stream.write(reinterpret_cast<const char*>(param_data),
+ sizeof(Dtype) * count);
+ total_sent += count;
}
+ LOG(INFO) << "Sent " << total_sent << " variables.";
data_stream.flush();
- if (!data_stream) {
- LOG(ERROR) << "Error in sending.";
- }
data_stream.close();
}
PreSolve();
// Send and receive once to get the current iteration and the parameters
+ LOG(INFO) << "Obtaining initial parameters.";
SendAndReceive(true);
+ LOG(INFO) << "Initial communication finished.";
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
void DistributedSolverParamClient<Dtype>::SendAndReceive(bool receive_only) {
tcp::iostream data_stream(this->param_.tcp_server(), this->param_.tcp_port());
CHECK(data_stream) << "Error in connection.";
- data_stream << receive_only;
+ data_stream.write(reinterpret_cast<char*>(&receive_only), sizeof(receive_only));
if (!receive_only) {
- data_stream << this->iter_;
+ LOG(INFO) << "Sending local changes.";
+ int local_iters = this->param_.communication_interval();
+ data_stream.write(reinterpret_cast<char*>(&local_iters),
+ sizeof(local_iters));
+ int total_sent = 0;
// TODO: send the accumulated gradient stored at history_, and set it to
// zero for future accumulation
for (int param_id = 0; param_id < this->history_.size(); ++param_id) {
Dtype* accum_history_data = this->history_[param_id]->mutable_cpu_diff();
int count = this->history_[param_id]->count();
- for (int i = 0; i < count; ++i) {
- data_stream << accum_history_data[i];
- accum_history_data[i] = 0;
- }
+ data_stream.write(reinterpret_cast<char*>(accum_history_data),
+ sizeof(Dtype) * count);
+ memset(accum_history_data, 0, sizeof(Dtype) * count);
+ total_sent += count;
}
- }
+ LOG(INFO) << "Sent " << total_sent << " variables.";
+ }// else {
+ // LOG(INFO) << "Not sending local changes. Receive only.";
+ //}
data_stream.flush();
// Receive parameters
- data_stream >> this->iter_;
+ LOG(INFO) << "Receiving parameters.";
+ data_stream.read(reinterpret_cast<char*>(&(this->iter_)),
+ sizeof(this->iter_));
+ LOG(INFO) << "New iteration: " << this->iter_;
+ int total_received = 0;
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype* param_data = net_params[param_id]->mutable_cpu_data();
int count = net_params[param_id]->count();
- for (int i = 0; i < count; ++i) {
- data_stream >> param_data[i];
- }
+ data_stream.read(reinterpret_cast<char*>(param_data),
+ sizeof(Dtype) * count);
+ total_received += count;
// Also, let's set the param_diff to be zero so that this update does not
// change the parameter value, since it has already been updated.
memset(net_params[param_id]->mutable_cpu_diff(), 0,
net_params[param_id]->count() * sizeof(Dtype));
}
+ LOG(INFO) << "Received " << total_received << " variables.";
+ // Set the next send iter.
+ next_send_iter_ = this->iter_ + this->param_.communication_interval();
}
LOG(FATAL) << "Unknown caffe mode.";
}
// See if we need to do communication.
- if (this->iter_ > next_send_iter_) {
- DLOG(INFO) << "Send and receive parameters.";
+ if (this->iter_ >= next_send_iter_) {
SendAndReceive();
- next_send_iter_ += this->param_.communication_interval();
}
}