solver
authorYangqing Jia <jiayq84@gmail.com>
Mon, 30 Sep 2013 21:08:39 +0000 (14:08 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 30 Sep 2013 21:08:39 +0000 (14:08 -0700)
src/caffe/blob.cpp
src/caffe/layers/data_layer.cpp
src/caffe/optimization/solver.cpp
src/caffe/proto/caffe.proto
src/caffe/pyutil/convert.py
src/caffe/test/data/simple_linear_regression_data.py [new file with mode: 0644]
src/caffe/test/test_solver_linear_regression.cpp [new file with mode: 0644]
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.hpp

index d31ba72..6838036 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
 
 namespace caffe {
 
@@ -86,9 +87,24 @@ Dtype* Blob<Dtype>::mutable_gpu_diff() {
 
 template <typename Dtype>
 void Blob<Dtype>::Update() {
-  // not implemented yet.
-  LOG(FATAL) << "not implemented";
   // We will perform update based on where the data is located.
+  switch (data_->head()) {
+  case SyncedMemory::HEAD_AT_CPU:
+    // perform computation on CPU
+    caffe_axpy<Dtype>(count_, Dtype(-1),
+        reinterpret_cast<const Dtype*>(diff_->cpu_data()),
+        reinterpret_cast<Dtype*>(data_->mutable_cpu_data()));
+    break;
+  case SyncedMemory::HEAD_AT_GPU:
+  case SyncedMemory::SYNCED:
+    // perform computation on GPU
+    caffe_gpu_axpy<Dtype>(count_, Dtype(-1),
+        reinterpret_cast<const Dtype*>(diff_->gpu_data()),
+        reinterpret_cast<Dtype*>(data_->mutable_gpu_data()));
+    break;
+  default:
+    LOG(FATAL) << "Syncedmem not initialized.";
+  }
 }
 
 template <typename Dtype>
index e527bd3..d42a810 100644 (file)
@@ -40,7 +40,7 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   // label
   (*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
   // datum size
-  datum_size_ = datum.data().size();
+    datum_size_ = datum.channels() * datum.height() * datum.width();
 }
 
 template <typename Dtype>
@@ -51,13 +51,25 @@ void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   Dtype* top_label = (*top)[1]->mutable_cpu_data();
   const Dtype scale = this->layer_param_.scale();
   const Dtype subtraction = this->layer_param_.subtraction();
+  // LOG(ERROR) << "Debug code on";
+  // if (true) {
+  //   iter_->SeekToFirst();
+  // }
   for (int i = 0; i < this->layer_param_.batchsize(); ++i) {
     // get a blob
     datum.ParseFromString(iter_->value().ToString());
     const string& data = datum.data();
-    for (int j = 0; j < datum_size_; ++j) {
-      top_data[i * datum_size_ + j] =
-          (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+    // we will prefer to use data() first, and then try float_data()
+    if (data.size()) {
+      for (int j = 0; j < datum_size_; ++j) {
+        top_data[i * datum_size_ + j] =
+            (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+      }
+    } else {
+      for (int j = 0; j < datum_size_; ++j) {
+        top_data[i * datum_size_ + j] =
+            (datum.float_data(j) * scale) - subtraction;
+      }
     }
     top_label[i] = datum.label();
     // go to the next iter
index 4f343c8..2df872e 100644 (file)
@@ -1,12 +1,16 @@
 // Copyright Yangqing Jia 2013
 
+#include <algorithm>
 #include <fstream>
 #include <string>
 
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/net.hpp"
+#include "caffe/util/math_functions.hpp"
 #include "caffe/optimization/solver.hpp"
 
+using std::max;
+using std::min;
 using std::stringstream;
 using std::ofstream;
 
@@ -23,13 +27,16 @@ void Solver<Dtype>::Solve(Net<Dtype>* net) {
   while (iter_++ < param_.max_iter()) {
     Dtype loss = net_->ForwardBackWard(bottom_vec);
     ComputeUpdateValue();
-    net->Update();
+    net_->Update();
 
     // Check if we need to do snapshot
-    if (iter_ % param_.snapshot()) {
+    if (param_.snapshot() > 0 && iter_ % param_.snapshot()) {
       // TODO(Yangqing): snapshot
+      NOT_IMPLEMENTED;
+    }
+    if (param_.display()) {
+      LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
     }
-    LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
   }
   LOG(INFO) << "Optimization Done.";
 }
@@ -62,20 +69,20 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
   } else if (lr_policy == "inv") {
     rate = this->param_.base_lr() *
         pow(Dtype(1) + this->param_.gamma() * this->iter_,
-            this->param_.power());
+            this->param_.power());
   } else {
     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
   }
-  rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
+  rate = min(max(rate, Dtype(this->param_.min_lr())),
+      Dtype(this->param_.max_lr()));
   return rate;
 }
 
 template <typename Dtype>
 void SGDSolver<Dtype>::ComputeUpdateValue() {
   // First of all, see if we need to initialize the history
-  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
-  if (this->iter_ == 1 && this->param_.momentum() > 0) {
-    LOG(INFO) << "Using momentum " << this->param_.momentum();
+  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  if (history_.size() == 0 && this->param_.momentum() > 0) {
     for (int i = 0; i < net_params.size(); ++i) {
       const Blob<Dtype>* net_param = net_params[i].get();
       history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
@@ -85,28 +92,47 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   }
   // get the learning rate
   Dtype rate = GetLearningRate();
-  if (this->param_.momentum == 0) {
+  if (this->param_.momentum() == 0) {
     for (int i = 0; i < net_params.size(); ++i) {
       switch (Caffe::mode()) {
       case Caffe::CPU:
         caffe_scal(net_params[i]->count(), rate,
-            net_params[i]->mutable_cpu_data());
+            net_params[i]->mutable_cpu_diff());
         break;
       case Caffe::GPU:
         caffe_gpu_scal(net_params[i]->count(), rate,
-            net_params[i]->mutable_gpu_data());
+            net_params[i]->mutable_gpu_diff());
         break;
       default:
         LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
       }
     }
   } else {
-    NOT_IMPLEMENTED;
+    // Need to maintain momentum
+    for (int i = 0; i < net_params.size(); ++i) {
+      switch (Caffe::mode()) {
+      case Caffe::CPU:
+        caffe_axpby(net_params[i]->count(), rate,
+            net_params[i]->cpu_diff(), Dtype(this->param_.momentum()),
+            history_[i]->mutable_cpu_data());
+        caffe_copy(net_params[i]->count(), history_[i]->cpu_data(),
+            net_params[i]->mutable_cpu_diff());
+        break;
+      case Caffe::GPU:
+        caffe_gpu_axpby(net_params[i]->count(), rate,
+            net_params[i]->gpu_diff(), Dtype(this->param_.momentum()),
+            history_[i]->mutable_gpu_data());
+        caffe_gpu_copy(net_params[i]->count(), history_[i]->gpu_data(),
+            net_params[i]->mutable_gpu_diff());
+        break;
+      default:
+        LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+      }
+    }
   }
 }
 
-
-
 INSTANTIATE_CLASS(Solver);
+INSTANTIATE_CLASS(SGDSolver);
 
 }  // namespace caffe
\ No newline at end of file
index 012ad56..c3632c1 100644 (file)
@@ -18,6 +18,8 @@ message Datum {
   // the actual image data, in bytes
   optional bytes data = 4;
   optional int32 label = 5;
+  // Optionally, the datum could also hold float data.
+  repeated float float_data = 6;
 }
 
 message FillerParameter {
@@ -84,7 +86,7 @@ message SolverParameter {
   optional float base_lr = 1; // The base learning rate
   optional int32 display = 2; // display options. 0 = no display
   optional int32 max_iter = 3; // the maximum number of iterations
-  optional int32 snapshot = 4; // The snapshot interval
+  optional int32 snapshot = 4 [default = 0]; // The snapshot interval
   optional string lr_policy = 5; // The learning rate decay policy.
   optional float min_lr = 6 [default = 0]; // The mininum learning rate
   optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
index efcea42..8a76a50 100644 (file)
@@ -20,9 +20,10 @@ def array_to_blobproto(arr):
 def array_to_datum(arr):
   if arr.ndim != 3:
     raise ValueError('Incorrect array shape.')
-  if arr.dtype != np.uint8:
-    raise TypeError('Input array has to be of type uint8.')
   datum = caffe_pb2.Datum()
   datum.channels, datum.height, datum.width = arr.shape
-  datum.data = arr.tostring()
+  if arr.dtype == np.uint8:
+    datum.data = arr.tostring()
+  else:
+    datum.float_data.extend(arr.flat)
   return datum
diff --git a/src/caffe/test/data/simple_linear_regression_data.py b/src/caffe/test/data/simple_linear_regression_data.py
new file mode 100644 (file)
index 0000000..e8fe840
--- /dev/null
@@ -0,0 +1,16 @@
+"""This script generates the mnist train and test leveldbs used in the
+test.
+"""
+from caffe.pyutil import convert
+import numpy as np
+import leveldb
+
+db = leveldb.LevelDB('simple-linear-regression-leveldb')
+
+for i in range(1000):
+  label = np.random.randint(2) * 2 - 1
+  arr = np.random.randn(2,1,1) + label
+  datum = convert.array_to_datum(arr)
+  datum.label = label
+  db.Put('%d' % (i), datum.SerializeToString())
+del db
diff --git a/src/caffe/test/test_solver_linear_regression.cpp b/src/caffe/test/test_solver_linear_regression.cpp
new file mode 100644 (file)
index 0000000..f368908
--- /dev/null
@@ -0,0 +1,76 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cuda_runtime.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <gtest/gtest.h>
+
+#include <cstring>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/optimization/solver.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+class SolverTest : public ::testing::Test {};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(SolverTest, Dtypes);
+
+TYPED_TEST(SolverTest, TestSolve) {
+  Caffe::set_mode(Caffe::GPU);
+
+  NetParameter net_param;
+  ReadProtoFromTextFile("caffe/test/data/linear_regression.prototxt",
+      &net_param);
+  // check if things are right
+  EXPECT_EQ(net_param.layers_size(), 3);
+  EXPECT_EQ(net_param.input_size(), 0);
+  vector<Blob<TypeParam>*> bottom_vec;
+  Net<TypeParam> caffe_net(net_param, bottom_vec);
+  EXPECT_EQ(caffe_net.layer_names().size(), 3);
+  EXPECT_EQ(caffe_net.blob_names().size(), 3);
+
+  // Run the network without training.
+  LOG(ERROR) << "Performing Forward";
+  caffe_net.Forward(bottom_vec);
+  LOG(ERROR) << "Performing Backward";
+  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
+
+  SolverParameter solver_param;
+  solver_param.set_base_lr(0.1);
+  solver_param.set_display(0);
+  solver_param.set_max_iter(100);
+  solver_param.set_lr_policy("inv");
+  solver_param.set_gamma(1.);
+  solver_param.set_power(0.75);
+  solver_param.set_momentum(0.9);
+
+  LOG(ERROR) << "Starting Optimization";
+  SGDSolver<TypeParam> solver(solver_param);
+  solver.Solve(&caffe_net);
+  LOG(ERROR) << "Optimization Done.";
+  LOG(ERROR) << "Weight: " << caffe_net.params()[0]->cpu_data()[0] << ", "
+      << caffe_net.params()[0]->cpu_data()[1];
+  LOG(ERROR) << "Bias: " << caffe_net.params()[1]->cpu_data()[0];
+
+  EXPECT_GE(caffe_net.params()[0]->cpu_data()[0], 0.3);
+  EXPECT_LE(caffe_net.params()[0]->cpu_data()[0], 0.35);
+
+  EXPECT_GE(caffe_net.params()[0]->cpu_data()[1], 0.3);
+  EXPECT_LE(caffe_net.params()[0]->cpu_data()[1], 0.35);
+
+  EXPECT_GE(caffe_net.params()[1]->cpu_data()[0], -0.01);
+  EXPECT_LE(caffe_net.params()[1]->cpu_data()[0], 0.01);
+}
+
+}  // namespace caffe
index a3a94cd..a074545 100644 (file)
@@ -103,6 +103,19 @@ template <>
 void caffe_axpy<double>(const int N, const double alpha, const double* X,
     double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
 
+
+template <>
+void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
+    float* Y) {
+  CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
+}
+
+template <>
+void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
+    double* Y) {
+  CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
+}
+
 template <>
 void caffe_axpby<float>(const int N, const float alpha, const float* X,
     const float beta, float* Y) {
@@ -126,6 +139,16 @@ void caffe_copy<double>(const int N, const double* X, double* Y) {
 }
 
 template <>
+void caffe_gpu_copy<float>(const int N, const float* X, float* Y) {
+  CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
+}
+
+template <>
+void caffe_gpu_copy<double>(const int N, const double* X, double* Y) {
+  CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
+}
+
+template <>
 void caffe_scal<float>(const int N, const float alpha, float *X) {
   cblas_sscal(N, alpha, X, 1);
 }
@@ -146,6 +169,20 @@ void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
 }
 
 template <>
+void caffe_gpu_axpby<float>(const int N, const float alpha, const float* X,
+    const float beta, float* Y) {
+  caffe_gpu_scal<float>(N, beta, Y);
+  caffe_gpu_axpy<float>(N, alpha, X, Y);
+}
+
+template <>
+void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
+    const double beta, double* Y) {
+  caffe_gpu_scal<double>(N, beta, Y);
+  caffe_gpu_axpy<double>(N, alpha, X, Y);
+}
+
+template <>
 void caffe_sqr<float>(const int n, const float* a, float* y) {
   vsSqr(n, a, y);
 }
index e3ace98..a71f28e 100644 (file)
@@ -40,13 +40,24 @@ void caffe_axpy(const int N, const Dtype alpha, const Dtype* X,
     Dtype* Y);
 
 template <typename Dtype>
+void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
+    Dtype* Y);
+
+template <typename Dtype>
 void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
     const Dtype beta, Dtype* Y);
 
 template <typename Dtype>
+void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
+    const Dtype beta, Dtype* Y);
+
+template <typename Dtype>
 void caffe_copy(const int N, const Dtype *X, Dtype *Y);
 
 template <typename Dtype>
+void caffe_gpu_copy(const int N, const Dtype *X, Dtype *Y);
+
+template <typename Dtype>
 void caffe_scal(const int N, const Dtype alpha, Dtype *X);
 
 template <typename Dtype>