From 865a60ca09dadb0d794238ba2836f06658734010 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Fri, 27 Sep 2013 16:59:43 -0700 Subject: [PATCH] updated a bunch of things, ready to test if it breaks things --- src/caffe/blob.cpp | 42 +++++++++++++- src/caffe/blob.hpp | 4 ++ src/caffe/net.cpp | 30 ++++++++-- src/caffe/net.hpp | 10 ++++ src/caffe/optimization/solver.cpp | 113 ++++++++++++++++++++++++++++++++++++++ src/caffe/optimization/solver.hpp | 25 ++++++++- src/caffe/proto/caffe.proto | 2 + src/caffe/util/math_functions.cpp | 10 ++++ src/caffe/util/math_functions.hpp | 3 + 9 files changed, 230 insertions(+), 9 deletions(-) create mode 100644 src/caffe/optimization/solver.cpp diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 6162740..35e5b04 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -1,5 +1,6 @@ // Copyright 2013 Yangqing Jia +#include #include #include "caffe/blob.hpp" @@ -11,6 +12,7 @@ namespace caffe { template void Blob::Reshape(const int num, const int channels, const int height, const int width) { + int old_count = count_; CHECK_GE(num, 0); CHECK_GE(channels, 0); CHECK_GE(height, 0); @@ -21,8 +23,10 @@ void Blob::Reshape(const int num, const int channels, const int height, width_ = width; count_ = num_ * channels_ * height_ * width_; if (count_) { - data_.reset(new SyncedMemory(count_ * sizeof(Dtype))); - diff_.reset(new SyncedMemory(count_ * sizeof(Dtype))); + if (old_count != count_) { + data_.reset(new SyncedMemory(count_ * sizeof(Dtype))); + diff_.reset(new SyncedMemory(count_ * sizeof(Dtype))); + } } else { data_.reset(reinterpret_cast(NULL)); diff_.reset(reinterpret_cast(NULL)); @@ -91,6 +95,40 @@ void Blob::Update() { } template +void Blob::CopyFrom(const Blob& source, bool copy_diff, bool reshape) { + if (num_ != source.num() || channels_ != source.channels() || + height_ != source.height() || width_ != source.width()) { + if (reshape) { + Reshape(source.num(), source.channels(), source.height(), source.width()); + } else { + LOG(FATAL) << "Trying to copy blobs of different sizes."; + } + } + switch (Caffe::mode()) { + case Caffe::GPU: + if (copy_diff) { + CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(), + sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice)); + } else { + CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(), + sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice)); + } + break; + case Caffe::CPU: + if (copy_diff) { + memcpy(diff_->mutable_cpu_data(), source.cpu_diff(), + sizeof(Dtype) * count_); + } else { + memcpy(data_->mutable_cpu_data(), source.cpu_data(), + sizeof(Dtype) * count_); + } + break; + default: + LOG(FATAL) << "Unknown caffe mode."; + } +} + +template void Blob::FromProto(const BlobProto& proto) { Reshape(proto.num(), proto.channels(), proto.height(), proto.width()); // copy data diff --git a/src/caffe/blob.hpp b/src/caffe/blob.hpp index f0e19c2..f31d3b0 100644 --- a/src/caffe/blob.hpp +++ b/src/caffe/blob.hpp @@ -29,6 +29,10 @@ class Blob { const int w = 0) const { return ((n * channels_ + c) * height_ + h) * width_ + w; } + // Copy from source. If copy_diff is false, we copy the data; if copy_diff + // is true, we copy the diff. + void CopyFrom(const Blob& source, bool copy_diff = false, + bool reshape = false); inline Dtype data_at(const int n, const int c, const int h, const int w) const { diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index c6dfce1..22d2743 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -41,6 +41,8 @@ Net::Net(const NetParameter& param, // For each layer, set up their input and output bottom_vecs_.resize(param.layers_size()); top_vecs_.resize(param.layers_size()); + bottom_id_vecs_.resize(param.layers_size()); + top_id_vecs_.resize(param.layers_size()); for (int i = 0; i < param.layers_size(); ++i) { const LayerConnection& layer_connection = param.layers(i); const LayerParameter& layer_param = layer_connection.layer(); @@ -57,6 +59,7 @@ Net::Net(const NetParameter& param, LOG(INFO) << layer_param.name() << " <- " << blob_name; bottom_vecs_[i].push_back( blobs_[blob_name_to_idx[blob_name]].get()); + bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]); available_blobs.erase(blob_name); } for (int j = 0; j < layer_connection.top_size(); ++j) { @@ -71,6 +74,7 @@ Net::Net(const NetParameter& param, blob_name_to_idx[blob_name] = blob_names_.size() - 1; available_blobs.insert(blob_name); top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get()); + top_id_vecs_[i].push_back(blob_names_.size() - 1); } } LOG(INFO) << "Checking top blobs."; @@ -95,7 +99,7 @@ Net::Net(const NetParameter& param, for (int i = 0; i < layers_.size(); ++i) { LOG(INFO) << "Setting up " << layer_names_[i]; layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]); - vector > >& layer_params = layers_[i].params(); + vector > >& layer_params = layers_[i]->params(); for (int j = 0; j < layer_params.size(); ++j) { params_.push_back(layer_params[j]); } @@ -109,15 +113,14 @@ void Net::Forward(const vector*> & bottom, vector*>* top) { // Copy bottom to internal bottom for (int i = 0; i < bottom.size(); ++i) { - memcpy(blobs_[net_input_blob_indices_[i]]->mutable_cpu_data(), - bottom[i]->cpu_data(), sizeof(Dtype) * bottom[i]->count()); + blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]); } for (int i = 0; i < layers_.size(); ++i) { layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]); } // Copy internal top to top for (int i = 0; i < (*top).size(); ++i) { - NOT_IMPLEMENTED; + (*top)[i]->CopyFrom(*blobs_[net_output_blob_indices_[i]]); } } @@ -167,11 +170,26 @@ void Net::ToProto(NetParameter* param, bool write_diff) { for (int i = 0; i < net_input_blob_indices_.size(); ++i) { param->add_bottom(blob_names_[net_input_blob_indices_[i]]); } - for (int i = 0; i < net_input_blob_indices_.size(); ++i) { - param->add_bottom(blob_names_[net_input_blob_indices_[i]]); + for (int i = 0; i < net_output_blob_indices_.size(); ++i) { + param->add_top(blob_names_[net_output_blob_indices_[i]]); } for (int i = 0; i < layers_.size(); ++i) { LayerConnection* layer_connection = param->add_layers(); + for (int j = 0; j < bottom_id_vecs_[i].size(); ++i) { + layer_connection->add_bottom(blob_names_[bottom_id_vecs_[i][j]]); + } + for (int j = 0; j < top_id_vecs_[i].size(); ++i) { + layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]); + } + LayerParameter* layer_parameter = layer_connection->mutable_layer(); + layers_[i]->ToProto(layer_parameter); + } +} + +template +void Net::Update() { + for (int i = 0; i < params_.size(); ++i) { + params_[i]->Update(); } } diff --git a/src/caffe/net.hpp b/src/caffe/net.hpp index 719267c..1f1a803 100644 --- a/src/caffe/net.hpp +++ b/src/caffe/net.hpp @@ -31,6 +31,12 @@ class Net { // been provided during the forward pass. Dtype Backward(); + Dtype ForwardBackWard(const vector* > & bottom, + vector*>* top) { + Forward(bottom, top); + return Backward(); + } + // For an already initialized net, CopyTrainedLayersFrom() copies the already // trained layers from another net parameter instance. void CopyTrainedLayersFrom(const NetParameter& param); @@ -49,6 +55,8 @@ class Net { inline const vector > >& layers() { return layers_; } // returns the parameters vector > >& params() { return params_; }; + // Updates the network + void Update(); protected: // Individual layers in the net @@ -61,9 +69,11 @@ class Net { // bottom_vecs stores the vectors containing the input for each layer, except // for the first layer whose bottom vec is provided by the network's input. vector*> > bottom_vecs_; + vector > bottom_id_vecs_; // top_vecs stores the vectors containing the output for each layer, except // for the last layer (likewise) vector*> > top_vecs_; + vector > top_id_vecs_; // blob indices for the input and the output of the net. vector net_input_blob_indices_; vector net_output_blob_indices_; diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp new file mode 100644 index 0000000..b9055d2 --- /dev/null +++ b/src/caffe/optimization/solver.cpp @@ -0,0 +1,113 @@ +// Copyright Yangqing Jia 2013 + +#include +#include + +#include "caffe/proto/caffe.pb.h" +#include "caffe/net.hpp" +#include "caffe/optimization/solver.hpp" + +using std::stringstream; +using std::ofstream; + +namespace caffe { + +template +void Solver::Solve(Net* net) { + net_ = net; + LOG(INFO) << "Solving net " << net_->name(); + iter_ = 0; + // For a network that is trained by the solver, no bottom or top vecs + // should be given, and we will just provide dummy vecs. + vector*> bottom_vec; + vector*> top_vec; + while (iter_++ < param_.max_iter()) { + Dtype loss = net_->ForwardBackWard(bottom_vec, &top_vec); + ComputeUpdateValue(); + net->Update(); + + // Check if we need to do snapshot + if (iter_ % param_.snapshot()) { + // TODO(Yangqing): snapshot + } + LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss; + } + LOG(INFO) << "Optimization Done."; +} + +template +void Solver::Snapshot(bool is_final) { + NetParameter net_param; + net_->ToProto(&net_param); + stringstream ss; + ss << param_.snapshot_prefix(); + if (is_final) { + ss << "_final"; + } else { + ss << "_iter_" << iter_; + } + ofstream output_file; + output_file.open(ss.str().c_str()); + CHECK(net_param.SerializeToOstream(&output_file)); + output_file.close(); +} + +template +Dtype SGDSolver::GetLearningRate() { + Dtype rate; + const string& lr_policy = this->param_.lr_policy(); + if (lr_policy == "fixed") { + rate = this->param_.base_lr(); + } else if (lr_policy == "exp") { + rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); + } else if (lr_policy == "inv") { + rate = this->param_.base_lr() * + pow(Dtype(1) + this->param_.gamma() * this->iter_, + this->param_.power()); + } else { + LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; + } + rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr()); + return rate; +} + +template +void SGDSolver::ComputeUpdateValue() { + // First of all, see if we need to initialize the history + vector > >& net_params = this->net_.params(); + if (this->iter_ == 1 && this->param_.momentum() > 0) { + LOG(INFO) << "Using momentum " << this->param_.momentum(); + for (int i = 0; i < net_params.size(); ++i) { + const Blob* net_param = net_params[i].get(); + history_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + } + } + // get the learning rate + Dtype rate = GetLearningRate(); + if (this->param_.momentum == 0) { + for (int i = 0; i < net_params.size(); ++i) { + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_scal(net_params[i]->count(), rate, + net_params[i]->mutable_cpu_data()); + break; + case Caffe::GPU: + caffe_gpu_scal(net_params[i]->count(), rate, + net_params[i]->mutable_gpu_data()); + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } + } + } else { + NOT_IMPLEMENTED; + } +} + + + +INSTANTIATE_CLASS(Solver); + +} // namespace caffe \ No newline at end of file diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp index 0c680e3..0a78d88 100644 --- a/src/caffe/optimization/solver.hpp +++ b/src/caffe/optimization/solver.hpp @@ -3,16 +3,39 @@ namespace caffe { +template class Solver { public: explicit Solver(const SolverParameter& param) : param_(param) {} - void Solve(Net* net); + // The main entry of the solver function. + void Solve(Net* net); protected: + // Get the update value for the current iteration. + virtual void ComputeUpdateValue() = 0; + void Snapshot(bool is_final = false); SolverParameter param_; + int iter_; + Net* net_; + + DISABLE_COPY_AND_ASSIGN(Solver); }; +template +class SGDSolver : public Solver { + public: + explicit SGDSolver(const SolverParameter& param) + : Solver(param) {} + + protected: + Dtype GetLearningRate(); + virtual void ComputeUpdateValue(); + // history maintains the historical momentum data. + vector > > history_; +}; + + } // namspace caffe #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_ \ No newline at end of file diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 732c2ee..9d691d2 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -89,4 +89,6 @@ message SolverParameter { optional float gamma = 8; // The parameter to compute the learning rate. optional float power = 9; // The parameter to compute the learning rate. optional float momentum = 10; // The momentum value. + + optional string snapshot_prefix = 11; // The prefix for the snapshot. } \ No newline at end of file diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 1949a70..7cd3b26 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -124,6 +124,16 @@ void caffe_scal(const int N, const double alpha, double *X) { } template <> +void caffe_gpu_scal(const int N, const float alpha, float *X) { + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1)); +} + +template <> +void caffe_gpu_scal(const int N, const double alpha, double *X) { + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); +} + +template <> void caffe_sqr(const int n, const float* a, float* y) { vsSqr(n, a, y); } diff --git a/src/caffe/util/math_functions.hpp b/src/caffe/util/math_functions.hpp index 822ef31..f09afe3 100644 --- a/src/caffe/util/math_functions.hpp +++ b/src/caffe/util/math_functions.hpp @@ -46,6 +46,9 @@ template void caffe_scal(const int N, const Dtype alpha, Dtype *X); template +void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X); + +template void caffe_sqr(const int N, const Dtype* a, Dtype* y); template -- 2.7.4