From: Ronghang Hu Date: Fri, 25 Sep 2015 00:11:07 +0000 (-0700) Subject: Split solver code into one file per solver class X-Git-Tag: submit/tizen/20180823.020014~325^2~3 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b822a702d19d4fbebbc91198a991f91c34e60650;p=platform%2Fupstream%2Fcaffeonacl.git Split solver code into one file per solver class --- diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp new file mode 100644 index 0000000..6bf1d70 --- /dev/null +++ b/include/caffe/sgd_solvers.hpp @@ -0,0 +1,142 @@ +#ifndef CAFFE_SGD_SOLVERS_HPP_ +#define CAFFE_SGD_SOLVERS_HPP_ + +#include +#include + +#include "caffe/solver.hpp" + +namespace caffe { + +/** + * @brief Optimizes the parameters of a Net using + * stochastic gradient descent (SGD) with momentum. + */ +template +class SGDSolver : public Solver { + public: + explicit SGDSolver(const SolverParameter& param) + : Solver(param) { PreSolve(); } + explicit SGDSolver(const string& param_file) + : Solver(param_file) { PreSolve(); } + + const vector > >& history() { return history_; } + + protected: + void PreSolve(); + Dtype GetLearningRate(); + virtual void ApplyUpdate(); + virtual void Normalize(int param_id); + virtual void Regularize(int param_id); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + virtual void ClipGradients(); + virtual void SnapshotSolverState(const string& model_filename); + virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); + virtual void SnapshotSolverStateToHDF5(const string& model_filename); + virtual void RestoreSolverStateFromHDF5(const string& state_file); + virtual void RestoreSolverStateFromBinaryProto(const string& state_file); + // history maintains the historical momentum data. + // update maintains update related data and is not needed in snapshots. + // temp maintains other information that might be needed in computation + // of gradients/updates and is not needed in snapshots + vector > > history_, update_, temp_; + + DISABLE_COPY_AND_ASSIGN(SGDSolver); +}; + +template +class NesterovSolver : public SGDSolver { + public: + explicit NesterovSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit NesterovSolver(const string& param_file) + : SGDSolver(param_file) {} + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(NesterovSolver); +}; + +template +class AdaGradSolver : public SGDSolver { + public: + explicit AdaGradSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit AdaGradSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with AdaGrad."; + } + + DISABLE_COPY_AND_ASSIGN(AdaGradSolver); +}; + + +template +class RMSPropSolver : public SGDSolver { + public: + explicit RMSPropSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with RMSProp."; + CHECK_GE(this->param_.rms_decay(), 0) + << "rms_decay should lie between 0 and 1."; + CHECK_LT(this->param_.rms_decay(), 1) + << "rms_decay should lie between 0 and 1."; + } + + DISABLE_COPY_AND_ASSIGN(RMSPropSolver); +}; + +template +class AdaDeltaSolver : public SGDSolver { + public: + explicit AdaDeltaSolver(const SolverParameter& param) + : SGDSolver(param) { AdaDeltaPreSolve(); } + explicit AdaDeltaSolver(const string& param_file) + : SGDSolver(param_file) { AdaDeltaPreSolve(); } + + protected: + void AdaDeltaPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); +}; + +/** + * @brief AdamSolver, an algorithm for first-order gradient-based optimization + * of stochastic objective functions, based on adaptive estimates of + * lower-order moments. Described in [1]. + * + * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." + * arXiv preprint arXiv:1412.6980v8 (2014). + */ +template +class AdamSolver : public SGDSolver { + public: + explicit AdamSolver(const SolverParameter& param) + : SGDSolver(param) { AdamPreSolve();} + explicit AdamSolver(const string& param_file) + : SGDSolver(param_file) { AdamPreSolve(); } + + protected: + void AdamPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdamSolver); +}; + +} // namespace caffe + +#endif // CAFFE_SGD_SOLVERS_HPP_ diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 2ecf539..a045ccf 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -1,5 +1,5 @@ -#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ -#define CAFFE_OPTIMIZATION_SOLVER_HPP_ +#ifndef CAFFE_SOLVER_HPP_ +#define CAFFE_SOLVER_HPP_ #include #include #include @@ -148,158 +148,10 @@ class WorkerSolver : public Solver { } }; -/** - * @brief Optimizes the parameters of a Net using - * stochastic gradient descent (SGD) with momentum. - */ -template -class SGDSolver : public Solver { - public: - explicit SGDSolver(const SolverParameter& param) - : Solver(param) { PreSolve(); } - explicit SGDSolver(const string& param_file) - : Solver(param_file) { PreSolve(); } - - const vector > >& history() { return history_; } - - protected: - void PreSolve(); - Dtype GetLearningRate(); - virtual void ApplyUpdate(); - virtual void Normalize(int param_id); - virtual void Regularize(int param_id); - virtual void ComputeUpdateValue(int param_id, Dtype rate); - virtual void ClipGradients(); - virtual void SnapshotSolverState(const string& model_filename); - virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); - virtual void SnapshotSolverStateToHDF5(const string& model_filename); - virtual void RestoreSolverStateFromHDF5(const string& state_file); - virtual void RestoreSolverStateFromBinaryProto(const string& state_file); - // history maintains the historical momentum data. - // update maintains update related data and is not needed in snapshots. - // temp maintains other information that might be needed in computation - // of gradients/updates and is not needed in snapshots - vector > > history_, update_, temp_; - - DISABLE_COPY_AND_ASSIGN(SGDSolver); -}; - +// The solver factory function template -class NesterovSolver : public SGDSolver { - public: - explicit NesterovSolver(const SolverParameter& param) - : SGDSolver(param) {} - explicit NesterovSolver(const string& param_file) - : SGDSolver(param_file) {} - - protected: - virtual void ComputeUpdateValue(int param_id, Dtype rate); - - DISABLE_COPY_AND_ASSIGN(NesterovSolver); -}; - -template -class AdaGradSolver : public SGDSolver { - public: - explicit AdaGradSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } - explicit AdaGradSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } - - protected: - virtual void ComputeUpdateValue(int param_id, Dtype rate); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.momentum()) - << "Momentum cannot be used with AdaGrad."; - } - - DISABLE_COPY_AND_ASSIGN(AdaGradSolver); -}; - - -template -class RMSPropSolver : public SGDSolver { - public: - explicit RMSPropSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } - explicit RMSPropSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } - - protected: - virtual void ComputeUpdateValue(int param_id, Dtype rate); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.momentum()) - << "Momentum cannot be used with RMSProp."; - CHECK_GE(this->param_.rms_decay(), 0) - << "rms_decay should lie between 0 and 1."; - CHECK_LT(this->param_.rms_decay(), 1) - << "rms_decay should lie between 0 and 1."; - } - - DISABLE_COPY_AND_ASSIGN(RMSPropSolver); -}; - -template -class AdaDeltaSolver : public SGDSolver { - public: - explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { AdaDeltaPreSolve(); } - explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { AdaDeltaPreSolve(); } - - protected: - void AdaDeltaPreSolve(); - virtual void ComputeUpdateValue(int param_id, Dtype rate); - - DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); -}; - -/** - * @brief AdamSolver, an algorithm for first-order gradient-based optimization - * of stochastic objective functions, based on adaptive estimates of - * lower-order moments. Described in [1]. - * - * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." - * arXiv preprint arXiv:1412.6980v8 (2014). - */ -template -class AdamSolver : public SGDSolver { - public: - explicit AdamSolver(const SolverParameter& param) - : SGDSolver(param) { AdamPreSolve();} - explicit AdamSolver(const string& param_file) - : SGDSolver(param_file) { AdamPreSolve(); } - - protected: - void AdamPreSolve(); - virtual void ComputeUpdateValue(int param_id, Dtype rate); - - DISABLE_COPY_AND_ASSIGN(AdamSolver); -}; - -template -Solver* GetSolver(const SolverParameter& param) { - SolverParameter_SolverType type = param.solver_type(); - - switch (type) { - case SolverParameter_SolverType_SGD: - return new SGDSolver(param); - case SolverParameter_SolverType_NESTEROV: - return new NesterovSolver(param); - case SolverParameter_SolverType_ADAGRAD: - return new AdaGradSolver(param); - case SolverParameter_SolverType_RMSPROP: - return new RMSPropSolver(param); - case SolverParameter_SolverType_ADADELTA: - return new AdaDeltaSolver(param); - case SolverParameter_SolverType_ADAM: - return new AdamSolver(param); - default: - LOG(FATAL) << "Unknown SolverType: " << type; - } - return (Solver*) NULL; -} +Solver* GetSolver(const SolverParameter& param); } // namespace caffe -#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_ +#endif // CAFFE_SOLVER_HPP_ diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index ccd5776..0e38dee 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -16,6 +16,7 @@ #include "caffe/caffe.hpp" #include "caffe/python_layer.hpp" +#include "caffe/sgd_solvers.hpp" // Temporary solution for numpy < 1.7 versions: old macro, no promises. // You're strongly advised to upgrade to >= 1.7. diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 12c13dd..016a028 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -1,18 +1,11 @@ #include -#include #include #include -#include "hdf5.h" -#include "hdf5_hl.h" - -#include "caffe/net.hpp" -#include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/hdf5.hpp" #include "caffe/util/io.hpp" -#include "caffe/util/math_functions.hpp" #include "caffe/util/upgrade_proto.hpp" namespace caffe { @@ -492,810 +485,6 @@ void Solver::Restore(const char* state_file) { } } -// Return the current learning rate. The currently implemented learning rate -// policies are as follows: -// - fixed: always return base_lr. -// - step: return base_lr * gamma ^ (floor(iter / step)) -// - exp: return base_lr * gamma ^ iter -// - inv: return base_lr * (1 + gamma * iter) ^ (- power) -// - multistep: similar to step but it allows non uniform steps defined by -// stepvalue -// - poly: the effective learning rate follows a polynomial decay, to be -// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) -// - sigmoid: the effective learning rate follows a sigmod decay -// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) -// -// where base_lr, max_iter, gamma, step, stepvalue and power are defined -// in the solver parameter protocol buffer, and iter is the current iteration. -template -Dtype SGDSolver::GetLearningRate() { - Dtype rate; - const string& lr_policy = this->param_.lr_policy(); - if (lr_policy == "fixed") { - rate = this->param_.base_lr(); - } else if (lr_policy == "step") { - this->current_step_ = this->iter_ / this->param_.stepsize(); - rate = this->param_.base_lr() * - pow(this->param_.gamma(), this->current_step_); - } else if (lr_policy == "exp") { - rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); - } else if (lr_policy == "inv") { - rate = this->param_.base_lr() * - pow(Dtype(1) + this->param_.gamma() * this->iter_, - - this->param_.power()); - } else if (lr_policy == "multistep") { - if (this->current_step_ < this->param_.stepvalue_size() && - this->iter_ >= this->param_.stepvalue(this->current_step_)) { - this->current_step_++; - LOG(INFO) << "MultiStep Status: Iteration " << - this->iter_ << ", step = " << this->current_step_; - } - rate = this->param_.base_lr() * - pow(this->param_.gamma(), this->current_step_); - } else if (lr_policy == "poly") { - rate = this->param_.base_lr() * pow(Dtype(1.) - - (Dtype(this->iter_) / Dtype(this->param_.max_iter())), - this->param_.power()); - } else if (lr_policy == "sigmoid") { - rate = this->param_.base_lr() * (Dtype(1.) / - (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - - Dtype(this->param_.stepsize()))))); - } else { - LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; - } - return rate; -} - -template -void SGDSolver::PreSolve() { - // Initialize the history - const vector*>& net_params = this->net_->learnable_params(); - history_.clear(); - update_.clear(); - temp_.clear(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - history_.push_back(shared_ptr >(new Blob(shape))); - update_.push_back(shared_ptr >(new Blob(shape))); - temp_.push_back(shared_ptr >(new Blob(shape))); - } -} - -template -void SGDSolver::ClipGradients() { - const Dtype clip_gradients = this->param_.clip_gradients(); - if (clip_gradients < 0) { return; } - const vector*>& net_params = this->net_->learnable_params(); - Dtype sumsq_diff = 0; - for (int i = 0; i < net_params.size(); ++i) { - sumsq_diff += net_params[i]->sumsq_diff(); - } - const Dtype l2norm_diff = std::sqrt(sumsq_diff); - if (l2norm_diff > clip_gradients) { - Dtype scale_factor = clip_gradients / l2norm_diff; - LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm " - << l2norm_diff << " > " << clip_gradients << ") " - << "by scale factor " << scale_factor; - for (int i = 0; i < net_params.size(); ++i) { - net_params[i]->scale_diff(scale_factor); - } - } -} - -template -void SGDSolver::ApplyUpdate() { - CHECK(Caffe::root_solver()); - Dtype rate = GetLearningRate(); - if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; - } - ClipGradients(); - for (int param_id = 0; param_id < this->net_->learnable_params().size(); - ++param_id) { - Normalize(param_id); - Regularize(param_id); - ComputeUpdateValue(param_id, rate); - } - this->net_->Update(); -} - -template -void SGDSolver::Normalize(int param_id) { - if (this->param_.iter_size() == 1) { return; } - // Scale gradient to counterbalance accumulation. - const vector*>& net_params = this->net_->learnable_params(); - const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); - switch (Caffe::mode()) { - case Caffe::CPU: { - caffe_scal(net_params[param_id]->count(), accum_normalization, - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - caffe_gpu_scal(net_params[param_id]->count(), accum_normalization, - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void SGDSolver::Regularize(int param_id) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); - Dtype weight_decay = this->param_.weight_decay(); - string regularization_type = this->param_.regularization_type(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; - // Compute the update to history, then copy it to the parameter diff. - switch (Caffe::mode()) { - case Caffe::CPU: { - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - history_[param_id]->mutable_cpu_data()); - caffe_copy(net_params[param_id]->count(), - history_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - history_[param_id]->mutable_gpu_data()); - caffe_copy(net_params[param_id]->count(), - history_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void SGDSolver::SnapshotSolverState(const string& model_filename) { - switch (this->param_.snapshot_format()) { - case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: - SnapshotSolverStateToBinaryProto(model_filename); - break; - case caffe::SolverParameter_SnapshotFormat_HDF5: - SnapshotSolverStateToHDF5(model_filename); - break; - default: - LOG(FATAL) << "Unsupported snapshot format."; - } -} - -template -void SGDSolver::SnapshotSolverStateToBinaryProto( - const string& model_filename) { - SolverState state; - state.set_iter(this->iter_); - state.set_learned_net(model_filename); - state.set_current_step(this->current_step_); - state.clear_history(); - for (int i = 0; i < history_.size(); ++i) { - // Add history - BlobProto* history_blob = state.add_history(); - history_[i]->ToProto(history_blob); - } - string snapshot_filename = Solver::SnapshotFilename(".solverstate"); - LOG(INFO) - << "Snapshotting solver state to binary proto file " << snapshot_filename; - WriteProtoToBinaryFile(state, snapshot_filename.c_str()); -} - -template -void SGDSolver::SnapshotSolverStateToHDF5( - const string& model_filename) { - string snapshot_filename = - Solver::SnapshotFilename(".solverstate.h5"); - LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename; - hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC, - H5P_DEFAULT, H5P_DEFAULT); - CHECK_GE(file_hid, 0) - << "Couldn't open " << snapshot_filename << " to save solver state."; - hdf5_save_int(file_hid, "iter", this->iter_); - hdf5_save_string(file_hid, "learned_net", model_filename); - hdf5_save_int(file_hid, "current_step", this->current_step_); - hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT, - H5P_DEFAULT); - CHECK_GE(history_hid, 0) - << "Error saving solver state to " << snapshot_filename << "."; - for (int i = 0; i < history_.size(); ++i) { - ostringstream oss; - oss << i; - hdf5_save_nd_dataset(history_hid, oss.str(), *history_[i]); - } - H5Gclose(history_hid); - H5Fclose(file_hid); -} - -template -void SGDSolver::RestoreSolverStateFromBinaryProto( - const string& state_file) { - SolverState state; - ReadProtoFromBinaryFile(state_file, &state); - this->iter_ = state.iter(); - if (state.has_learned_net()) { - NetParameter net_param; - ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); - this->net_->CopyTrainedLayersFrom(net_param); - } - this->current_step_ = state.current_step(); - CHECK_EQ(state.history_size(), history_.size()) - << "Incorrect length of history blobs."; - LOG(INFO) << "SGDSolver: restoring history"; - for (int i = 0; i < history_.size(); ++i) { - history_[i]->FromProto(state.history(i)); - } -} - -template -void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) { - hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); - CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file; - this->iter_ = hdf5_load_int(file_hid, "iter"); - if (H5LTfind_dataset(file_hid, "learned_net")) { - string learned_net = hdf5_load_string(file_hid, "learned_net"); - this->net_->CopyTrainedLayersFrom(learned_net); - } - this->current_step_ = hdf5_load_int(file_hid, "current_step"); - hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT); - CHECK_GE(history_hid, 0) << "Error reading history from " << state_file; - int state_history_size = hdf5_get_num_links(history_hid); - CHECK_EQ(state_history_size, history_.size()) - << "Incorrect length of history blobs."; - for (int i = 0; i < history_.size(); ++i) { - ostringstream oss; - oss << i; - hdf5_load_nd_dataset(history_hid, oss.str().c_str(), 0, - kMaxBlobAxes, history_[i].get()); - } - H5Gclose(history_hid); - H5Fclose(file_hid); -} - -template -void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - // save history momentum for stepping back - caffe_copy(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - this->history_[param_id]->mutable_cpu_data()); - - // compute update: step back then over step - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->cpu_data(), -momentum, - this->update_[param_id]->mutable_cpu_data()); - - // copy - caffe_copy(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // save history momentum for stepping back - caffe_copy(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - this->history_[param_id]->mutable_gpu_data()); - - // compute update: step back then over step - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->gpu_data(), -momentum, - this->update_[param_id]->mutable_gpu_data()); - - // copy - caffe_copy(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype delta = this->param_.delta(); - Dtype local_rate = rate * net_params_lr[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_add(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - this->history_[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); - - // prepare update - caffe_powx(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_cpu_data()); - - caffe_div(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), - this->update_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // scale and copy - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->cpu_data(), Dtype(0), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_add(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - this->history_[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); - - // prepare update - caffe_gpu_powx(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_div(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), - this->update_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // scale and copy - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->gpu_data(), Dtype(0), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - - // get the learning rate - Dtype delta = this->param_.delta(); - Dtype rms_decay = this->param_.rms_decay(); - Dtype local_rate = rate * net_params_lr[param_id]; - - switch (Caffe::mode()) { - case Caffe::CPU: - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_cpu_axpby(net_params[param_id] -> count(), - Dtype(1-rms_decay), this->update_[param_id]->cpu_data(), - rms_decay, this->history_[param_id]-> mutable_cpu_data()); - - // prepare update - caffe_powx(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_cpu_data()); - - caffe_div(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // scale and copy - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->cpu_data(), Dtype(0), - net_params[param_id]->mutable_cpu_diff()); - break; - case Caffe::GPU: -#ifndef CPU_ONLY - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_axpby(net_params[param_id] -> count(), - Dtype(1-rms_decay), this->update_[param_id]->gpu_data(), - rms_decay, this->history_[param_id]-> mutable_gpu_data()); - - // prepare update - caffe_gpu_powx(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_div(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->gpu_data(), Dtype(0), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void AdaDeltaSolver::AdaDeltaPreSolve() { - // Add the extra history entries for AdaDelta after those from - // SGDSolver::PreSolve - const vector*>& net_params = this->net_->learnable_params(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - this->history_.push_back( - shared_ptr >(new Blob(shape))); - } -} - -template -void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype delta = this->param_.delta(); - Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; - size_t update_history_offset = net_params.size(); - switch (Caffe::mode()) { - case Caffe::CPU: { - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history of gradients - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->cpu_data(), momentum, - this->history_[param_id]->mutable_cpu_data()); - - // add delta to history to guard against dividing by zero later - caffe_set(net_params[param_id]->count(), delta, - this->temp_[param_id]->mutable_cpu_data()); - - caffe_add(net_params[param_id]->count(), - this->temp_[param_id]->cpu_data(), - this->history_[update_history_offset + param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add(net_params[param_id]->count(), - this->temp_[param_id]->cpu_data(), - this->history_[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - - // divide history of updates by history of gradients - caffe_div(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - this->temp_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // jointly compute the RMS of both for update and gradient history - caffe_powx(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - // compute the update - caffe_mul(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), - this->update_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - - // compute square of update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history of updates - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->cpu_data(), momentum, - this->history_[update_history_offset + param_id]->mutable_cpu_data()); - - // apply learning rate - caffe_cpu_scale(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history of gradients - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->gpu_data(), momentum, - this->history_[param_id]->mutable_gpu_data()); - - // add delta to history to guard against dividing by zero later - caffe_gpu_set(net_params[param_id]->count(), delta, - this->temp_[param_id]->mutable_gpu_data()); - - caffe_gpu_add(net_params[param_id]->count(), - this->temp_[param_id]->gpu_data(), - this->history_[update_history_offset + param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add(net_params[param_id]->count(), - this->temp_[param_id]->gpu_data(), - this->history_[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - - // divide history of updates by history of gradients - caffe_gpu_div(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - this->temp_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // jointly compute the RMS of both for update and gradient history - caffe_gpu_powx(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - // compute the update and copy to net_diff - caffe_gpu_mul(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), - this->update_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - - // compute square of update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history of updates - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->gpu_data(), momentum, - this->history_[update_history_offset + param_id]->mutable_gpu_data()); - - // apply learning rate - caffe_gpu_scale(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void AdamSolver::AdamPreSolve() { - // Add the extra history entries for Adam after those from - // SGDSolver::PreSolve - const vector*>& net_params = this->net_->learnable_params(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - this->history_.push_back( - shared_ptr >(new Blob(shape))); - } -} - -template -void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype local_rate = rate * net_params_lr[param_id]; - const Dtype beta1 = this->param_.momentum(); - const Dtype beta2 = this->param_.momentum2(); - - // we create aliases for convenience - size_t update_history_offset = net_params.size(); - Blob* val_m = this->history_[param_id].get(); - Blob* val_v = this->history_[param_id + update_history_offset].get(); - Blob* val_t = this->temp_[param_id].get(); - - const int t = this->iter_ + 1; - const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / - (Dtype(1.) - pow(beta1, t)); - const int N = net_params[param_id]->count(); - const Dtype eps_hat = this->param_.delta(); - - switch (Caffe::mode()) { - case Caffe::CPU: { - // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t - caffe_cpu_axpby(N, Dtype(1)-beta1, - net_params[param_id]->cpu_diff(), beta1, - val_m->mutable_cpu_data()); - - // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 - caffe_mul(N, - net_params[param_id]->cpu_diff(), - net_params[param_id]->cpu_diff(), - val_t->mutable_cpu_data()); - caffe_cpu_axpby(N, Dtype(1)-beta2, - val_t->cpu_data(), beta2, - val_v->mutable_cpu_data()); - - // set update - caffe_powx(N, - val_v->cpu_data(), Dtype(0.5), - val_t->mutable_cpu_data()); - caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); - caffe_div(N, - val_m->cpu_data(), - val_t->cpu_data(), - val_t->mutable_cpu_data()); - - caffe_cpu_scale(N, local_rate*correction, - val_t->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t - caffe_gpu_axpby(N, Dtype(1)-beta1, - net_params[param_id]->gpu_diff(), beta1, - val_m->mutable_gpu_data()); - - // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 - caffe_gpu_mul(N, - net_params[param_id]->gpu_diff(), - net_params[param_id]->gpu_diff(), - val_t->mutable_gpu_data()); - caffe_gpu_axpby(N, Dtype(1)-beta2, - val_t->gpu_data(), beta2, - val_v->mutable_gpu_data()); - - // set update - caffe_gpu_powx(N, - val_v->gpu_data(), Dtype(0.5), - val_t->mutable_gpu_data()); - caffe_gpu_add_scalar(N, eps_hat, - val_t->mutable_gpu_data()); - caffe_gpu_div(N, - val_m->gpu_data(), - val_t->gpu_data(), - val_t->mutable_gpu_data()); - - caffe_gpu_scale(N, local_rate*correction, - val_t->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - INSTANTIATE_CLASS(Solver); -INSTANTIATE_CLASS(SGDSolver); -INSTANTIATE_CLASS(NesterovSolver); -INSTANTIATE_CLASS(AdaGradSolver); -INSTANTIATE_CLASS(RMSPropSolver); -INSTANTIATE_CLASS(AdaDeltaSolver); -INSTANTIATE_CLASS(AdamSolver); } // namespace caffe diff --git a/src/caffe/solver_factory.cpp b/src/caffe/solver_factory.cpp new file mode 100644 index 0000000..f78fab2 --- /dev/null +++ b/src/caffe/solver_factory.cpp @@ -0,0 +1,32 @@ +#include "caffe/solver.hpp" +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +Solver* GetSolver(const SolverParameter& param) { + SolverParameter_SolverType type = param.solver_type(); + + switch (type) { + case SolverParameter_SolverType_SGD: + return new SGDSolver(param); + case SolverParameter_SolverType_NESTEROV: + return new NesterovSolver(param); + case SolverParameter_SolverType_ADAGRAD: + return new AdaGradSolver(param); + case SolverParameter_SolverType_RMSPROP: + return new RMSPropSolver(param); + case SolverParameter_SolverType_ADADELTA: + return new AdaDeltaSolver(param); + case SolverParameter_SolverType_ADAM: + return new AdamSolver(param); + default: + LOG(FATAL) << "Unknown SolverType: " << type; + } + return (Solver*) NULL; +} + +template Solver* GetSolver(const SolverParameter& param); +template Solver* GetSolver(const SolverParameter& param); + +} // namespace caffe diff --git a/src/caffe/solvers/adadelta_solver.cpp b/src/caffe/solvers/adadelta_solver.cpp new file mode 100644 index 0000000..45cd4eb --- /dev/null +++ b/src/caffe/solvers/adadelta_solver.cpp @@ -0,0 +1,155 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void AdaDeltaSolver::AdaDeltaPreSolve() { + // Add the extra history entries for AdaDelta after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype delta = this->param_.delta(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + size_t update_history_offset = net_params.size(); + switch (Caffe::mode()) { + case Caffe::CPU: { + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of gradients + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[update_history_offset + param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + + // divide history of updates by history of gradients + caffe_div(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->temp_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_powx(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + // compute the update + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + + // compute square of update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of updates + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_cpu_data()); + + // apply learning rate + caffe_cpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of gradients + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_gpu_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[update_history_offset + param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + + // divide history of updates by history of gradients + caffe_gpu_div(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->temp_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_gpu_powx(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + // compute the update and copy to net_diff + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + + // compute square of update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of updates + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_gpu_data()); + + // apply learning rate + caffe_gpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(AdaDeltaSolver); + +} // namespace caffe diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp new file mode 100644 index 0000000..627d816 --- /dev/null +++ b/src/caffe/solvers/adagrad_solver.cpp @@ -0,0 +1,88 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype delta = this->param_.delta(); + Dtype local_rate = rate * net_params_lr[param_id]; + switch (Caffe::mode()) { + case Caffe::CPU: { + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_add(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_add(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // scale and copy + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(AdaGradSolver); + +} // namespace caffe diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp new file mode 100644 index 0000000..8c334f6 --- /dev/null +++ b/src/caffe/solvers/adam_solver.cpp @@ -0,0 +1,112 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void AdamSolver::AdamPreSolve() { + // Add the extra history entries for Adam after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype local_rate = rate * net_params_lr[param_id]; + const Dtype beta1 = this->param_.momentum(); + const Dtype beta2 = this->param_.momentum2(); + + // we create aliases for convenience + size_t update_history_offset = net_params.size(); + Blob* val_m = this->history_[param_id].get(); + Blob* val_v = this->history_[param_id + update_history_offset].get(); + Blob* val_t = this->temp_[param_id].get(); + + const int t = this->iter_ + 1; + const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / + (Dtype(1.) - pow(beta1, t)); + const int N = net_params[param_id]->count(); + const Dtype eps_hat = this->param_.delta(); + + switch (Caffe::mode()) { + case Caffe::CPU: { + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_cpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->cpu_diff(), beta1, + val_m->mutable_cpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_mul(N, + net_params[param_id]->cpu_diff(), + net_params[param_id]->cpu_diff(), + val_t->mutable_cpu_data()); + caffe_cpu_axpby(N, Dtype(1)-beta2, + val_t->cpu_data(), beta2, + val_v->mutable_cpu_data()); + + // set update + caffe_powx(N, + val_v->cpu_data(), Dtype(0.5), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + caffe_div(N, + val_m->cpu_data(), + val_t->cpu_data(), + val_t->mutable_cpu_data()); + + caffe_cpu_scale(N, local_rate*correction, + val_t->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_gpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->gpu_diff(), beta1, + val_m->mutable_gpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_gpu_mul(N, + net_params[param_id]->gpu_diff(), + net_params[param_id]->gpu_diff(), + val_t->mutable_gpu_data()); + caffe_gpu_axpby(N, Dtype(1)-beta2, + val_t->gpu_data(), beta2, + val_v->mutable_gpu_data()); + + // set update + caffe_gpu_powx(N, + val_v->gpu_data(), Dtype(0.5), + val_t->mutable_gpu_data()); + caffe_gpu_add_scalar(N, eps_hat, + val_t->mutable_gpu_data()); + caffe_gpu_div(N, + val_m->gpu_data(), + val_t->gpu_data(), + val_t->mutable_gpu_data()); + + caffe_gpu_scale(N, local_rate*correction, + val_t->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(AdamSolver); + +} // namespace caffe diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp new file mode 100644 index 0000000..8135ee2 --- /dev/null +++ b/src/caffe/solvers/nesterov_solver.cpp @@ -0,0 +1,70 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + switch (Caffe::mode()) { + case Caffe::CPU: { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // compute update: step back then over step + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->cpu_data(), -momentum, + this->update_[param_id]->mutable_cpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // compute update: step back then over step + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->gpu_data(), -momentum, + this->update_[param_id]->mutable_gpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(NesterovSolver); + +} // namespace caffe diff --git a/src/caffe/solvers/rmsprop_solver.cpp b/src/caffe/solvers/rmsprop_solver.cpp new file mode 100644 index 0000000..96d1b3d --- /dev/null +++ b/src/caffe/solvers/rmsprop_solver.cpp @@ -0,0 +1,84 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + + // get the learning rate + Dtype delta = this->param_.delta(); + Dtype rms_decay = this->param_.rms_decay(); + Dtype local_rate = rate * net_params_lr[param_id]; + + switch (Caffe::mode()) { + case Caffe::CPU: + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_cpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->cpu_data(), + rms_decay, this->history_[param_id]-> mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->gpu_data(), + rms_decay, this->history_[param_id]-> mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(RMSPropSolver); + +} // namespace caffe diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp new file mode 100644 index 0000000..89ef5ec --- /dev/null +++ b/src/caffe/solvers/sgd_solver.cpp @@ -0,0 +1,347 @@ +#include +#include + +#include "caffe/sgd_solvers.hpp" +#include "caffe/util/hdf5.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +namespace caffe { + +// Return the current learning rate. The currently implemented learning rate +// policies are as follows: +// - fixed: always return base_lr. +// - step: return base_lr * gamma ^ (floor(iter / step)) +// - exp: return base_lr * gamma ^ iter +// - inv: return base_lr * (1 + gamma * iter) ^ (- power) +// - multistep: similar to step but it allows non uniform steps defined by +// stepvalue +// - poly: the effective learning rate follows a polynomial decay, to be +// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) +// - sigmoid: the effective learning rate follows a sigmod decay +// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) +// +// where base_lr, max_iter, gamma, step, stepvalue and power are defined +// in the solver parameter protocol buffer, and iter is the current iteration. +template +Dtype SGDSolver::GetLearningRate() { + Dtype rate; + const string& lr_policy = this->param_.lr_policy(); + if (lr_policy == "fixed") { + rate = this->param_.base_lr(); + } else if (lr_policy == "step") { + this->current_step_ = this->iter_ / this->param_.stepsize(); + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "exp") { + rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); + } else if (lr_policy == "inv") { + rate = this->param_.base_lr() * + pow(Dtype(1) + this->param_.gamma() * this->iter_, + - this->param_.power()); + } else if (lr_policy == "multistep") { + if (this->current_step_ < this->param_.stepvalue_size() && + this->iter_ >= this->param_.stepvalue(this->current_step_)) { + this->current_step_++; + LOG(INFO) << "MultiStep Status: Iteration " << + this->iter_ << ", step = " << this->current_step_; + } + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "poly") { + rate = this->param_.base_lr() * pow(Dtype(1.) - + (Dtype(this->iter_) / Dtype(this->param_.max_iter())), + this->param_.power()); + } else if (lr_policy == "sigmoid") { + rate = this->param_.base_lr() * (Dtype(1.) / + (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - + Dtype(this->param_.stepsize()))))); + } else { + LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; + } + return rate; +} + +template +void SGDSolver::PreSolve() { + // Initialize the history + const vector*>& net_params = this->net_->learnable_params(); + history_.clear(); + update_.clear(); + temp_.clear(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + history_.push_back(shared_ptr >(new Blob(shape))); + update_.push_back(shared_ptr >(new Blob(shape))); + temp_.push_back(shared_ptr >(new Blob(shape))); + } +} + +template +void SGDSolver::ClipGradients() { + const Dtype clip_gradients = this->param_.clip_gradients(); + if (clip_gradients < 0) { return; } + const vector*>& net_params = this->net_->learnable_params(); + Dtype sumsq_diff = 0; + for (int i = 0; i < net_params.size(); ++i) { + sumsq_diff += net_params[i]->sumsq_diff(); + } + const Dtype l2norm_diff = std::sqrt(sumsq_diff); + if (l2norm_diff > clip_gradients) { + Dtype scale_factor = clip_gradients / l2norm_diff; + LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm " + << l2norm_diff << " > " << clip_gradients << ") " + << "by scale factor " << scale_factor; + for (int i = 0; i < net_params.size(); ++i) { + net_params[i]->scale_diff(scale_factor); + } + } +} + +template +void SGDSolver::ApplyUpdate() { + CHECK(Caffe::root_solver()); + Dtype rate = GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + ClipGradients(); + for (int param_id = 0; param_id < this->net_->learnable_params().size(); + ++param_id) { + Normalize(param_id); + Regularize(param_id); + ComputeUpdateValue(param_id, rate); + } + this->net_->Update(); +} + +template +void SGDSolver::Normalize(int param_id) { + if (this->param_.iter_size() == 1) { return; } + // Scale gradient to counterbalance accumulation. + const vector*>& net_params = this->net_->learnable_params(); + const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); + switch (Caffe::mode()) { + case Caffe::CPU: { + caffe_scal(net_params[param_id]->count(), accum_normalization, + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + caffe_gpu_scal(net_params[param_id]->count(), accum_normalization, + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::Regularize(int param_id) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_weight_decay = + this->net_->params_weight_decay(); + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + switch (Caffe::mode()) { + case Caffe::CPU: { + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + // Compute the update to history, then copy it to the parameter diff. + switch (Caffe::mode()) { + case Caffe::CPU: { + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + history_[param_id]->mutable_cpu_data()); + caffe_copy(net_params[param_id]->count(), + history_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + history_[param_id]->mutable_gpu_data()); + caffe_copy(net_params[param_id]->count(), + history_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::SnapshotSolverState(const string& model_filename) { + switch (this->param_.snapshot_format()) { + case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: + SnapshotSolverStateToBinaryProto(model_filename); + break; + case caffe::SolverParameter_SnapshotFormat_HDF5: + SnapshotSolverStateToHDF5(model_filename); + break; + default: + LOG(FATAL) << "Unsupported snapshot format."; + } +} + +template +void SGDSolver::SnapshotSolverStateToBinaryProto( + const string& model_filename) { + SolverState state; + state.set_iter(this->iter_); + state.set_learned_net(model_filename); + state.set_current_step(this->current_step_); + state.clear_history(); + for (int i = 0; i < history_.size(); ++i) { + // Add history + BlobProto* history_blob = state.add_history(); + history_[i]->ToProto(history_blob); + } + string snapshot_filename = Solver::SnapshotFilename(".solverstate"); + LOG(INFO) + << "Snapshotting solver state to binary proto file " << snapshot_filename; + WriteProtoToBinaryFile(state, snapshot_filename.c_str()); +} + +template +void SGDSolver::SnapshotSolverStateToHDF5( + const string& model_filename) { + string snapshot_filename = + Solver::SnapshotFilename(".solverstate.h5"); + LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename; + hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC, + H5P_DEFAULT, H5P_DEFAULT); + CHECK_GE(file_hid, 0) + << "Couldn't open " << snapshot_filename << " to save solver state."; + hdf5_save_int(file_hid, "iter", this->iter_); + hdf5_save_string(file_hid, "learned_net", model_filename); + hdf5_save_int(file_hid, "current_step", this->current_step_); + hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT, + H5P_DEFAULT); + CHECK_GE(history_hid, 0) + << "Error saving solver state to " << snapshot_filename << "."; + for (int i = 0; i < history_.size(); ++i) { + ostringstream oss; + oss << i; + hdf5_save_nd_dataset(history_hid, oss.str(), *history_[i]); + } + H5Gclose(history_hid); + H5Fclose(file_hid); +} + +template +void SGDSolver::RestoreSolverStateFromBinaryProto( + const string& state_file) { + SolverState state; + ReadProtoFromBinaryFile(state_file, &state); + this->iter_ = state.iter(); + if (state.has_learned_net()) { + NetParameter net_param; + ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); + this->net_->CopyTrainedLayersFrom(net_param); + } + this->current_step_ = state.current_step(); + CHECK_EQ(state.history_size(), history_.size()) + << "Incorrect length of history blobs."; + LOG(INFO) << "SGDSolver: restoring history"; + for (int i = 0; i < history_.size(); ++i) { + history_[i]->FromProto(state.history(i)); + } +} + +template +void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) { + hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file; + this->iter_ = hdf5_load_int(file_hid, "iter"); + if (H5LTfind_dataset(file_hid, "learned_net")) { + string learned_net = hdf5_load_string(file_hid, "learned_net"); + this->net_->CopyTrainedLayersFrom(learned_net); + } + this->current_step_ = hdf5_load_int(file_hid, "current_step"); + hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT); + CHECK_GE(history_hid, 0) << "Error reading history from " << state_file; + int state_history_size = hdf5_get_num_links(history_hid); + CHECK_EQ(state_history_size, history_.size()) + << "Incorrect length of history blobs."; + for (int i = 0; i < history_.size(); ++i) { + ostringstream oss; + oss << i; + hdf5_load_nd_dataset(history_hid, oss.str().c_str(), 0, + kMaxBlobAxes, history_[i].get()); + } + H5Gclose(history_hid); + H5Fclose(file_hid); +} + +INSTANTIATE_CLASS(SGDSolver); + +} // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 7ad7467..1767ad3 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -10,7 +10,7 @@ #include "caffe/common.hpp" #include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" -#include "caffe/solver.hpp" +#include "caffe/sgd_solvers.hpp" #include "caffe/util/io.hpp" #include "caffe/test/test_caffe_main.hpp" diff --git a/src/caffe/test/test_solver.cpp b/src/caffe/test/test_solver.cpp index ceabc9c..b181642 100644 --- a/src/caffe/test/test_solver.cpp +++ b/src/caffe/test/test_solver.cpp @@ -7,6 +7,7 @@ #include "caffe/common.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/sgd_solvers.hpp" #include "caffe/solver.hpp" #include "caffe/test/test_caffe_main.hpp"