From a464df45a782e2cd45412744de9c0abbd671df6a Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 26 Aug 2014 12:01:26 -0700 Subject: [PATCH] Re-added solver switch into the new caffe main excutable; fixed AdaGrad MNIST example --- examples/mnist/mnist_autoencoder_solver_adagrad.prototxt | 1 - examples/mnist/train_mnist_autoencoder_adagrad.sh | 2 +- tools/caffe.cpp | 10 +++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt index 6193351..fa7d65c 100644 --- a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt @@ -9,7 +9,6 @@ max_iter: 4000000 weight_decay: 0.0005 snapshot: 10000 snapshot_prefix: "mnist_autoencoder_train" -momentum: 0.9 # solver mode: CPU or GPU solver_mode: GPU solver_type: ADAGRAD diff --git a/examples/mnist/train_mnist_autoencoder_adagrad.sh b/examples/mnist/train_mnist_autoencoder_adagrad.sh index 628c74b..25a48c3 100755 --- a/examples/mnist/train_mnist_autoencoder_adagrad.sh +++ b/examples/mnist/train_mnist_autoencoder_adagrad.sh @@ -1,4 +1,4 @@ #!/bin/bash TOOLS=../../build/tools -$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver.prototxt +$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver_adagrad.prototxt diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 5b3ad0b..9958ac3 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -88,7 +88,7 @@ int train() { caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); LOG(INFO) << "Starting Optimization"; - caffe::SGDSolver solver(solver_param); + shared_ptr> solver(caffe::GetSolver(solver_param)); // Set device id and mode if (FLAGS_gpu >= 0) { @@ -102,13 +102,13 @@ int train() { if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; - solver.Solve(FLAGS_snapshot); + solver->Solve(FLAGS_snapshot); } else if (FLAGS_weights.size()) { LOG(INFO) << "Finetuning from " << FLAGS_weights; - solver.net()->CopyTrainedLayersFrom(FLAGS_weights); - solver.Solve(); + solver->net()->CopyTrainedLayersFrom(FLAGS_weights); + solver->Solve(); } else { - solver.Solve(); + solver->Solve(); } LOG(INFO) << "Optimization Done."; return 0; -- 2.7.4