Set device_id at the begining of Solver.Init() to avoid using memory in the default GPU
authorSergio <sguada@gmail.com>
Thu, 19 Jun 2014 19:55:48 +0000 (12:55 -0700)
committerSergio <sguada@gmail.com>
Thu, 19 Jun 2014 19:55:48 +0000 (12:55 -0700)
src/caffe/solver.cpp

index 6a8f18f..7696181 100644 (file)
@@ -36,6 +36,11 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
   LOG(INFO) << "Initializing solver from parameters: " << std::endl
             << param.DebugString();
   param_ = param;
+  if (param_.solver_mode() == SolverParameter_SolverMode_GPU &&
+      param_.has_device_id()) {
+    Caffe::SetDevice(param_.device_id());
+  }
+  Caffe::set_mode(Caffe::Brew(param_.solver_mode()));
   if (param_.random_seed() >= 0) {
     Caffe::set_random_seed(param_.random_seed());
   }
@@ -74,14 +79,8 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
   LOG(INFO) << "Solver scaffolding done.";
 }
 
-
 template <typename Dtype>
 void Solver<Dtype>::Solve(const char* resume_file) {
-  Caffe::set_mode(Caffe::Brew(param_.solver_mode()));
-  if (param_.solver_mode() == SolverParameter_SolverMode_GPU &&
-      param_.has_device_id()) {
-    Caffe::SetDevice(param_.device_id());
-  }
   Caffe::set_phase(Caffe::TRAIN);
   LOG(INFO) << "Solving " << net_->name();
   PreSolve();