From bbd166e07313a6b30c68f3a7557e81316adb3d14 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Mon, 15 Sep 2014 14:15:58 -0700 Subject: [PATCH] fix caffe train GPU initialization Previously, the solver constructed nets before the caffe train tool read the --gpu flag, which can cause errors due to LayerSetUp executing on the wrong device (breaking cuDNN, for example). --- src/caffe/solver.cpp | 5 ----- tools/caffe.cpp | 15 +++++++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 6b55706..fc2576b 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -32,11 +32,6 @@ void Solver::Init(const SolverParameter& param) { LOG(INFO) << "Initializing solver from parameters: " << std::endl << param.DebugString(); param_ = param; - if (param_.solver_mode() == SolverParameter_SolverMode_GPU && - param_.has_device_id()) { - Caffe::SetDevice(param_.device_id()); - } - Caffe::set_mode(Caffe::Brew(param_.solver_mode())); if (param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 10ba70d..fa27fdf 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -87,20 +87,27 @@ int train() { caffe::SolverParameter solver_param; caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); - LOG(INFO) << "Starting Optimization"; - shared_ptr > - solver(caffe::GetSolver(solver_param)); + // If the gpu flag is not provided, allow the mode and device to be set + // in the solver prototxt. + if (FLAGS_gpu < 0 + && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { + FLAGS_gpu = solver_param.device_id(); + } // Set device id and mode if (FLAGS_gpu >= 0) { LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; Caffe::SetDevice(FLAGS_gpu); Caffe::set_mode(Caffe::GPU); - } else if (!solver_param.has_solver_mode()) { + } else { LOG(INFO) << "Use CPU."; Caffe::set_mode(Caffe::CPU); } + LOG(INFO) << "Starting Optimization"; + shared_ptr > + solver(caffe::GetSolver(solver_param)); + if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Solve(FLAGS_snapshot); -- 2.7.4