Added L1 regularization support for the weights
authorqipeng <pengrobertqi@163.com>
Thu, 24 Jul 2014 20:09:28 +0000 (13:09 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 1 Sep 2014 18:33:41 +0000 (11:33 -0700)
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index 9d5481c..4bf50d4 100644 (file)
@@ -73,7 +73,9 @@ class SGDSolver : public Solver<Dtype> {
   virtual void RestoreSolverState(const SolverState& state);
   // history maintains the historical momentum data.
   // update maintains update related data and is not needed in snapshots.
-  vector<shared_ptr<Blob<Dtype> > > history_, update_;
+  // temp maintains other information that might be needed in computation
+  //   of gradients/updates and is not needed in snapshots
+  vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
 
   DISABLE_COPY_AND_ASSIGN(SGDSolver);
 };
index 49a6e14..0bb5d11 100644 (file)
@@ -116,6 +116,9 @@ message SolverParameter {
   optional float power = 10; // The parameter to compute the learning rate.
   optional float momentum = 11; // The momentum value.
   optional float weight_decay = 12; // The weight decay.
+  // regularization types supported: L1 and L2
+  // controled by weight_decay
+  optional string regularization_type = 25 [default = "L2"];
   optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
   optional int32 snapshot = 14 [default = 0]; // The snapshot interval
   optional string snapshot_prefix = 15; // The prefix for the snapshot.
index 8928c7b..223194b 100644 (file)
@@ -378,6 +378,7 @@ void SGDSolver<Dtype>::PreSolve() {
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   history_.clear();
   update_.clear();
+  temp_.clear();
   for (int i = 0; i < net_params.size(); ++i) {
     const Blob<Dtype>* net_param = net_params[i].get();
     history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
@@ -386,6 +387,9 @@ void SGDSolver<Dtype>::PreSolve() {
     update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
         net_param->num(), net_param->channels(), net_param->height(),
         net_param->width())));
+    temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
+        net_param->num(), net_param->channels(), net_param->height(),
+        net_param->width())));
   }
 }
 
@@ -402,6 +406,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   }
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
+  string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {
   case Caffe::CPU:
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
@@ -412,11 +417,23 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
           net_params[param_id]->cpu_diff(), momentum,
           history_[param_id]->mutable_cpu_data());
       if (local_decay) {
-        // add weight decay
-        caffe_axpy(net_params[param_id]->count(),
-            local_decay * local_rate,
-            net_params[param_id]->cpu_data(),
-            history_[param_id]->mutable_cpu_data());
+        if (regularization_type == "L2") {
+          // add weight decay
+          caffe_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              net_params[param_id]->cpu_data(),
+              history_[param_id]->mutable_cpu_data());
+        } else if (regularization_type == "L1") {
+          caffe_cpu_sign(net_params[param_id]->count(),
+              net_params[param_id]->cpu_data(),
+              temp_[param_id]->mutable_cpu_data());
+          caffe_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              temp_[param_id]->cpu_data(),
+              history_[param_id]->mutable_cpu_data());
+        } else {
+          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+        }
       }
       // copy
       caffe_copy(net_params[param_id]->count(),
@@ -434,11 +451,23 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
           net_params[param_id]->gpu_diff(), momentum,
           history_[param_id]->mutable_gpu_data());
       if (local_decay) {
-        // add weight decay
-        caffe_gpu_axpy(net_params[param_id]->count(),
-            local_decay * local_rate,
-            net_params[param_id]->gpu_data(),
-            history_[param_id]->mutable_gpu_data());
+        if (regularization_type == "L2") {
+          // add weight decay
+          caffe_gpu_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              net_params[param_id]->gpu_data(),
+              history_[param_id]->mutable_gpu_data());
+        } else if (regularization_type == "L1") {
+          caffe_gpu_sign(net_params[param_id]->count(),
+              net_params[param_id]->gpu_data(),
+              temp_[param_id]->mutable_gpu_data());
+          caffe_gpu_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              temp_[param_id]->gpu_data(),
+              history_[param_id]->mutable_gpu_data());
+        } else {
+          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+        }
       }
       // copy
       caffe_copy(net_params[param_id]->count(),
@@ -487,6 +516,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
   }
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
+  string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {
   case Caffe::CPU:
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
@@ -501,11 +531,23 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
           net_params[param_id]->cpu_diff(), momentum,
           this->history_[param_id]->mutable_cpu_data());
       if (local_decay) {
-        // add weight decay
-        caffe_axpy(net_params[param_id]->count(),
-            local_decay * local_rate,
-            net_params[param_id]->cpu_data(),
-            this->history_[param_id]->mutable_cpu_data());
+        if (regularization_type == "L2") {
+          // add weight decay
+          caffe_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              net_params[param_id]->cpu_data(),
+              this->history_[param_id]->mutable_cpu_data());
+        } else if (regularization_type == "L1") {
+          caffe_cpu_sign(net_params[param_id]->count(),
+              net_params[param_id]->cpu_data(),
+              this->temp_[param_id]->mutable_cpu_data());
+          caffe_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              this->temp_[param_id]->cpu_data(),
+              this->history_[param_id]->mutable_cpu_data());
+        } else {
+          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+        }
       }
       // compute udpate: step back then over step
       caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
@@ -532,11 +574,23 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
           net_params[param_id]->gpu_diff(), momentum,
           this->history_[param_id]->mutable_gpu_data());
       if (local_decay) {
-        // add weight decay
-        caffe_gpu_axpy(net_params[param_id]->count(),
-            local_decay * local_rate,
-            net_params[param_id]->gpu_data(),
-            this->history_[param_id]->mutable_gpu_data());
+        if (regularization_type == "L2") {
+          // add weight decay
+          caffe_gpu_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              net_params[param_id]->gpu_data(),
+              this->history_[param_id]->mutable_gpu_data());
+        } else if (regularization_type == "L1") {
+          caffe_gpu_sign(net_params[param_id]->count(),
+              net_params[param_id]->gpu_data(),
+              this->temp_[param_id]->mutable_gpu_data());
+          caffe_gpu_axpy(net_params[param_id]->count(),
+              local_decay * local_rate,
+              this->temp_[param_id]->gpu_data(),
+              this->history_[param_id]->mutable_gpu_data());
+        } else {
+          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+        }
       }
       // compute udpate: step back then over step
       caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
@@ -568,6 +622,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
   Dtype weight_decay = this->param_.weight_decay();
+  string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {
   case Caffe::CPU:
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
@@ -575,11 +630,23 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
 
       if (local_decay) {
-        // add weight decay
-        caffe_axpy(net_params[param_id]->count(),
-            local_decay,
-            net_params[param_id]->cpu_data(),
-            net_params[param_id]->mutable_cpu_diff());
+        if (regularization_type == "L2") {
+          // add weight decay
+          caffe_axpy(net_params[param_id]->count(),
+              local_decay,
+              net_params[param_id]->cpu_data(),
+              this->history_[param_id]->mutable_cpu_data());
+        } else if (regularization_type == "L1") {
+          caffe_cpu_sign(net_params[param_id]->count(),
+              net_params[param_id]->cpu_data(),
+              this->temp_[param_id]->mutable_cpu_data());
+          caffe_axpy(net_params[param_id]->count(),
+              local_decay,
+              this->temp_[param_id]->cpu_data(),
+              this->history_[param_id]->mutable_cpu_data());
+        } else {
+          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+        }
       }
 
       // compute square of gradient in update
@@ -619,11 +686,23 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
 
       if (local_decay) {
-        // add weight decay
-        caffe_gpu_axpy(net_params[param_id]->count(),
-            local_decay,
-            net_params[param_id]->gpu_data(),
-            net_params[param_id]->mutable_gpu_diff());
+        if (regularization_type == "L2") {
+          // add weight decay
+          caffe_gpu_axpy(net_params[param_id]->count(),
+              local_decay,
+              net_params[param_id]->gpu_data(),
+              this->history_[param_id]->mutable_gpu_data());
+        } else if (regularization_type == "L1") {
+          caffe_gpu_sign(net_params[param_id]->count(),
+              net_params[param_id]->gpu_data(),
+              this->temp_[param_id]->mutable_gpu_data());
+          caffe_gpu_axpy(net_params[param_id]->count(),
+              local_decay,
+              this->temp_[param_id]->gpu_data(),
+              this->history_[param_id]->mutable_gpu_data());
+        } else {
+          LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+        }
       }
 
       // compute square of gradient in update