Add gradient clipping -- limit L2 norm of parameter gradients
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 7 Oct 2014 06:46:48 +0000 (23:46 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 14 Feb 2015 01:28:10 +0000 (17:28 -0800)
include/caffe/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index fde6620..2510de7 100644 (file)
@@ -81,6 +81,7 @@ class SGDSolver : public Solver<Dtype> {
   void PreSolve();
   Dtype GetLearningRate();
   virtual void ComputeUpdateValue();
+  virtual void ClipGradients();
   virtual void SnapshotSolverState(SolverState * state);
   virtual void RestoreSolverState(const SolverState& state);
   // history maintains the historical momentum data.
index c2a39a5..8d93742 100644 (file)
@@ -75,7 +75,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 35 (last added: stepvalue)
+// SolverParameter next available ID: 36 (last added: clip_gradients)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -140,6 +140,11 @@ message SolverParameter {
   optional int32 stepsize = 13;
   // the stepsize for learning rate policy "multistep"
   repeated int32 stepvalue = 34;
+
+  // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
+  // whenever their actual L2 norm is larger.
+  optional float clip_gradients = 35 [default = -1];
+
   optional int32 snapshot = 14 [default = 0]; // The snapshot interval
   optional string snapshot_prefix = 15; // The prefix for the snapshot.
   // whether to snapshot diff in the results or not. Snapshotting diff will help
index 100cb35..9866d7c 100644 (file)
@@ -437,6 +437,30 @@ void SGDSolver<Dtype>::PreSolve() {
   }
 }
 
+template <typename Dtype>
+void SGDSolver<Dtype>::ClipGradients() {
+  const Dtype clip_gradients = this->param_.clip_gradients();
+  if (clip_gradients < 0) { return; }
+  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  Dtype sumsq_diff = 0;
+  for (int i = 0; i < net_params.size(); ++i) {
+    if (this->net_->param_owners()[i] < 0) {
+      sumsq_diff += net_params[i]->sumsq_diff();
+    }
+  }
+  const Dtype l2norm_diff = std::sqrt(sumsq_diff);
+  if (l2norm_diff > clip_gradients) {
+    Dtype scale_factor = clip_gradients / l2norm_diff;
+    LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
+        << l2norm_diff << " > " << clip_gradients << ") "
+        << "by scale factor " << scale_factor;
+    for (int i = 0; i < net_params.size(); ++i) {
+      if (this->net_->param_owners()[i] < 0) {
+        net_params[i]->scale_diff(scale_factor);
+      }
+    }
+  }
+}
 
 template <typename Dtype>
 void SGDSolver<Dtype>::ComputeUpdateValue() {
@@ -449,6 +473,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
+  ClipGradients();
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
   string regularization_type = this->param_.regularization_type();
@@ -563,6 +588,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
+  SGDSolver<Dtype>::ClipGradients();
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
   string regularization_type = this->param_.regularization_type();
@@ -680,6 +706,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
+  SGDSolver<Dtype>::ClipGradients();
   Dtype weight_decay = this->param_.weight_decay();
   string regularization_type = this->param_.regularization_type();
   switch (Caffe::mode()) {