caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);
LOG(INFO) << "Starting Optimization";
- caffe::SGDSolver<float> solver(solver_param);
+ shared_ptr<caffe::Solver<float>> solver(caffe::GetSolver<float>(solver_param));
// Set device id and mode
if (FLAGS_gpu >= 0) {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
- solver.Solve(FLAGS_snapshot);
+ solver->Solve(FLAGS_snapshot);
} else if (FLAGS_weights.size()) {
LOG(INFO) << "Finetuning from " << FLAGS_weights;
- solver.net()->CopyTrainedLayersFrom(FLAGS_weights);
- solver.Solve();
+ solver->net()->CopyTrainedLayersFrom(FLAGS_weights);
+ solver->Solve();
} else {
- solver.Solve();
+ solver->Solve();
}
LOG(INFO) << "Optimization Done.";
return 0;