From: qipeng Date: Sat, 19 Jul 2014 21:14:21 +0000 (-0700) Subject: Solver switching support & implementation of Nesterov's accelerated gradient and... X-Git-Tag: submit/tizen/20180823.020014~620^2~62^2~22 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9d10569a79d4dc1a0aa3f42a0bd5214a085ab7c9;p=platform%2Fupstream%2Fcaffeonacl.git Solver switching support & implementation of Nesterov's accelerated gradient and AdaGrad --- diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 9012c5d..87c6563 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -66,17 +66,62 @@ class SGDSolver : public Solver { : Solver(param_file) {} protected: - virtual void PreSolve(); + void PreSolve(); Dtype GetLearningRate(); virtual void ComputeUpdateValue(); - virtual void SnapshotSolverState(SolverState * state); - virtual void RestoreSolverState(const SolverState& state); + void SnapshotSolverState(SolverState * state); + void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. - vector > > history_; + // update maintains update related data and is not needed in snapshots. + vector > > history_, update_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; +template +class NesterovSolver : public SGDSolver { + public: + explicit NesterovSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit NesterovSolver(const string& param_file) + : SGDSolver(param_file) {} + + protected: + virtual void ComputeUpdateValue(); + + DISABLE_COPY_AND_ASSIGN(NesterovSolver); +}; + +template +class AdaGradSolver : public SGDSolver { + public: + explicit AdaGradSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit AdaGradSolver(const string& param_file) + : SGDSolver(param_file) {} + + protected: + virtual void ComputeUpdateValue(); + + DISABLE_COPY_AND_ASSIGN(AdaGradSolver); +}; + +template +Solver* GetSolver(const SolverParameter& param) { + SolverParameter_SolverType type = param.solver_type(); + + switch (type) { + case SolverParameter_SolverType_SGD: + return new SGDSolver(param); + case SolverParameter_SolverType_NESTEROV: + return new NesterovSolver(param); + case SolverParameter_SolverType_ADAGRAD: + return new AdaGradSolver(param); + default: + LOG(FATAL) << "Unknown SolverType: " << type; + } +} + } // namespace caffe diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 80582b3..5632f24 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -377,11 +377,15 @@ void SGDSolver::PreSolve() { // Initialize the history vector > >& net_params = this->net_->params(); history_.clear(); + update_.clear(); for (int i = 0; i < net_params.size(); ++i) { const Blob* net_param = net_params[i].get(); history_.push_back(shared_ptr >(new Blob( net_param->num(), net_param->channels(), net_param->height(), net_param->width()))); + update_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); } } @@ -470,7 +474,192 @@ void SGDSolver::RestoreSolverState(const SolverState& state) { } } +template +void NesterovSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = this->GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype momentum = this->param_.momentum(); + Dtype weight_decay = this->param_.weight_decay(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + if (local_decay) { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } + // compute udpate: step back then over step + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->cpu_data(), -momentum, + this->update_[param_id]->mutable_cpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + if (local_decay) { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } + // compute udpate: step back then over step + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->gpu_data(), -momentum, + this->update_[param_id]->mutable_gpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void AdaGradSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = this->GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype weight_decay = this->param_.weight_decay(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } + + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_add(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(-0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } + + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_add(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(-0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // scale and copy + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); +INSTANTIATE_CLASS(NesterovSolver); +INSTANTIATE_CLASS(AdaGradSolver); } // namespace caffe diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 622bca3..1176759 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -1,6 +1,9 @@ #include "caffe/caffe.hpp" +using namespace caffe; // NOLINT(build/namespaces) + int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe train --solver=... " "[--snapshot=...] instead."; return 0;