From: qipeng Date: Thu, 24 Jul 2014 20:09:28 +0000 (-0700) Subject: Added L1 regularization support for the weights X-Git-Tag: submit/tizen/20180823.020014~620^2~62^2~16 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a683c40d63b79dcb4e2407d59335fb7f8129c82f;p=platform%2Fupstream%2Fcaffeonacl.git Added L1 regularization support for the weights --- diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 9d5481c..4bf50d4 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -73,7 +73,9 @@ class SGDSolver : public Solver { virtual void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. // update maintains update related data and is not needed in snapshots. - vector > > history_, update_; + // temp maintains other information that might be needed in computation + // of gradients/updates and is not needed in snapshots + vector > > history_, update_, temp_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 49a6e14..0bb5d11 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -116,6 +116,9 @@ message SolverParameter { optional float power = 10; // The parameter to compute the learning rate. optional float momentum = 11; // The momentum value. optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controled by weight_decay + optional string regularization_type = 25 [default = "L2"]; optional int32 stepsize = 13; // the stepsize for learning rate policy "step" optional int32 snapshot = 14 [default = 0]; // The snapshot interval optional string snapshot_prefix = 15; // The prefix for the snapshot. diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 8928c7b..223194b 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -378,6 +378,7 @@ void SGDSolver::PreSolve() { vector > >& net_params = this->net_->params(); history_.clear(); update_.clear(); + temp_.clear(); for (int i = 0; i < net_params.size(); ++i) { const Blob* net_param = net_params[i].get(); history_.push_back(shared_ptr >(new Blob( @@ -386,6 +387,9 @@ void SGDSolver::PreSolve() { update_.push_back(shared_ptr >(new Blob( net_param->num(), net_param->channels(), net_param->height(), net_param->width()))); + temp_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); } } @@ -402,6 +406,7 @@ void SGDSolver::ComputeUpdateValue() { } Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { @@ -412,11 +417,23 @@ void SGDSolver::ComputeUpdateValue() { net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); if (local_decay) { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->cpu_data(), - history_[param_id]->mutable_cpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + history_[param_id]->mutable_cpu_data()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + temp_[param_id]->cpu_data(), + history_[param_id]->mutable_cpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // copy caffe_copy(net_params[param_id]->count(), @@ -434,11 +451,23 @@ void SGDSolver::ComputeUpdateValue() { net_params[param_id]->gpu_diff(), momentum, history_[param_id]->mutable_gpu_data()); if (local_decay) { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->gpu_data(), - history_[param_id]->mutable_gpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + history_[param_id]->mutable_gpu_data()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + temp_[param_id]->gpu_data(), + history_[param_id]->mutable_gpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // copy caffe_copy(net_params[param_id]->count(), @@ -487,6 +516,7 @@ void NesterovSolver::ComputeUpdateValue() { } Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { @@ -501,11 +531,23 @@ void NesterovSolver::ComputeUpdateValue() { net_params[param_id]->cpu_diff(), momentum, this->history_[param_id]->mutable_cpu_data()); if (local_decay) { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute udpate: step back then over step caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, @@ -532,11 +574,23 @@ void NesterovSolver::ComputeUpdateValue() { net_params[param_id]->gpu_diff(), momentum, this->history_[param_id]->mutable_gpu_data()); if (local_decay) { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute udpate: step back then over step caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, @@ -568,6 +622,7 @@ void AdaGradSolver::ComputeUpdateValue() { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { @@ -575,11 +630,23 @@ void AdaGradSolver::ComputeUpdateValue() { Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; if (local_decay) { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute square of gradient in update @@ -619,11 +686,23 @@ void AdaGradSolver::ComputeUpdateValue() { Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; if (local_decay) { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute square of gradient in update