Updated AdaDelta for modern Caffe; reduced iterations on multi-iter tests
authorKevin Bache <kevin.bache@gmail.com>
Thu, 19 Mar 2015 22:56:51 +0000 (15:56 -0700)
committerMatthias Plappert <matthiasplappert@me.com>
Mon, 10 Aug 2015 09:16:53 +0000 (11:16 +0200)
examples/mnist/mnist_autoencoder_solver_adadelta.prototxt
include/caffe/solver.hpp
src/caffe/solver.cpp
src/caffe/test/test_gradient_based_solver.cpp

index cc4f0bb..4e43468 100644 (file)
@@ -6,6 +6,7 @@ test_iter: 100
 test_interval: 500
 test_compute_loss: true
 momentum: 0.95
+delta: 1e-8
 display: 100
 max_iter: 65000
 weight_decay: 0.0005
@@ -14,4 +15,3 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train"
 # solver mode: CPU or GPU
 solver_mode: GPU
 solver_type: ADADELTA
-delta: 1e-8
index 4b40838..495cd4f 100644 (file)
@@ -82,12 +82,12 @@ class SGDSolver : public Solver<Dtype> {
   const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
 
  protected:
-  void PreSolve();
   Dtype GetLearningRate();
   virtual void ApplyUpdate();
   virtual void Normalize(int param_id);
   virtual void Regularize(int param_id);
   virtual void ComputeUpdateValue(int param_id, Dtype rate);
+  virtual void PreSolve();
   virtual void ClipGradients();
   virtual void SnapshotSolverState(const string& model_filename);
   virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
@@ -162,9 +162,9 @@ template <typename Dtype>
 class AdaDeltaSolver : public SGDSolver<Dtype> {
  public:
   explicit AdaDeltaSolver(const SolverParameter& param)
-      : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
+      : SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); }
   explicit AdaDeltaSolver(const string& param_file)
-      : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+      : SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); }
 
  protected:
   virtual void PreSolve();
index d8749a1..34a290f 100644 (file)
@@ -936,35 +936,21 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 
 template <typename Dtype>
 void AdaDeltaSolver<Dtype>::PreSolve() {
-  // Initialize the history
-  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
-  this->history_.clear();
-  this->update_.clear();
-  this->temp_.clear();
-  for (int i = 0; i < net_params.size(); ++i) {
-    const Blob<Dtype>* net_param = net_params[i].get();
-    this->history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
-        net_param->num(), net_param->channels(), net_param->height(),
-        net_param->width())));
-    this->update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
-        net_param->num(), net_param->channels(), net_param->height(),
-        net_param->width())));
-    this->temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
-        net_param->num(), net_param->channels(), net_param->height(),
-        net_param->width())));
-  }
+  // Add the extra history entries for AdaDelta after those from
+  // SGDSolver::PreSolve
+  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   for (int i = 0; i < net_params.size(); ++i) {
-    const Blob<Dtype>* net_param = net_params[i].get();
-    this->history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
-        net_param->num(), net_param->channels(), net_param->height(),
-        net_param->width())));
+        const vector<int>& shape = net_params[i]->shape();
+        this->history_.push_back(
+                shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
   }
 }
 
 template <typename Dtype>
 void AdaDeltaSolver<Dtype>::ComputeUpdateValue() {
-  vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
-  vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
+  const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+  const vector<float>& net_params_weight_decay =
+          this->net_->params_weight_decay();
   Dtype delta = this->param_.delta();
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
index db89e28..277aa3a 100644 (file)
@@ -1060,7 +1060,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
   const Dtype kLearningRate = 0.0;
   const Dtype kWeightDecay = 0.0;
   const Dtype kMomentum = 0.95;
-  const int kNumIters = 500;
+  const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
       this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
   }
@@ -1071,7 +1071,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) {
   const Dtype kLearningRate = 0.0;
   const Dtype kWeightDecay = 0.1;
   const Dtype kMomentum = 0.95;
-  const int kNumIters = 500;
+  const int kNumIters = 4;
   for (int i = 0; i <= kNumIters; ++i) {
       this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
   }