added unit test for solvers and fixed solver bugs
authorqipeng <pengrobertqi@163.com>
Thu, 21 Aug 2014 03:07:53 +0000 (20:07 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 1 Sep 2014 18:33:41 +0000 (11:33 -0700)
src/caffe/solver.cpp
src/caffe/test/test_adagrad_solver.cpp [new file with mode: 0644]
src/caffe/test/test_nesterov_solver.cpp [new file with mode: 0644]

index 52fd652..dcac4c1 100644 (file)
@@ -413,28 +413,30 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
       // Compute the value to history, and then copy them to the blob's diff.
       Dtype local_rate = rate * net_params_lr[param_id];
       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-          net_params[param_id]->cpu_diff(), momentum,
-          history_[param_id]->mutable_cpu_data());
+
       if (local_decay) {
         if (regularization_type == "L2") {
           // add weight decay
           caffe_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               net_params[param_id]->cpu_data(),
-              history_[param_id]->mutable_cpu_data());
+              net_params[param_id]->mutable_cpu_diff());
         } else if (regularization_type == "L1") {
           caffe_cpu_sign(net_params[param_id]->count(),
               net_params[param_id]->cpu_data(),
               temp_[param_id]->mutable_cpu_data());
           caffe_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               temp_[param_id]->cpu_data(),
-              history_[param_id]->mutable_cpu_data());
+              net_params[param_id]->mutable_cpu_diff());
         } else {
           LOG(FATAL) << "Unknown regularization type: " << regularization_type;
         }
       }
+
+      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+                net_params[param_id]->cpu_diff(), momentum,
+                history_[param_id]->mutable_cpu_data());
       // copy
       caffe_copy(net_params[param_id]->count(),
           history_[param_id]->cpu_data(),
@@ -447,28 +449,30 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
       // Compute the value to history, and then copy them to the blob's diff.
       Dtype local_rate = rate * net_params_lr[param_id];
       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-          net_params[param_id]->gpu_diff(), momentum,
-          history_[param_id]->mutable_gpu_data());
+
       if (local_decay) {
         if (regularization_type == "L2") {
           // add weight decay
           caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               net_params[param_id]->gpu_data(),
-              history_[param_id]->mutable_gpu_data());
+              net_params[param_id]->mutable_gpu_diff());
         } else if (regularization_type == "L1") {
           caffe_gpu_sign(net_params[param_id]->count(),
               net_params[param_id]->gpu_data(),
               temp_[param_id]->mutable_gpu_data());
           caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               temp_[param_id]->gpu_data(),
-              history_[param_id]->mutable_gpu_data());
+              net_params[param_id]->mutable_gpu_diff());
         } else {
           LOG(FATAL) << "Unknown regularization type: " << regularization_type;
         }
       }
+
+      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+                net_params[param_id]->gpu_diff(), momentum,
+                history_[param_id]->mutable_gpu_data());
       // copy
       caffe_copy(net_params[param_id]->count(),
           history_[param_id]->gpu_data(),
@@ -526,28 +530,32 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
 
       Dtype local_rate = rate * net_params_lr[param_id];
       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
-          net_params[param_id]->cpu_diff(), momentum,
-          this->history_[param_id]->mutable_cpu_data());
+
       if (local_decay) {
         if (regularization_type == "L2") {
           // add weight decay
           caffe_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               net_params[param_id]->cpu_data(),
-              this->history_[param_id]->mutable_cpu_data());
+              net_params[param_id]->mutable_cpu_diff());
         } else if (regularization_type == "L1") {
           caffe_cpu_sign(net_params[param_id]->count(),
               net_params[param_id]->cpu_data(),
               this->temp_[param_id]->mutable_cpu_data());
           caffe_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               this->temp_[param_id]->cpu_data(),
-              this->history_[param_id]->mutable_cpu_data());
+              net_params[param_id]->mutable_cpu_diff());
         } else {
           LOG(FATAL) << "Unknown regularization type: " << regularization_type;
         }
       }
+
+      // update history
+      caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+                net_params[param_id]->cpu_diff(), momentum,
+                this->history_[param_id]->mutable_cpu_data());
+
       // compute udpate: step back then over step
       caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
           this->history_[param_id]->cpu_data(), -momentum,
@@ -569,28 +577,32 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
 
       Dtype local_rate = rate * net_params_lr[param_id];
       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
-      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
-          net_params[param_id]->gpu_diff(), momentum,
-          this->history_[param_id]->mutable_gpu_data());
+
       if (local_decay) {
         if (regularization_type == "L2") {
           // add weight decay
           caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               net_params[param_id]->gpu_data(),
-              this->history_[param_id]->mutable_gpu_data());
+              net_params[param_id]->mutable_gpu_diff());
         } else if (regularization_type == "L1") {
           caffe_gpu_sign(net_params[param_id]->count(),
               net_params[param_id]->gpu_data(),
               this->temp_[param_id]->mutable_gpu_data());
           caffe_gpu_axpy(net_params[param_id]->count(),
-              local_decay * local_rate,
+              local_decay,
               this->temp_[param_id]->gpu_data(),
-              this->history_[param_id]->mutable_gpu_data());
+              net_params[param_id]->mutable_gpu_diff());
         } else {
           LOG(FATAL) << "Unknown regularization type: " << regularization_type;
         }
       }
+
+      // update history
+      caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+                net_params[param_id]->gpu_diff(), momentum,
+                this->history_[param_id]->mutable_gpu_data());
+
       // compute udpate: step back then over step
       caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
           this->history_[param_id]->gpu_data(), -momentum,
@@ -635,7 +647,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
           caffe_axpy(net_params[param_id]->count(),
               local_decay,
               net_params[param_id]->cpu_data(),
-              this->history_[param_id]->mutable_cpu_data());
+              net_params[param_id]->mutable_cpu_diff());
         } else if (regularization_type == "L1") {
           caffe_cpu_sign(net_params[param_id]->count(),
               net_params[param_id]->cpu_data(),
@@ -643,7 +655,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
           caffe_axpy(net_params[param_id]->count(),
               local_decay,
               this->temp_[param_id]->cpu_data(),
-              this->history_[param_id]->mutable_cpu_data());
+              net_params[param_id]->mutable_cpu_diff());
         } else {
           LOG(FATAL) << "Unknown regularization type: " << regularization_type;
         }
@@ -691,7 +703,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
           caffe_gpu_axpy(net_params[param_id]->count(),
               local_decay,
               net_params[param_id]->gpu_data(),
-              this->history_[param_id]->mutable_gpu_data());
+              net_params[param_id]->mutable_gpu_diff());
         } else if (regularization_type == "L1") {
           caffe_gpu_sign(net_params[param_id]->count(),
               net_params[param_id]->gpu_data(),
@@ -699,7 +711,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
           caffe_gpu_axpy(net_params[param_id]->count(),
               local_decay,
               this->temp_[param_id]->gpu_data(),
-              this->history_[param_id]->mutable_gpu_data());
+              net_params[param_id]->mutable_gpu_diff());
         } else {
           LOG(FATAL) << "Unknown regularization type: " << regularization_type;
         }
diff --git a/src/caffe/test/test_adagrad_solver.cpp b/src/caffe/test/test_adagrad_solver.cpp
new file mode 100644 (file)
index 0000000..45cf200
--- /dev/null
@@ -0,0 +1,351 @@
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/text_format.h"
+
+#include "gtest/gtest.h"
+
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+using std::ostringstream;
+
+namespace caffe {
+
+template <typename TypeParam>
+class AdaGradSolverTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  AdaGradSolverTest() :
+      seed_(1701), num_(5), channels_(3), height_(10), width_(10) {}
+
+  // MockAdaGradSolver: an AdaGradSolver with public history.
+  class MockAdaGradSolver : public AdaGradSolver<Dtype> {
+   public:
+    explicit MockAdaGradSolver(const SolverParameter& param) :
+        AdaGradSolver<Dtype>(param) {}
+    vector<shared_ptr<Blob<Dtype> > >& history() { return this->history_; }
+    Dtype delta() { return this->param_.delta(); }
+  };
+
+  shared_ptr<MockAdaGradSolver> solver_;
+  int seed_;
+  int num_, channels_, height_, width_;
+
+  virtual void InitSolverFromProtoString(const string& proto) {
+    SolverParameter param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
+    // Disable saving a final snapshot so the tests don't pollute the user's
+    // working directory with useless snapshots.
+    param.set_snapshot_after_train(false);
+    // Set the solver_mode according to current Caffe::mode.
+    switch (Caffe::mode()) {
+      case Caffe::CPU:
+        param.set_solver_mode(SolverParameter_SolverMode_CPU);
+        break;
+      case Caffe::GPU:
+        param.set_solver_mode(SolverParameter_SolverMode_GPU);
+        break;
+      default:
+        LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
+    }
+    solver_.reset(new MockAdaGradSolver(param));
+  }
+
+  void RunLeastSquaresSolver(const Dtype learning_rate,
+      const Dtype weight_decay, const Dtype momentum, const int num_iters) {
+    ostringstream proto;
+    proto <<
+       "max_iter: " << num_iters << " "
+       "base_lr: " << learning_rate << " "
+       "lr_policy: 'fixed' "
+       "net_param { "
+       "  name: 'TestNetwork' "
+       "  layers: { "
+       "    name: 'data' "
+       "    type: DUMMY_DATA "
+       "    dummy_data_param { "
+       "      num: " << num_ << " "
+       "      channels: " << channels_ << " "
+       "      height: " << height_ << " "
+       "      width: " << width_ << " "
+       "      channels: 1 "
+       "      height: 1 "
+       "      width: 1 "
+       "      data_filler { "
+       "        type: 'gaussian' "
+       "        std: 1.0 "
+       "      } "
+       "    } "
+       "    top: 'data' "
+       "    top: 'targets' "
+       "  } "
+       "  layers: { "
+       "    name: 'innerprod' "
+       "    type: INNER_PRODUCT "
+       "    inner_product_param { "
+       "      num_output: 1 "
+       "      weight_filler { "
+       "        type: 'gaussian' "
+       "        std: 1.0 "
+       "      } "
+       "      bias_filler { "
+       "        type: 'gaussian' "
+       "        std: 1.0 "
+       "      } "
+       "    } "
+       "    bottom: 'data' "
+       "    top: 'innerprod' "
+       "  } "
+       "  layers: { "
+       "    name: 'loss' "
+       "    type: EUCLIDEAN_LOSS "
+       "    bottom: 'innerprod' "
+       "    bottom: 'targets' "
+       "  } "
+       "} ";
+    if (weight_decay != 0) {
+      proto << "weight_decay: " << weight_decay << " ";
+    }
+    if (momentum != 0) {
+      proto << "momentum: " << momentum << " ";
+    }
+    Caffe::set_random_seed(this->seed_);
+    this->InitSolverFromProtoString(proto.str());
+    this->solver_->Solve();
+  }
+
+  // Compute an update value given the current state of the train net,
+  // using the analytical formula for the least squares gradient.
+  // updated_params will store the updated weight and bias results,
+  // using the blobs' diffs to hold the update values themselves.
+  void ComputeLeastSquaresUpdate(const Dtype learning_rate,
+      const Dtype weight_decay, const Dtype momentum,
+      vector<shared_ptr<Blob<Dtype> > >* updated_params) {
+    const int N = num_;
+    const int D = channels_ * height_ * width_;
+
+    // Run a forward pass, and manually compute the update values from the
+    // result.
+    Net<Dtype>& net = *this->solver_->net();
+    vector<Blob<Dtype>*> empty_bottom_vec;
+    net.Forward(empty_bottom_vec);
+    ASSERT_TRUE(net.has_blob("data"));
+    const Blob<Dtype>& data = *net.blob_by_name("data");
+    ASSERT_TRUE(net.has_blob("targets"));
+    const Blob<Dtype>& targets = *net.blob_by_name("targets");
+    ASSERT_TRUE(net.has_layer("innerprod"));
+    const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
+        net.layer_by_name("innerprod")->blobs();
+    const int num_param_blobs = 2;
+    ASSERT_EQ(num_param_blobs, param_blobs.size());
+    const Blob<Dtype>& weights = *param_blobs[0];
+    const Blob<Dtype>& bias = *param_blobs[1];
+    ASSERT_EQ(D * N, data.count());
+    ASSERT_EQ(N, targets.count());
+    ASSERT_EQ(D, weights.count());
+    ASSERT_EQ(1, bias.count());
+
+    updated_params->clear();
+    updated_params->resize(num_param_blobs);
+    for (int i = 0; i < num_param_blobs; ++i) {
+      (*updated_params)[i].reset(new Blob<Dtype>());
+    }
+    Blob<Dtype>& updated_weights = *(*updated_params)[0];
+    updated_weights.ReshapeLike(weights);
+    Blob<Dtype>& updated_bias = *(*updated_params)[1];
+    updated_bias.ReshapeLike(bias);
+
+    for (int i = 0; i <= D; ++i) {
+      // Compute the derivative with respect to the ith weight (i.e., the ith
+      // element of the gradient).
+      Dtype grad = 0;
+      for (int j = 0; j <= D; ++j) {
+        // Compute element (i, j) of X^T * X.
+        Dtype element = 0;
+        for (int k = 0; k < N; ++k) {
+          // (i, k) in X^T (== (k, i) in X) times (k, j) in X.
+          const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
+          const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
+          element += element_i * element_j;
+        }
+        if (j == D) {
+          grad += element * bias.cpu_data()[0];
+        } else {
+          grad += element * weights.cpu_data()[j];
+        }
+      }
+      for (int k = 0; k < N; ++k) {
+        const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
+        grad -= element_i * targets.cpu_data()[k];
+      }
+      // Scale the gradient over the N samples.
+      grad /= N;
+      // Add the weight decay to the gradient.
+      grad += weight_decay *
+          ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
+      // Finally, compute update
+      const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
+      Dtype delta = solver_->delta();
+      ASSERT_EQ(2, history.size());  // 1 blob for weights, 1 for bias
+      Dtype update_value, temp;
+      if (i == D) {
+        temp = history[1]->cpu_data()[0];
+        temp += grad * grad;
+        update_value = learning_rate * grad / (std::sqrt(temp) + delta);
+        updated_bias.mutable_cpu_diff()[0] = update_value;
+        updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
+      } else {
+        temp = history[0]->cpu_data()[i];
+        temp += grad * grad;
+        update_value = learning_rate * grad / (std::sqrt(temp) + delta);
+        updated_weights.mutable_cpu_diff()[i] = update_value;
+        updated_weights.mutable_cpu_data()[i] =
+            weights.cpu_data()[i] - update_value;
+      }
+    }
+  }
+
+  void CheckLeastSquaresUpdate(
+      const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
+    const int D = channels_ * height_ * width_;
+
+    const Blob<Dtype>& updated_weights = *updated_params[0];
+    const Blob<Dtype>& updated_bias = *updated_params[1];
+
+    Net<Dtype>& net = *this->solver_->net();
+    ASSERT_TRUE(net.has_layer("innerprod"));
+    const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
+        net.layer_by_name("innerprod")->blobs();
+    ASSERT_EQ(2, param_blobs.size());
+    const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
+    ASSERT_EQ(D, solver_updated_weights.count());
+    const double kPrecision = 1e-3;
+    const double kMinPrecision = 1e-7;
+    for (int i = 0; i < D; ++i) {
+      const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
+      const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
+      const Dtype error_margin = std::max(kMinPrecision, kPrecision *
+          std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
+      EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
+    }
+    const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
+    ASSERT_EQ(1, solver_updated_bias_blob.count());
+    const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
+    const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
+    const Dtype error_margin = std::max(kMinPrecision, kPrecision *
+          std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
+    EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
+
+    // Check the solver's history -- should contain the previous update value.
+//    vector<shared_ptr<Blob<Dtype> > >& history = this->solver_->history();
+//    ASSERT_EQ(2, history.size());
+//    for (int i = 0; i < D; ++i) {
+//      const Dtype expected_history = updated_weights.cpu_diff()[i];
+//      const Dtype solver_history = history[0]->cpu_data()[i];
+//      const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
+//          std::min(fabs(expected_history), fabs(solver_history)));
+//      EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
+//    }
+//    const Dtype expected_history = updated_bias.cpu_diff()[0];
+//    const Dtype solver_history = history[1]->cpu_data()[0];
+//    const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
+//        std::min(fabs(expected_history), fabs(solver_history)));
+//    EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
+  }
+
+  // Test that the correct update is computed for a regularized least squares
+  // problem:
+  //
+  //            E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
+  //   \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
+  //
+  // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
+  // w \in R^{(d+1) x 1} ((d+1)th element is the bias)
+  // y \in R^{n x 1}
+  // lambda is weight_decay
+  //
+  // TestLeastSquaresUpdate works "inductively", assuming that the solver
+  // correctly updates the net K (= iter_to_check) times, then given the history
+  // from the Kth update, we compute the (K+1)th update and check that it
+  // matches the solver's (K+1)th update.
+  void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
+      const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
+      const int iter_to_check = 0) {
+    // Initialize the solver and run K (= iter_to_check) solver iterations.
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);
+
+    // Compute the (K+1)th update using the analytic least squares gradient.
+    vector<shared_ptr<Blob<Dtype> > > updated_params;
+    ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
+                              &updated_params);
+
+    // Reinitialize the solver and run K+1 solver iterations.
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+                          iter_to_check + 1);
+
+    // Check that the solver's solution matches ours.
+    CheckLeastSquaresUpdate(updated_params);
+  }
+};
+
+TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices);
+
+TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->TestLeastSquaresUpdate();
+}
+
+TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneTenth) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.1;
+  this->TestLeastSquaresUpdate(kLearningRate);
+}
+
+TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.5;
+  this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
+}
+/*
+TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithMomentum) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.0;
+  const Dtype kMomentum = 0.5;
+  const int kNumIters = 1;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.0;
+  const Dtype kMomentum = 0.5;
+  const int kNumIters = 5;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+*/
+TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.1;
+  const Dtype kMomentum = 0.0;
+  const int kNumIters = 5;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+}  // namespace caffe
diff --git a/src/caffe/test/test_nesterov_solver.cpp b/src/caffe/test/test_nesterov_solver.cpp
new file mode 100644 (file)
index 0000000..f2fcba3
--- /dev/null
@@ -0,0 +1,351 @@
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/text_format.h"
+
+#include "gtest/gtest.h"
+
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+using std::ostringstream;
+
+namespace caffe {
+
+template <typename TypeParam>
+class NesterovSolverTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  NesterovSolverTest() :
+      seed_(1701), num_(5), channels_(3), height_(10), width_(10) {}
+
+  // MockNesterovSolver: an NesterovSolver with public history.
+  class MockNesterovSolver : public NesterovSolver<Dtype> {
+   public:
+    explicit MockNesterovSolver(const SolverParameter& param) :
+        NesterovSolver<Dtype>(param) {}
+    vector<shared_ptr<Blob<Dtype> > >& history() { return this->history_; }
+  };
+
+  shared_ptr<MockNesterovSolver> solver_;
+  int seed_;
+  int num_, channels_, height_, width_;
+
+  virtual void InitSolverFromProtoString(const string& proto) {
+    SolverParameter param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
+    // Disable saving a final snapshot so the tests don't pollute the user's
+    // working directory with useless snapshots.
+    param.set_snapshot_after_train(false);
+    // Set the solver_mode according to current Caffe::mode.
+    switch (Caffe::mode()) {
+      case Caffe::CPU:
+        param.set_solver_mode(SolverParameter_SolverMode_CPU);
+        break;
+      case Caffe::GPU:
+        param.set_solver_mode(SolverParameter_SolverMode_GPU);
+        break;
+      default:
+        LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
+    }
+    solver_.reset(new MockNesterovSolver(param));
+  }
+
+  void RunLeastSquaresSolver(const Dtype learning_rate,
+      const Dtype weight_decay, const Dtype momentum, const int num_iters) {
+    ostringstream proto;
+    proto <<
+       "max_iter: " << num_iters << " "
+       "base_lr: " << learning_rate << " "
+       "lr_policy: 'fixed' "
+       "net_param { "
+       "  name: 'TestNetwork' "
+       "  layers: { "
+       "    name: 'data' "
+       "    type: DUMMY_DATA "
+       "    dummy_data_param { "
+       "      num: " << num_ << " "
+       "      channels: " << channels_ << " "
+       "      height: " << height_ << " "
+       "      width: " << width_ << " "
+       "      channels: 1 "
+       "      height: 1 "
+       "      width: 1 "
+       "      data_filler { "
+       "        type: 'gaussian' "
+       "        std: 1.0 "
+       "      } "
+       "    } "
+       "    top: 'data' "
+       "    top: 'targets' "
+       "  } "
+       "  layers: { "
+       "    name: 'innerprod' "
+       "    type: INNER_PRODUCT "
+       "    inner_product_param { "
+       "      num_output: 1 "
+       "      weight_filler { "
+       "        type: 'gaussian' "
+       "        std: 1.0 "
+       "      } "
+       "      bias_filler { "
+       "        type: 'gaussian' "
+       "        std: 1.0 "
+       "      } "
+       "    } "
+       "    bottom: 'data' "
+       "    top: 'innerprod' "
+       "  } "
+       "  layers: { "
+       "    name: 'loss' "
+       "    type: EUCLIDEAN_LOSS "
+       "    bottom: 'innerprod' "
+       "    bottom: 'targets' "
+       "  } "
+       "} ";
+    if (weight_decay != 0) {
+      proto << "weight_decay: " << weight_decay << " ";
+    }
+    if (momentum != 0) {
+      proto << "momentum: " << momentum << " ";
+    }
+    Caffe::set_random_seed(this->seed_);
+    this->InitSolverFromProtoString(proto.str());
+    this->solver_->Solve();
+  }
+
+  // Compute an update value given the current state of the train net,
+  // using the analytical formula for the least squares gradient.
+  // updated_params will store the updated weight and bias results,
+  // using the blobs' diffs to hold the update values themselves.
+  void ComputeLeastSquaresUpdate(const Dtype learning_rate,
+      const Dtype weight_decay, const Dtype momentum,
+      vector<shared_ptr<Blob<Dtype> > >* updated_params) {
+    const int N = num_;
+    const int D = channels_ * height_ * width_;
+
+    // Run a forward pass, and manually compute the update values from the
+    // result.
+    Net<Dtype>& net = *this->solver_->net();
+    vector<Blob<Dtype>*> empty_bottom_vec;
+    net.Forward(empty_bottom_vec);
+    ASSERT_TRUE(net.has_blob("data"));
+    const Blob<Dtype>& data = *net.blob_by_name("data");
+    ASSERT_TRUE(net.has_blob("targets"));
+    const Blob<Dtype>& targets = *net.blob_by_name("targets");
+    ASSERT_TRUE(net.has_layer("innerprod"));
+    const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
+        net.layer_by_name("innerprod")->blobs();
+    const int num_param_blobs = 2;
+    ASSERT_EQ(num_param_blobs, param_blobs.size());
+    const Blob<Dtype>& weights = *param_blobs[0];
+    const Blob<Dtype>& bias = *param_blobs[1];
+    ASSERT_EQ(D * N, data.count());
+    ASSERT_EQ(N, targets.count());
+    ASSERT_EQ(D, weights.count());
+    ASSERT_EQ(1, bias.count());
+
+    updated_params->clear();
+    updated_params->resize(num_param_blobs);
+    for (int i = 0; i < num_param_blobs; ++i) {
+      (*updated_params)[i].reset(new Blob<Dtype>());
+    }
+    Blob<Dtype>& updated_weights = *(*updated_params)[0];
+    updated_weights.ReshapeLike(weights);
+    Blob<Dtype>& updated_bias = *(*updated_params)[1];
+    updated_bias.ReshapeLike(bias);
+
+    for (int i = 0; i <= D; ++i) {
+      // Compute the derivative with respect to the ith weight (i.e., the ith
+      // element of the gradient).
+      Dtype grad = 0;
+      for (int j = 0; j <= D; ++j) {
+        // Compute element (i, j) of X^T * X.
+        Dtype element = 0;
+        for (int k = 0; k < N; ++k) {
+          // (i, k) in X^T (== (k, i) in X) times (k, j) in X.
+          const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
+          const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
+          element += element_i * element_j;
+        }
+        if (j == D) {
+          grad += element * bias.cpu_data()[0];
+        } else {
+          grad += element * weights.cpu_data()[j];
+        }
+      }
+      for (int k = 0; k < N; ++k) {
+        const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
+        grad -= element_i * targets.cpu_data()[k];
+      }
+      // Scale the gradient over the N samples.
+      grad /= N;
+      // Add the weight decay to the gradient.
+      grad += weight_decay *
+          ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
+      // Finally, add any momentum.
+      const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
+      ASSERT_EQ(2, history.size());  // 1 blob for weights, 1 for bias
+      Dtype update_value = learning_rate * grad, temp;
+      if (i == D) {
+        temp = history[1]->cpu_data()[0] * momentum;
+        update_value += temp;  // update history
+        // step back then over-step
+        update_value = (1 + momentum) * update_value - temp;
+        updated_bias.mutable_cpu_diff()[0] = update_value;
+        updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
+      } else {
+        temp = history[0]->cpu_data()[i] * momentum;
+        update_value += temp;  // update history
+        // step back then over-step
+        update_value = (1 + momentum) * update_value - temp;
+        updated_weights.mutable_cpu_diff()[i] = update_value;
+        updated_weights.mutable_cpu_data()[i] =
+            weights.cpu_data()[i] - update_value;
+      }
+    }
+  }
+
+  void CheckLeastSquaresUpdate(
+      const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
+    const int D = channels_ * height_ * width_;
+
+    const Blob<Dtype>& updated_weights = *updated_params[0];
+    const Blob<Dtype>& updated_bias = *updated_params[1];
+
+    Net<Dtype>& net = *this->solver_->net();
+    ASSERT_TRUE(net.has_layer("innerprod"));
+    const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
+        net.layer_by_name("innerprod")->blobs();
+    ASSERT_EQ(2, param_blobs.size());
+    const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
+    ASSERT_EQ(D, solver_updated_weights.count());
+    const double kPrecision = 1e-3;
+    const double kMinPrecision = 1e-7;
+    for (int i = 0; i < D; ++i) {
+      const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
+      const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
+      const Dtype error_margin = std::max(kMinPrecision, kPrecision *
+          std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
+      EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
+    }
+    const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
+    ASSERT_EQ(1, solver_updated_bias_blob.count());
+    const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
+    const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
+    const Dtype error_margin = std::max(kMinPrecision, kPrecision *
+          std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
+    EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);
+
+    // Check the solver's history -- should contain the previous update value.
+//    vector<shared_ptr<Blob<Dtype> > >& history = this->solver_->history();
+//    ASSERT_EQ(2, history.size());
+//    for (int i = 0; i < D; ++i) {
+//      const Dtype expected_history = updated_weights.cpu_diff()[i];
+//      const Dtype solver_history = history[0]->cpu_data()[i];
+//      const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
+//          std::min(fabs(expected_history), fabs(solver_history)));
+//      EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
+//    }
+//    const Dtype expected_history = updated_bias.cpu_diff()[0];
+//    const Dtype solver_history = history[1]->cpu_data()[0];
+//    const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
+//        std::min(fabs(expected_history), fabs(solver_history)));
+//    EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
+  }
+
+  // Test that the correct update is computed for a regularized least squares
+  // problem:
+  //
+  //            E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
+  //   \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
+  //
+  // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
+  // w \in R^{(d+1) x 1} ((d+1)th element is the bias)
+  // y \in R^{n x 1}
+  // lambda is weight_decay
+  //
+  // TestLeastSquaresUpdate works "inductively", assuming that the solver
+  // correctly updates the net K (= iter_to_check) times, then given the history
+  // from the Kth update, we compute the (K+1)th update and check that it
+  // matches the solver's (K+1)th update.
+  void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
+      const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
+      const int iter_to_check = 0) {
+    // Initialize the solver and run K (= iter_to_check) solver iterations.
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);
+
+    // Compute the (K+1)th update using the analytic least squares gradient.
+    vector<shared_ptr<Blob<Dtype> > > updated_params;
+    ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
+                              &updated_params);
+
+    // Reinitialize the solver and run K+1 solver iterations.
+    RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+                          iter_to_check + 1);
+
+    // Check that the solver's solution matches ours.
+    CheckLeastSquaresUpdate(updated_params);
+  }
+};
+
+TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices);
+
+TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->TestLeastSquaresUpdate();
+}
+
+TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneTenth) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.1;
+  this->TestLeastSquaresUpdate(kLearningRate);
+}
+
+TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.5;
+  this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
+}
+
+TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.0;
+  const Dtype kMomentum = 0.5;
+  const int kNumIters = 1;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 1.0;
+  const Dtype kWeightDecay = 0.0;
+  const Dtype kMomentum = 0.5;
+  const int kNumIters = 5;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kLearningRate = 0.01;
+  const Dtype kWeightDecay = 0.1;
+  const Dtype kMomentum = 0.9;
+  const int kNumIters = 5;
+  for (int i = 0; i <= kNumIters; ++i) {
+    this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+  }
+}
+
+}  // namespace caffe