misc update
authorYangqing Jia <jiayq84@gmail.com>
Mon, 7 Oct 2013 19:35:13 +0000 (12:35 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 7 Oct 2013 19:35:13 +0000 (12:35 -0700)
src/caffe/layer.hpp
src/caffe/net.cpp
src/caffe/optimization/solver.cpp
src/caffe/optimization/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/test/test_gradient_check_util.hpp
src/programs/convert_dataset.cpp
src/programs/demo_mnist.cpp
src/programs/train_alexnet.cpp

index 9898cbc..2b73daf 100644 (file)
@@ -34,11 +34,13 @@ class Layer {
       const bool propagate_down,
       vector<Blob<Dtype>*>* bottom);
 
-  // Returns the vector of parameters.
-  vector<shared_ptr<Blob<Dtype> > >& params() {
+  // Returns the vector of blobs.
+  vector<shared_ptr<Blob<Dtype> > >& blobs() {
     return blobs_;
   }
 
+  // Returns the layer parameter
+  const LayerParameter& layer_param() { return layer_param_; }
   // Writes the layer parameter to a protocol buffer
   virtual void ToProto(LayerParameter* param, bool write_diff = false);
 
index c0ccbb1..6b5e4af 100644 (file)
@@ -89,9 +89,14 @@ Net<Dtype>::Net(const NetParameter& param,
   for (int i = 0; i < layers_.size(); ++i) {
     LOG(INFO) << "Setting up " << layer_names_[i];
     layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
-    vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i]->params();
-    for (int j = 0; j < layer_params.size(); ++j) {
-      params_.push_back(layer_params[j]);
+    vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs();
+    for (int j = 0; j < layer_blobs.size(); ++j) {
+      params_.push_back(layer_blobs[j]);
+    }
+    for (int topid = 0; topid < top_vecs_[i].size(); ++topid) {
+      LOG(INFO) << "Top shape: " << top_vecs_[i][topid]->channels() << " "
+          << top_vecs_[i][topid]->height() << " "
+          << top_vecs_[i][topid]->width();
     }
   }
 
@@ -106,7 +111,7 @@ const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
     blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]);
   }
   for (int i = 0; i < layers_.size(); ++i) {
-    //LOG(ERROR) << "Forwarding " << layer_names_[i];
+    // LOG(ERROR) << "Forwarding " << layer_names_[i];
     layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
   }
   return net_output_blobs_;
@@ -141,7 +146,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
     }
     LOG(INFO) << "Loading source layer " << source_layer_name;
     vector<shared_ptr<Blob<Dtype> > >& target_blobs =
-        layers_[target_layer_id]->params();
+        layers_[target_layer_id]->blobs();
     CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
         << "Incompatible number of blobs for layer " << source_layer_name;
     for (int j = 0; j < target_blobs.size(); ++j) {
index a48408c..cb288b3 100644 (file)
@@ -19,7 +19,8 @@ namespace caffe {
 template <typename Dtype>
 void Solver<Dtype>::Solve(Net<Dtype>* net) {
   net_ = net;
-  LOG(INFO) << "Solving net " << net_->name();
+  LOG(INFO) << "Solving " << net_->name();
+  PreSolve();
   iter_ = 0;
   // For a network that is trained by the solver, no bottom or top vecs
   // should be given, and we will just provide dummy vecs.
@@ -79,10 +80,11 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
 }
 
 template <typename Dtype>
-void SGDSolver<Dtype>::ComputeUpdateValue() {
+void SGDSolver<Dtype>::PreSolve() {
   // First of all, see if we need to initialize the history
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
-  if (history_.size() == 0 && this->param_.momentum() > 0) {
+  history_.clear();
+  if (this->param_.momentum() > 0) {
     for (int i = 0; i < net_params.size(); ++i) {
       const Blob<Dtype>* net_param = net_params[i].get();
       history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
@@ -90,45 +92,54 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
           net_param->width())));
     }
   }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ComputeUpdateValue() {
+  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   // get the learning rate
   Dtype rate = GetLearningRate();
-  if (this->param_.momentum() == 0) {
-    for (int i = 0; i < net_params.size(); ++i) {
-      switch (Caffe::mode()) {
-      case Caffe::CPU:
-        caffe_scal(net_params[i]->count(), rate,
-            net_params[i]->mutable_cpu_diff());
-        break;
-      case Caffe::GPU:
-        caffe_gpu_scal(net_params[i]->count(), rate,
-            net_params[i]->mutable_gpu_diff());
-        break;
-      default:
-        LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+  Dtype momentum = this->param_.momentum();
+  Dtype weight_decay = this->param_.weight_decay();
+  switch (Caffe::mode()) {
+  case Caffe::CPU:
+    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
+      // Compute the value to history, and then copy them to the blob's diff.
+      caffe_axpby(net_params[param_id]->count(), rate,
+          net_params[param_id]->cpu_diff(), momentum,
+          history_[param_id]->mutable_cpu_data());
+      if (weight_decay) {
+        // add weight decay
+        caffe_axpy(net_params[param_id]->count(), weight_decay * rate,
+            net_params[param_id]->cpu_data(),
+            history_[param_id]->mutable_cpu_data());
       }
+      // copy
+      caffe_copy(net_params[param_id]->count(),
+          history_[param_id]->cpu_data(),
+          net_params[param_id]->mutable_cpu_diff());
     }
-  } else {
-    // Need to maintain momentum
-    for (int i = 0; i < net_params.size(); ++i) {
-      switch (Caffe::mode()) {
-      case Caffe::CPU:
-        caffe_axpby(net_params[i]->count(), rate,
-            net_params[i]->cpu_diff(), Dtype(this->param_.momentum()),
-            history_[i]->mutable_cpu_data());
-        caffe_copy(net_params[i]->count(), history_[i]->cpu_data(),
-            net_params[i]->mutable_cpu_diff());
-        break;
-      case Caffe::GPU:
-        caffe_gpu_axpby(net_params[i]->count(), rate,
-            net_params[i]->gpu_diff(), Dtype(this->param_.momentum()),
-            history_[i]->mutable_gpu_data());
-        caffe_gpu_copy(net_params[i]->count(), history_[i]->gpu_data(),
-            net_params[i]->mutable_gpu_diff());
-        break;
-      default:
-        LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+    break;
+  case Caffe::GPU:
+    for (int param_id = 0; param_id < net_params.size(); ++param_id) {
+      // Compute the value to history, and then copy them to the blob's diff.
+      caffe_gpu_axpby(net_params[param_id]->count(), rate,
+          net_params[param_id]->gpu_diff(), momentum,
+          history_[param_id]->mutable_gpu_data());
+      if (weight_decay) {
+        // add weight decay
+        caffe_gpu_axpy(net_params[param_id]->count(), weight_decay * rate,
+            net_params[param_id]->gpu_data(),
+            history_[param_id]->mutable_gpu_data());
       }
+      // copy
+      caffe_gpu_copy(net_params[param_id]->count(),
+          history_[param_id]->gpu_data(),
+          net_params[param_id]->mutable_gpu_diff());
     }
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
   }
 }
 
index 0a78d88..f20a06e 100644 (file)
@@ -12,6 +12,9 @@ class Solver {
   void Solve(Net<Dtype>* net);
 
  protected:
+  // PreSolve is run before any solving iteration starts, allowing one to
+  // put up some scaffold.
+  virtual void PreSolve() {};
   // Get the update value for the current iteration.
   virtual void ComputeUpdateValue() = 0;
   void Snapshot(bool is_final = false);
@@ -29,6 +32,7 @@ class SGDSolver : public Solver<Dtype> {
       : Solver<Dtype>(param) {}
 
  protected:
+  virtual void PreSolve();
   Dtype GetLearningRate();
   virtual void ComputeUpdateValue();
   // history maintains the historical momentum data.
index eef6058..0231ad9 100644 (file)
@@ -95,6 +95,7 @@ message SolverParameter {
   optional float gamma = 8; // The parameter to compute the learning rate.
   optional float power = 9; // The parameter to compute the learning rate.
   optional float momentum = 10; // The momentum value.
+  optional float weight_decay = 11; // The weight decay.
 
-  optional string snapshot_prefix = 11; // The prefix for the snapshot.
+  optional string snapshot_prefix = 12; // The prefix for the snapshot.
 }
index c540549..55a5b95 100644 (file)
@@ -65,8 +65,8 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
     int check_bottom, int top_id, int top_data_id) {
   // First, figure out what blobs we need to check against.
   vector<Blob<Dtype>*> blobs_to_check;
-  for (int i = 0; i < layer.params().size(); ++i) {
-    blobs_to_check.push_back(layer.params()[i].get());
+  for (int i = 0; i < layer.blobs().size(); ++i) {
+    blobs_to_check.push_back(layer.blobs()[i].get());
   }
   if (check_bottom < 0) {
     for (int i = 0; i < bottom.size(); ++i) {
index 53a1e29..3bf7794 100644 (file)
@@ -57,8 +57,13 @@ int main(int argc, char** argv) {
   leveldb::WriteBatch* batch = new leveldb::WriteBatch();
   while (infile >> filename >> label) {
     ReadImageToDatum(root_folder + filename, label, &datum);
+    // sequential
     sprintf(key_cstr, "%08d_%s", count, filename.c_str());
     string key(key_cstr);
+    // random
+    // string key;
+    // GenerateRandomPrefix(8, &key);
+    // key += filename;
     string value;
     // get the value
     datum.SerializeToString(&value);
index 7c0937b..e5712a8 100644 (file)
@@ -40,6 +40,7 @@ int main(int argc, char** argv) {
   solver_param.set_gamma(0.0001);
   solver_param.set_power(0.75);
   solver_param.set_momentum(0.9);
+  solver_param.set_weight_decay(0.0005);
 
   LOG(ERROR) << "Starting Optimization";
   SGDSolver<float> solver(solver_param);
index c86a946..d6a4ca5 100644 (file)
@@ -32,21 +32,15 @@ int main(int argc, char** argv) {
   LOG(ERROR) << "Performing Backward";
   LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
 
-  // Run the network without training.
-  LOG(ERROR) << "Multiple Passes";
-  for (int i = 0; i < 100; ++i) {
-    caffe_net.ForwardBackward(bottom_vec);
-  }
-  LOG(ERROR) << "Multiple passes done.";
-/*
   SolverParameter solver_param;
-  solver_param.set_base_lr(0.01);
-  solver_param.set_display(0);
-  solver_param.set_max_iter(6000);
-  solver_param.set_lr_policy("inv");
-  solver_param.set_gamma(0.0001);
-  solver_param.set_power(0.75);
+  solver_param.set_base_lr(0.001);
+  solver_param.set_display(1);
+  solver_param.set_max_iter(600000);
+  solver_param.set_lr_policy("fixed");
+  //solver_param.set_gamma(0.0001);
+  //solver_param.set_power(0.75);
   solver_param.set_momentum(0.9);
+  solver_param.set_weight_decay(0.0005);
 
   LOG(ERROR) << "Starting Optimization";
   SGDSolver<float> solver(solver_param);
@@ -60,41 +54,5 @@ int main(int argc, char** argv) {
   float loss = caffe_net.Backward();
   LOG(ERROR) << "Final loss: " << loss;
 
-  NetParameter trained_net_param;
-  caffe_net.ToProto(&trained_net_param);
-
-  NetParameter traintest_net_param;
-  ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt",
-      &traintest_net_param);
-  Net<float> caffe_traintest_net(traintest_net_param, bottom_vec);
-  caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
-
-  // Test run
-  double train_accuracy = 0;
-  int batch_size = traintest_net_param.layers(0).layer().batchsize();
-  for (int i = 0; i < 60000 / batch_size; ++i) {
-    const vector<Blob<float>*>& result =
-        caffe_traintest_net.Forward(bottom_vec);
-    train_accuracy += result[0]->cpu_data()[0];
-  }
-  train_accuracy /= 60000 / batch_size;
-  LOG(ERROR) << "Train accuracy:" << train_accuracy;
-
-  NetParameter test_net_param;
-  ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param);
-  Net<float> caffe_test_net(test_net_param, bottom_vec);
-  caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
-
-  // Test run
-  double test_accuracy = 0;
-  batch_size = test_net_param.layers(0).layer().batchsize();
-  for (int i = 0; i < 10000 / batch_size; ++i) {
-    const vector<Blob<float>*>& result =
-        caffe_test_net.Forward(bottom_vec);
-    test_accuracy += result[0]->cpu_data()[0];
-  }
-  test_accuracy /= 10000 / batch_size;
-  LOG(ERROR) << "Test accuracy:" << test_accuracy;
-*/
   return 0;
 }