updated a bunch of things, ready to test if it breaks things
authorYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 23:59:43 +0000 (16:59 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 23:59:43 +0000 (16:59 -0700)
src/caffe/blob.cpp
src/caffe/blob.hpp
src/caffe/net.cpp
src/caffe/net.hpp
src/caffe/optimization/solver.cpp [new file with mode: 0644]
src/caffe/optimization/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.hpp

index 6162740..35e5b04 100644 (file)
@@ -1,5 +1,6 @@
 // Copyright 2013 Yangqing Jia
 
+#include <cuda_runtime.h>
 #include <cublas_v2.h>
 
 #include "caffe/blob.hpp"
@@ -11,6 +12,7 @@ namespace caffe {
 template <typename Dtype>
 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
     const int width) {
+  int old_count = count_;
   CHECK_GE(num, 0);
   CHECK_GE(channels, 0);
   CHECK_GE(height, 0);
@@ -21,8 +23,10 @@ void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
   width_ = width;
   count_ = num_ * channels_ * height_ * width_;
   if (count_) {
-    data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
-    diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+    if (old_count != count_) {
+      data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+      diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+    }
   } else {
     data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
     diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
@@ -91,6 +95,40 @@ void Blob<Dtype>::Update() {
 }
 
 template <typename Dtype>
+void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
+  if (num_ != source.num() || channels_ != source.channels() ||
+      height_ != source.height() || width_ != source.width()) {
+    if (reshape) {
+      Reshape(source.num(), source.channels(), source.height(), source.width());
+    } else {
+      LOG(FATAL) << "Trying to copy blobs of different sizes.";
+    }
+  }
+  switch (Caffe::mode()) {
+  case Caffe::GPU:
+    if (copy_diff) {
+      CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(), source.gpu_diff(),
+          sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+    } else {
+      CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(), source.gpu_data(),
+          sizeof(Dtype) * count_, cudaMemcpyDeviceToDevice));
+    }
+    break;
+  case Caffe::CPU:
+    if (copy_diff) {
+      memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+          sizeof(Dtype) * count_);
+    } else {
+      memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+        sizeof(Dtype) * count_);
+    }
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode.";
+  }
+}
+
+template <typename Dtype>
 void Blob<Dtype>::FromProto(const BlobProto& proto) {
   Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
   // copy data
index f0e19c2..f31d3b0 100644 (file)
@@ -29,6 +29,10 @@ class Blob {
       const int w = 0) const {
     return ((n * channels_ + c) * height_ + h) * width_ + w;
   }
+  // Copy from source. If copy_diff is false, we copy the data; if copy_diff
+  // is true, we copy the diff.
+  void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,
+      bool reshape = false);
 
   inline Dtype data_at(const int n, const int c, const int h,
       const int w) const {
index c6dfce1..22d2743 100644 (file)
@@ -41,6 +41,8 @@ Net<Dtype>::Net(const NetParameter& param,
   // For each layer, set up their input and output
   bottom_vecs_.resize(param.layers_size());
   top_vecs_.resize(param.layers_size());
+  bottom_id_vecs_.resize(param.layers_size());
+  top_id_vecs_.resize(param.layers_size());
   for (int i = 0; i < param.layers_size(); ++i) {
     const LayerConnection& layer_connection = param.layers(i);
     const LayerParameter& layer_param = layer_connection.layer();
@@ -57,6 +59,7 @@ Net<Dtype>::Net(const NetParameter& param,
       LOG(INFO) << layer_param.name() << " <- " << blob_name;
       bottom_vecs_[i].push_back(
           blobs_[blob_name_to_idx[blob_name]].get());
+      bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]);
       available_blobs.erase(blob_name);
     }
     for (int j = 0; j < layer_connection.top_size(); ++j) {
@@ -71,6 +74,7 @@ Net<Dtype>::Net(const NetParameter& param,
       blob_name_to_idx[blob_name] = blob_names_.size() - 1;
       available_blobs.insert(blob_name);
       top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
+      top_id_vecs_[i].push_back(blob_names_.size() - 1);
     }
   }
   LOG(INFO) << "Checking top blobs.";
@@ -95,7 +99,7 @@ Net<Dtype>::Net(const NetParameter& param,
   for (int i = 0; i < layers_.size(); ++i) {
     LOG(INFO) << "Setting up " << layer_names_[i];
     layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
-    vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i].params();
+    vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i]->params();
     for (int j = 0; j < layer_params.size(); ++j) {
       params_.push_back(layer_params[j]);
     }
@@ -109,15 +113,14 @@ void Net<Dtype>::Forward(const vector<Blob<Dtype>*> & bottom,
     vector<Blob<Dtype>*>* top) {
   // Copy bottom to internal bottom
   for (int i = 0; i < bottom.size(); ++i) {
-    memcpy(blobs_[net_input_blob_indices_[i]]->mutable_cpu_data(),
-        bottom[i]->cpu_data(), sizeof(Dtype) * bottom[i]->count());
+    blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]);
   }
   for (int i = 0; i < layers_.size(); ++i) {
     layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
   }
   // Copy internal top to top
   for (int i = 0; i < (*top).size(); ++i) {
-    NOT_IMPLEMENTED;
+    (*top)[i]->CopyFrom(*blobs_[net_output_blob_indices_[i]]);
   }
 }
 
@@ -167,11 +170,26 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
     param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
   }
-  for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
-    param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+  for (int i = 0; i < net_output_blob_indices_.size(); ++i) {
+    param->add_top(blob_names_[net_output_blob_indices_[i]]);
   }
   for (int i = 0; i < layers_.size(); ++i) {
     LayerConnection* layer_connection = param->add_layers();
+    for (int j = 0; j < bottom_id_vecs_[i].size(); ++i) {
+      layer_connection->add_bottom(blob_names_[bottom_id_vecs_[i][j]]);
+    }
+    for (int j = 0; j < top_id_vecs_[i].size(); ++i) {
+      layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]);
+    }
+    LayerParameter* layer_parameter = layer_connection->mutable_layer();
+    layers_[i]->ToProto(layer_parameter);
+  }
+}
+
+template <typename Dtype>
+void Net<Dtype>::Update() {
+  for (int i = 0; i < params_.size(); ++i) {
+    params_[i]->Update();
   }
 }
 
index 719267c..1f1a803 100644 (file)
@@ -31,6 +31,12 @@ class Net {
   // been provided during the forward pass.
   Dtype Backward();
 
+  Dtype ForwardBackWard(const vector<Blob<Dtype>* > & bottom,
+      vector<Blob<Dtype>*>* top) {
+    Forward(bottom, top);
+    return Backward();
+  }
+
   // For an already initialized net, CopyTrainedLayersFrom() copies the already
   // trained layers from another net parameter instance.
   void CopyTrainedLayersFrom(const NetParameter& param);
@@ -49,6 +55,8 @@ class Net {
   inline const vector<shared_ptr<Layer<Dtype> > >& layers() { return layers_; }
   // returns the parameters
   vector<shared_ptr<Blob<Dtype> > >& params() { return params_; };
+  // Updates the network
+  void Update();
 
  protected:
   // Individual layers in the net
@@ -61,9 +69,11 @@ class Net {
   // bottom_vecs stores the vectors containing the input for each layer, except
   // for the first layer whose bottom vec is provided by the network's input.
   vector<vector<Blob<Dtype>*> > bottom_vecs_;
+  vector<vector<int> > bottom_id_vecs_;
   // top_vecs stores the vectors containing the output for each layer, except
   // for the last layer (likewise)
   vector<vector<Blob<Dtype>*> > top_vecs_;
+  vector<vector<int> > top_id_vecs_;
   // blob indices for the input and the output of the net.
   vector<int> net_input_blob_indices_;
   vector<int> net_output_blob_indices_;
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp
new file mode 100644 (file)
index 0000000..b9055d2
--- /dev/null
@@ -0,0 +1,113 @@
+// Copyright Yangqing Jia 2013
+
+#include <fstream>
+#include <string>
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/net.hpp"
+#include "caffe/optimization/solver.hpp"
+
+using std::stringstream;
+using std::ofstream;
+
+namespace caffe {
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(Net<Dtype>* net) {
+  net_ = net;
+  LOG(INFO) << "Solving net " << net_->name();
+  iter_ = 0;
+  // For a network that is trained by the solver, no bottom or top vecs
+  // should be given, and we will just provide dummy vecs.
+  vector<Blob<Dtype>*> bottom_vec;
+  vector<Blob<Dtype>*> top_vec;
+  while (iter_++ < param_.max_iter()) {
+    Dtype loss = net_->ForwardBackWard(bottom_vec, &top_vec);
+    ComputeUpdateValue();
+    net->Update();
+
+    // Check if we need to do snapshot
+    if (iter_ % param_.snapshot()) {
+      // TODO(Yangqing): snapshot
+    }
+    LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
+  }
+  LOG(INFO) << "Optimization Done.";
+}
+
+template <typename Dtype>
+void Solver<Dtype>::Snapshot(bool is_final) {
+  NetParameter net_param;
+  net_->ToProto(&net_param);
+  stringstream ss;
+  ss << param_.snapshot_prefix();
+  if (is_final) {
+    ss << "_final";
+  } else {
+    ss << "_iter_" << iter_;
+  }
+  ofstream output_file;
+  output_file.open(ss.str().c_str());
+  CHECK(net_param.SerializeToOstream(&output_file));
+  output_file.close();
+}
+
+template <typename Dtype>
+Dtype SGDSolver<Dtype>::GetLearningRate() {
+  Dtype rate;
+  const string& lr_policy = this->param_.lr_policy();
+  if (lr_policy == "fixed") {
+    rate = this->param_.base_lr();
+  } else if (lr_policy == "exp") {
+    rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
+  } else if (lr_policy == "inv") {
+    rate = this->param_.base_lr() *
+        pow(Dtype(1) + this->param_.gamma() * this->iter_,
+            this->param_.power());
+  } else {
+    LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
+  }
+  rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
+  return rate;
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ComputeUpdateValue() {
+  // First of all, see if we need to initialize the history
+  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
+  if (this->iter_ == 1 && this->param_.momentum() > 0) {
+    LOG(INFO) << "Using momentum " << this->param_.momentum();
+    for (int i = 0; i < net_params.size(); ++i) {
+      const Blob<Dtype>* net_param = net_params[i].get();
+      history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
+          net_param->num(), net_param->channels(), net_param->height(),
+          net_param->width())));
+    }
+  }
+  // get the learning rate
+  Dtype rate = GetLearningRate();
+  if (this->param_.momentum == 0) {
+    for (int i = 0; i < net_params.size(); ++i) {
+      switch (Caffe::mode()) {
+      case Caffe::CPU:
+        caffe_scal(net_params[i]->count(), rate,
+            net_params[i]->mutable_cpu_data());
+        break;
+      case Caffe::GPU:
+        caffe_gpu_scal(net_params[i]->count(), rate,
+            net_params[i]->mutable_gpu_data());
+        break;
+      default:
+        LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+      }
+    }
+  } else {
+    NOT_IMPLEMENTED;
+  }
+}
+
+
+
+INSTANTIATE_CLASS(Solver);
+
+}  // namespace caffe
\ No newline at end of file
index 0c680e3..0a78d88 100644 (file)
@@ -3,16 +3,39 @@
 
 namespace caffe {
 
+template <typename Dtype>
 class Solver {
  public:
   explicit Solver(const SolverParameter& param)
       : param_(param) {}
-  void Solve(Net* net);
+  // The main entry of the solver function.
+  void Solve(Net<Dtype>* net);
 
  protected:
+  // Get the update value for the current iteration.
+  virtual void ComputeUpdateValue() = 0;
+  void Snapshot(bool is_final = false);
   SolverParameter param_;
+  int iter_;
+  Net<Dtype>* net_;
+
+  DISABLE_COPY_AND_ASSIGN(Solver);
 };
 
+template <typename Dtype>
+class SGDSolver : public Solver<Dtype> {
+ public:
+  explicit SGDSolver(const SolverParameter& param)
+      : Solver<Dtype>(param) {}
+
+ protected:
+  Dtype GetLearningRate();
+  virtual void ComputeUpdateValue();
+  // history maintains the historical momentum data.
+  vector<shared_ptr<Blob<Dtype> > > history_;
+};
+
+
 }  // namspace caffe
 
 #endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_
\ No newline at end of file
index 732c2ee..9d691d2 100644 (file)
@@ -89,4 +89,6 @@ message SolverParameter {
   optional float gamma = 8; // The parameter to compute the learning rate.
   optional float power = 9; // The parameter to compute the learning rate.
   optional float momentum = 10; // The momentum value.
+
+  optional string snapshot_prefix = 11; // The prefix for the snapshot.
 }
\ No newline at end of file
index 1949a70..7cd3b26 100644 (file)
@@ -124,6 +124,16 @@ void caffe_scal<double>(const int N, const double alpha, double *X) {
 }
 
 template <>
+void caffe_gpu_scal<float>(const int N, const float alpha, float *X) {
+  CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1));
+}
+
+template <>
+void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
+  CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1));
+}
+
+template <>
 void caffe_sqr<float>(const int n, const float* a, float* y) {
   vsSqr(n, a, y);
 }
index 822ef31..f09afe3 100644 (file)
@@ -46,6 +46,9 @@ template <typename Dtype>
 void caffe_scal(const int N, const Dtype alpha, Dtype *X);
 
 template <typename Dtype>
+void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X);
+
+template <typename Dtype>
 void caffe_sqr(const int N, const Dtype* a, Dtype* y);
 
 template <typename Dtype>