Multi-GPU
authorCyprien Noel <cyprien.noel@gmail.com>
Tue, 19 May 2015 18:11:05 +0000 (11:11 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 9 Aug 2015 22:16:00 +0000 (15:16 -0700)
- Parallelize batches among GPUs and tree-reduce the gradients
- The effective batch size scales with the number of devices
- Batch size is multiplied by the number of devices
- Split batches between GPUs, and tree-reduce the gradients
- Detect machine topology (twin-GPU boards, P2P connectivity)
- Track device in syncedmem (thanks @thatguymike)
- Insert a callback in the solver for minimal code change
- Accept list for gpu flag of caffe tool, e.g. '-gpu 0,1' or '-gpu all'.
  Run on default GPU if no ID given.
- Add multi-GPU solver test
- Deterministic architecture for reproducible runs

18 files changed:
include/caffe/caffe.hpp
include/caffe/common.hpp
include/caffe/internal_thread.hpp
include/caffe/layer_factory.hpp
include/caffe/parallel.hpp [new file with mode: 0644]
include/caffe/solver.hpp
include/caffe/syncedmem.hpp
src/caffe/common.cpp
src/caffe/data_reader.cpp
src/caffe/data_transformer.cpp
src/caffe/internal_thread.cpp
src/caffe/net.cpp
src/caffe/parallel.cpp [new file with mode: 0644]
src/caffe/solver.cpp
src/caffe/syncedmem.cpp
src/caffe/test/test_gradient_based_solver.cpp
src/caffe/util/blocking_queue.cpp
tools/caffe.cpp

index 3c829f2..68a5e1d 100644 (file)
@@ -10,6 +10,7 @@
 #include "caffe/layer.hpp"
 #include "caffe/layer_factory.hpp"
 #include "caffe/net.hpp"
+#include "caffe/parallel.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/solver.hpp"
 #include "caffe/util/benchmark.hpp"
index 3fa8143..1df6b9a 100644 (file)
@@ -149,6 +149,11 @@ class Caffe {
   static void SetDevice(const int device_id);
   // Prints the current GPU status.
   static void DeviceQuery();
+  // Parallel training info
+  inline static int solver_count() { return Get().solver_count_; }
+  inline static void set_solver_count(int val) { Get().solver_count_ = val; }
+  inline static bool root_solver() { return Get().root_solver_; }
+  inline static void set_root_solver(bool val) { Get().root_solver_ = val; }
 
  protected:
 #ifndef CPU_ONLY
@@ -158,6 +163,8 @@ class Caffe {
   shared_ptr<RNG> random_generator_;
 
   Brew mode_;
+  int solver_count_;
+  bool root_solver_;
 
  private:
   // The private constructor to avoid duplicate instantiation.
index be6ff7f..6a8c5a0 100644 (file)
@@ -42,7 +42,8 @@ class InternalThread {
   bool must_stop();
 
  private:
-  void entry(int device, Caffe::Brew mode, int rand_seed);
+  void entry(int device, Caffe::Brew mode, int rand_seed, int solver_count,
+      bool root_solver);
 
   shared_ptr<boost::thread> thread_;
 };
index 2fcd938..32e849d 100644 (file)
@@ -71,7 +71,9 @@ class LayerRegistry {
 
   // Get a layer using a LayerParameter.
   static shared_ptr<Layer<Dtype> > CreateLayer(const LayerParameter& param) {
-    LOG(INFO) << "Creating layer " << param.name();
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "Creating layer " << param.name();
+    }
     const string& type = param.type();
     CreatorRegistry& registry = Registry();
     CHECK_EQ(registry.count(type), 1) << "Unknown layer type: " << type
diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp
new file mode 100644 (file)
index 0000000..85fc2b5
--- /dev/null
@@ -0,0 +1,118 @@
+#ifndef CAFFE_PARALLEL_HPP_
+#define CAFFE_PARALLEL_HPP_
+
+#include <boost/date_time/posix_time/posix_time.hpp>
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/internal_thread.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+#include "caffe/syncedmem.hpp"
+#include "caffe/util/blocking_queue.hpp"
+
+namespace caffe {
+
+// Represents a net parameters. Once a net is created, its parameter buffers can
+// be replaced by ones from Params, to allow parallelization. Params ensures
+// parameters are allocated in one consecutive array.
+template<typename Dtype>
+class Params {
+ public:
+  explicit Params(shared_ptr<Solver<Dtype> > root_solver);
+  virtual ~Params() {
+  }
+
+  inline size_t size() const {
+    return size_;
+  }
+  inline Dtype* data() const {
+    return data_;
+  }
+  inline Dtype* diff() const {
+    return diff_;
+  }
+
+ protected:
+  const size_t size_;           // Size of buffers
+  Dtype* data_;                 // Network parameters
+  Dtype* diff_;                 // Gradient
+
+DISABLE_COPY_AND_ASSIGN(Params);
+};
+
+// Params stored in GPU memory.
+template<typename Dtype>
+class GPUParams : public Params<Dtype> {
+ public:
+  GPUParams(shared_ptr<Solver<Dtype> > root_solver, int device);
+  virtual ~GPUParams();
+
+  void configure(Solver<Dtype>* solver) const;
+
+ protected:
+  using Params<Dtype>::size_;
+  using Params<Dtype>::data_;
+  using Params<Dtype>::diff_;
+};
+
+class DevicePair {
+ public:
+  DevicePair(int parent, int device)
+      : parent_(parent),
+        device_(device) {
+  }
+  inline int parent() {
+    return parent_;
+  }
+  inline int device() {
+    return device_;
+  }
+
+  // Group GPUs in pairs, by proximity depending on machine's topology
+  static void compute(const vector<int> devices, vector<DevicePair>* pairs);
+
+ protected:
+  int parent_;
+  int device_;
+};
+
+// Synchronous data parallelism using map-reduce between local GPUs.
+template<typename Dtype>
+class P2PSync : public GPUParams<Dtype>, public Solver<Dtype>::Callback,
+    public InternalThread {
+ public:
+  explicit P2PSync(shared_ptr<Solver<Dtype> > root_solver,
+                   P2PSync<Dtype>* parent, const SolverParameter& param);
+  virtual ~P2PSync();
+
+  inline const shared_ptr<Solver<Dtype> >& solver() const {
+    return solver_;
+  }
+
+  void run(const vector<int>& gpus);
+
+ protected:
+  void on_start();
+  void on_gradients_ready();
+
+  void InternalThreadEntry();
+
+  P2PSync<Dtype>* parent_;
+  vector<P2PSync<Dtype>*> children_;
+  BlockingQueue<P2PSync<Dtype>*> queue_;
+  const int initial_iter_;
+  Dtype* parent_grads_;
+  shared_ptr<Solver<Dtype> > solver_;
+
+  using Params<Dtype>::size_;
+  using Params<Dtype>::data_;
+  using Params<Dtype>::diff_;
+};
+
+}  // namespace caffe
+
+#endif
index fbade93..89a6c76 100644 (file)
@@ -32,12 +32,27 @@ class Solver {
   // methods to restore the state from the appropriate snapshot type.
   void Restore(const char* resume_file);
   virtual ~Solver() {}
+  inline const SolverParameter& param() const { return param_; }
   inline shared_ptr<Net<Dtype> > net() { return net_; }
   inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
     return test_nets_;
   }
   int iter() { return iter_; }
 
+  // Invoked at specific points during an iteration
+  class Callback {
+   protected:
+    virtual void on_start() = 0;
+    virtual void on_gradients_ready() = 0;
+
+    template <typename T>
+    friend class Solver;
+  };
+  const vector<Callback*>& callbacks() const { return callbacks_; }
+  void add_callback(Callback* value) {
+    callbacks_.push_back(value);
+  }
+
  protected:
   // Make and apply the update value for the current iteration.
   virtual void ApplyUpdate() = 0;
@@ -62,10 +77,33 @@ class Solver {
   int current_step_;
   shared_ptr<Net<Dtype> > net_;
   vector<shared_ptr<Net<Dtype> > > test_nets_;
+  vector<Callback*> callbacks_;
 
   DISABLE_COPY_AND_ASSIGN(Solver);
 };
 
+/**
+ * @brief Solver that only computes gradients, used as worker
+ *        for multi-GPU training.
+ */
+template <typename Dtype>
+class WorkerSolver : public Solver<Dtype> {
+ public:
+  explicit WorkerSolver(const SolverParameter& param)
+      : Solver<Dtype>(param) {}
+
+ protected:
+  void ApplyUpdate() {}
+  void SnapshotSolverState(const string& model_filename) {
+    LOG(FATAL) << "Should not be called on worker solver.";
+  }
+  void RestoreSolverStateFromBinaryProto(const string& state_file) {
+    LOG(FATAL) << "Should not be called on worker solver.";
+  }
+  void RestoreSolverStateFromHDF5(const string& state_file) {
+    LOG(FATAL) << "Should not be called on worker solver.";
+  }
+};
 
 /**
  * @brief Optimizes the parameters of a Net using
index 4a1a2f3..62aadef 100644 (file)
@@ -45,14 +45,15 @@ class SyncedMemory {
  public:
   SyncedMemory()
       : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
-        own_cpu_data_(false) {}
+        own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {}
   explicit SyncedMemory(size_t size)
       : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
-        own_cpu_data_(false) {}
+        own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {}
   ~SyncedMemory();
   const void* cpu_data();
   void set_cpu_data(void* data);
   const void* gpu_data();
+  void set_gpu_data(void* data);
   void* mutable_cpu_data();
   void* mutable_gpu_data();
   enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED };
@@ -71,6 +72,8 @@ class SyncedMemory {
   size_t size_;
   SyncedHead head_;
   bool own_cpu_data_;
+  bool own_gpu_data_;
+  int gpu_device_;
 
   DISABLE_COPY_AND_ASSIGN(SyncedMemory);
 };  // class SyncedMemory
index 0215c76..7077f37 100644 (file)
@@ -51,7 +51,8 @@ void GlobalInit(int* pargc, char*** pargv) {
 #ifdef CPU_ONLY  // CPU-only Caffe.
 
 Caffe::Caffe()
-    : random_generator_(), mode_(Caffe::CPU) { }
+    : random_generator_(), mode_(Caffe::CPU),
+      solver_count_(1), root_solver_(true) { }
 
 Caffe::~Caffe() { }
 
@@ -95,7 +96,7 @@ void* Caffe::RNG::generator() {
 
 Caffe::Caffe()
     : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
-    mode_(Caffe::CPU) {
+    mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
   // Try to create a cublas handler, and report an error if failed (but we will
   // keep the program running as one might just want to run CPU code).
   if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
index 60606f0..1637820 100644 (file)
@@ -76,9 +76,7 @@ void DataReader::Body::InternalThreadEntry() {
   shared_ptr<db::Cursor> cursor(db->NewCursor());
   vector<shared_ptr<QueuePair> > qps;
   try {
-    // int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;
-    // TODO single solver until multi-gpu merge
-    int solver_count = 1;
+    int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;
 
     // To ensure deterministic runs, only start running once all solvers
     // are ready. But solvers need to peek on one item during initialization,
index 2263392..4666d9b 100644 (file)
@@ -19,7 +19,9 @@ DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
     CHECK_EQ(param_.mean_value_size(), 0) <<
       "Cannot specify mean_file and mean_value at the same time";
     const string& mean_file = param.mean_file();
-    LOG(INFO) << "Loading mean file from: " << mean_file;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "Loading mean file from: " << mean_file;
+    }
     BlobProto blob_proto;
     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
     data_mean_.FromProto(blob_proto);
index b193826..104884e 100644 (file)
@@ -27,21 +27,26 @@ void InternalThread::StartInternalThread() {
 #endif
   Caffe::Brew mode = Caffe::mode();
   int rand_seed = caffe_rng_rand();
+  int solver_count = Caffe::solver_count();
+  bool root_solver = Caffe::root_solver();
 
   try {
     thread_.reset(new boost::thread(&InternalThread::entry, this, device, mode,
-          rand_seed));
+          rand_seed, solver_count, root_solver));
   } catch (std::exception& e) {
     LOG(FATAL) << "Thread exception: " << e.what();
   }
 }
 
-void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed) {
+void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed,
+    int solver_count, bool root_solver) {
 #ifndef CPU_ONLY
   CUDA_CHECK(cudaSetDevice(device));
 #endif
   Caffe::set_mode(mode);
   Caffe::set_random_seed(rand_seed);
+  Caffe::set_solver_count(solver_count);
+  Caffe::set_root_solver(root_solver);
 
   InternalThreadEntry();
 }
index 0e5ed80..5d0f432 100644 (file)
@@ -10,6 +10,7 @@
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/net.hpp"
+#include "caffe/parallel.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/hdf5.hpp"
 #include "caffe/util/insert_splits.hpp"
@@ -41,8 +42,10 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   // the current NetState.
   NetParameter filtered_param;
   FilterNet(in_param, &filtered_param);
-  LOG(INFO) << "Initializing net from parameters: " << std::endl
-            << filtered_param.DebugString();
+  if (Caffe::root_solver()) {
+    LOG(INFO) << "Initializing net from parameters: " << std::endl
+              << filtered_param.DebugString();
+  }
   // Create a copy of filtered_param with splits added where necessary.
   NetParameter param;
   InsertSplits(filtered_param, &param);
@@ -66,7 +69,8 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
     const int layer_id = -1;  // inputs have fake layer ID -1
     AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx);
   }
-  DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+  DLOG_IF(INFO, Caffe::root_solver())
+      << "Memory required for data: " << memory_used_ * sizeof(Dtype);
   // For each layer, set up its input and output
   bottom_vecs_.resize(param.layer_size());
   top_vecs_.resize(param.layer_size());
@@ -89,7 +93,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
     }
     layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
     layer_names_.push_back(layer_param.name());
-    LOG(INFO) << "Creating Layer " << layer_param.name();
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "Creating Layer " << layer_param.name();
+    }
     bool need_backward = false;
 
     // Figure out this layer's input and output
@@ -119,20 +125,30 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
       }
     }
     // After this layer is connected, set it up.
-    LOG(INFO) << "Setting up " << layer_names_[layer_id];
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "Setting up " << layer_names_[layer_id];
+    }
     layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
     for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
       if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) {
         blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0));
       }
       blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id);
-      LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string();
+      if (Caffe::root_solver()) {
+        LOG(INFO) << "Top shape: "
+                  << top_vecs_[layer_id][top_id]->shape_string();
+      }
       if (layer->loss(top_id)) {
-        LOG(INFO) << "    with loss weight " << layer->loss(top_id);
+        if (Caffe::root_solver()) {
+          LOG(INFO) << "    with loss weight " << layer->loss(top_id);
+        }
       }
       memory_used_ += top_vecs_[layer_id][top_id]->count();
     }
-    DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+    if (Caffe::root_solver()) {
+      DLOG(INFO) << "Memory required for data: "
+                 << memory_used_ * sizeof(Dtype);
+    }
     const int param_size = layer_param.param_size();
     const int num_param_blobs = layers_[layer_id]->blobs().size();
     CHECK_LE(param_size, num_param_blobs)
@@ -191,10 +207,14 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
     }
     if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
     if (layer_need_backward_[layer_id]) {
-      LOG(INFO) << layer_names_[layer_id] << " needs backward computation.";
+      if (Caffe::root_solver()) {
+        LOG(INFO) << layer_names_[layer_id] << " needs backward computation.";
+      }
     } else {
-      LOG(INFO) << layer_names_[layer_id]
-                << " does not need backward computation.";
+      if (Caffe::root_solver()) {
+        LOG(INFO) << layer_names_[layer_id]
+                  << " does not need backward computation.";
+      }
     }
     for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
          ++bottom_id) {
@@ -234,7 +254,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   // In the end, all remaining blobs are considered output blobs.
   for (set<string>::iterator it = available_blobs.begin();
       it != available_blobs.end(); ++it) {
-    LOG(INFO) << "This network produces output " << *it;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "This network produces output " << *it;
+    }
     net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
     net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
   }
@@ -246,8 +268,10 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   }
   ShareWeights();
   debug_info_ = param.debug_info();
-  LOG(INFO) << "Network initialization done.";
-  LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+  if (Caffe::root_solver()) {
+    LOG(INFO) << "Network initialization done.";
+    LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+  }
 }
 
 template <typename Dtype>
@@ -286,27 +310,33 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
   // Check whether the rule is broken due to phase.
   if (rule.has_phase()) {
       if (rule.phase() != state.phase()) {
-        LOG(INFO) << "The NetState phase (" << state.phase()
-          << ") differed from the phase (" << rule.phase()
-          << ") specified by a rule in layer " << layer_name;
+        if (Caffe::root_solver()) {
+          LOG(INFO) << "The NetState phase (" << state.phase()
+                    << ") differed from the phase (" << rule.phase()
+                    << ") specified by a rule in layer " << layer_name;
+        }
         return false;
       }
   }
   // Check whether the rule is broken due to min level.
   if (rule.has_min_level()) {
     if (state.level() < rule.min_level()) {
-      LOG(INFO) << "The NetState level (" << state.level()
-          << ") is above the min_level (" << rule.min_level()
-          << ") specified by a rule in layer " << layer_name;
+      if (Caffe::root_solver()) {
+        LOG(INFO) << "The NetState level (" << state.level()
+                  << ") is above the min_level (" << rule.min_level()
+                  << ") specified by a rule in layer " << layer_name;
+      }
       return false;
     }
   }
   // Check whether the rule is broken due to max level.
   if (rule.has_max_level()) {
     if (state.level() > rule.max_level()) {
-      LOG(INFO) << "The NetState level (" << state.level()
-          << ") is above the max_level (" << rule.max_level()
-          << ") specified by a rule in layer " << layer_name;
+      if (Caffe::root_solver()) {
+        LOG(INFO) << "The NetState level (" << state.level()
+                  << ") is above the max_level (" << rule.max_level()
+                  << ") specified by a rule in layer " << layer_name;
+      }
       return false;
     }
   }
@@ -319,8 +349,10 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
       if (rule.stage(i) == state.stage(j)) { has_stage = true; }
     }
     if (!has_stage) {
-      LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i)
-                << "' specified by a rule in layer " << layer_name;
+      if (Caffe::root_solver()) {
+        LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i)
+                  << "' specified by a rule in layer " << layer_name;
+      }
       return false;
     }
   }
@@ -333,8 +365,10 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
       if (rule.not_stage(i) == state.stage(j)) { has_stage = true; }
     }
     if (has_stage) {
-      LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i)
-                << "' specified by a rule in layer " << layer_name;
+      if (Caffe::root_solver()) {
+        LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i)
+                  << "' specified by a rule in layer " << layer_name;
+      }
       return false;
     }
   }
@@ -356,7 +390,9 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
   if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id &&
       blob_name == layer_param->bottom(top_id)) {
     // In-place computation
-    LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)";
+    if (Caffe::root_solver()) {
+      LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)";
+    }
     top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get());
     top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]);
   } else if (blob_name_to_idx &&
@@ -366,10 +402,12 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
     LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
   } else {
     // Normal output.
-    if (layer_param) {
-      LOG(INFO) << layer_param->name() << " -> " << blob_name;
-    } else {
-      LOG(INFO) << "Input " << top_id << " -> " << blob_name;
+    if (Caffe::root_solver()) {
+      if (layer_param) {
+        LOG(INFO) << layer_param->name() << " -> " << blob_name;
+      } else {
+        LOG(INFO) << "Input " << top_id << " -> " << blob_name;
+      }
     }
     shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
     const int blob_id = blobs_.size();
@@ -409,7 +447,9 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
                << " (at index " << bottom_id << ") to layer " << layer_id;
   }
   const int blob_id = (*blob_name_to_idx)[blob_name];
-  LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name;
+  if (Caffe::root_solver()) {
+    LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name;
+  }
   bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
   bottom_id_vecs_[layer_id].push_back(blob_id);
   available_blobs->erase(blob_name);
@@ -468,9 +508,10 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
         param_layer_indices_[owner_net_param_id];
     const int owner_layer_id = owner_index.first;
     const int owner_param_id = owner_index.second;
-    LOG(INFO) << "Sharing parameters '" << param_name << "' owned by "
-              << "layer '" << layer_names_[owner_layer_id] << "', param "
-              << "index " << owner_param_id;
+    LOG_IF(INFO, Caffe::root_solver()) << "Sharing parameters '" << param_name
+        << "' owned by "
+        << "layer '" << layer_names_[owner_layer_id] << "', param "
+        << "index " << owner_param_id;
     Blob<Dtype>* this_blob = layers_[layer_id]->blobs()[param_id].get();
     Blob<Dtype>* owner_blob =
         layers_[owner_layer_id]->blobs()[owner_param_id].get();
@@ -595,8 +636,10 @@ void Net<Dtype>::InputDebugInfo(const int input_id) {
   const Blob<Dtype>& blob = *net_input_blobs_[input_id];
   const string& blob_name = blob_names_[net_input_blob_indices_[input_id]];
   const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
-  LOG(INFO) << "    [Forward] "
-     << "Input " << blob_name << " data: " << data_abs_val_mean;
+  if (Caffe::root_solver()) {
+    LOG(INFO) << "    [Forward] "
+              << "Input " << blob_name << " data: " << data_abs_val_mean;
+  }
 }
 
 template <typename Dtype>
@@ -605,9 +648,12 @@ void Net<Dtype>::ForwardDebugInfo(const int layer_id) {
     const Blob<Dtype>& blob = *top_vecs_[layer_id][top_id];
     const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
     const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
-    LOG(INFO) << "    [Forward] "
-       << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name
-       << " data: " << data_abs_val_mean;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "    [Forward] "
+                << "Layer " << layer_names_[layer_id]
+                << ", top blob " << blob_name
+                << " data: " << data_abs_val_mean;
+    }
   }
   for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
        ++param_id) {
@@ -615,9 +661,12 @@ void Net<Dtype>::ForwardDebugInfo(const int layer_id) {
     const int net_param_id = param_id_vecs_[layer_id][param_id];
     const string& blob_name = param_display_names_[net_param_id];
     const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
-    LOG(INFO) << "    [Forward] "
-       << "Layer " << layer_names_[layer_id] << ", param blob " << blob_name
-       << " data: " << data_abs_val_mean;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "    [Forward] "
+                << "Layer " << layer_names_[layer_id]
+                << ", param blob " << blob_name
+                << " data: " << data_abs_val_mean;
+    }
   }
 }
 
@@ -629,18 +678,24 @@ void Net<Dtype>::BackwardDebugInfo(const int layer_id) {
     const Blob<Dtype>& blob = *bottom_vec[bottom_id];
     const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
     const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count();
-    LOG(INFO) << "    [Backward] "
-        << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name
-        << " diff: " << diff_abs_val_mean;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "    [Backward] "
+                << "Layer " << layer_names_[layer_id]
+                << ", bottom blob " << blob_name
+                << " diff: " << diff_abs_val_mean;
+    }
   }
   for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
        ++param_id) {
     if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; }
     const Blob<Dtype>& blob = *layers_[layer_id]->blobs()[param_id];
     const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count();
-    LOG(INFO) << "    [Backward] "
-        << "Layer " << layer_names_[layer_id] << ", param blob " << param_id
-        << " diff: " << diff_abs_val_mean;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "    [Backward] "
+                << "Layer " << layer_names_[layer_id]
+                << ", param blob " << param_id
+                << " diff: " << diff_abs_val_mean;
+    }
   }
 }
 
@@ -653,17 +708,22 @@ void Net<Dtype>::UpdateDebugInfo(const int param_id) {
   const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count();
   if (param_owner < 0) {
     const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
-    LOG(INFO) << "    [Update] Layer " << layer_name
-        << ", param " << param_display_name
-        << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "    [Update] Layer " << layer_name
+                << ", param " << param_display_name
+                << " data: " << data_abs_val_mean
+                << "; diff: " << diff_abs_val_mean;
+    }
   } else {
     const string& owner_layer_name =
         layer_names_[param_layer_indices_[param_owner].first];
-    LOG(INFO) << "    [Update] Layer " << layer_name
-        << ", param blob " << param_display_name
-        << " (owned by layer " << owner_layer_name << ", "
-        << "param " << param_display_names_[param_owners_[param_id]] << ")"
-        << " diff: " << diff_abs_val_mean;
+    if (Caffe::root_solver()) {
+      LOG(INFO) << "    [Update] Layer " << layer_name
+                << ", param blob " << param_display_name
+                << " (owned by layer " << owner_layer_name << ", " << "param "
+                << param_display_names_[param_owners_[param_id]] << ")"
+                << " diff: " << diff_abs_val_mean;
+    }
   }
 }
 
@@ -720,8 +780,8 @@ void Net<Dtype>::Backward() {
     const Dtype l2norm_data = std::sqrt(sumsq_data);
     const Dtype l2norm_diff = std::sqrt(sumsq_diff);
     LOG(ERROR) << "    [Backward] All net params (data, diff): "
-        << "L1 norm = (" << asum_data << ", " << asum_diff << "); "
-        << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")";
+               << "L1 norm = (" << asum_data << ", " << asum_diff << "); "
+               << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")";
   }
 }
 
diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp
new file mode 100644 (file)
index 0000000..3fef8cf
--- /dev/null
@@ -0,0 +1,430 @@
+#ifndef CPU_ONLY
+#include <cuda_runtime.h>
+#endif
+#include <glog/logging.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+
+#include <cstdlib>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "boost/thread.hpp"
+#include "caffe/caffe.hpp"
+#include "caffe/parallel.hpp"
+
+namespace caffe {
+
+enum Op {
+  copy,
+  replace_cpu,
+  replace_gpu,
+  replace_cpu_diff,
+  replace_gpu_diff
+};
+
+template<typename Dtype>
+static void apply_buffers(const vector<Blob<Dtype>*>& blobs,
+                          Dtype* buffer, size_t total_size, Op op) {
+  Dtype* ptr = buffer;
+  for (int i = 0; i < blobs.size(); ++i) {
+    int size = blobs[i]->count();
+    switch (op) {
+      case copy: {
+        // Init buffer to current values of blobs
+        caffe_copy(size,
+                   reinterpret_cast<const Dtype*>(blobs[i]->data()->cpu_data()),
+                   ptr);
+        break;
+      }
+      case replace_cpu:
+        blobs[i]->data()->set_cpu_data(ptr);
+        break;
+      case replace_gpu:
+        blobs[i]->data()->set_gpu_data(ptr);
+        break;
+      case replace_cpu_diff:
+        blobs[i]->diff()->set_cpu_data(ptr);
+        break;
+      case replace_gpu_diff:
+        blobs[i]->diff()->set_gpu_data(ptr);
+        break;
+    }
+    ptr += size;
+  }
+  CHECK_EQ(total_size, ptr - buffer);
+}
+
+// Buffer size necessary to store given blobs
+template<typename Dtype>
+static size_t total_size(const vector<Blob<Dtype>*>& params) {
+  size_t size = 0;
+  for (int i = 0; i < params.size(); ++i)
+    size += params[i]->count();
+  return size;
+}
+
+template<typename Dtype>
+Params<Dtype>::Params(shared_ptr<Solver<Dtype> > root_solver)
+    : size_(total_size<Dtype>(root_solver->net()->learnable_params())),
+      data_(),
+      diff_() {
+}
+
+template<typename Dtype>
+GPUParams<Dtype>::GPUParams(shared_ptr<Solver<Dtype> > root_solver, int device)
+    : Params<Dtype>(root_solver) {
+#ifndef CPU_ONLY
+  int initial_device;
+  CUDA_CHECK(cudaGetDevice(&initial_device));
+
+  // Allocate device buffers
+  CUDA_CHECK(cudaSetDevice(device));
+  CUDA_CHECK(cudaMalloc(&data_, size_ * sizeof(Dtype)));
+
+  // Copy blob values
+  const vector<Blob<Dtype>*>& net =
+      root_solver->net()->learnable_params();
+  apply_buffers(net, data_, size_, copy);
+
+  CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype)));
+  caffe_gpu_set(size_, Dtype(0), diff_);
+
+  CUDA_CHECK(cudaSetDevice(initial_device));
+#else
+  NO_GPU;
+#endif
+}
+
+template<typename Dtype>
+GPUParams<Dtype>::~GPUParams() {
+#ifndef CPU_ONLY
+  CUDA_CHECK(cudaFree(data_));
+  CUDA_CHECK(cudaFree(diff_));
+#endif
+}
+
+template<typename Dtype>
+void GPUParams<Dtype>::configure(Solver<Dtype>* solver) const {
+  const vector<Blob<Dtype>*>& net =
+      solver->net()->learnable_params();
+  apply_buffers(net, data_, size_, replace_gpu);
+  apply_buffers(net, diff_, size_, replace_gpu_diff);
+}
+
+void DevicePair::compute(const vector<int> devices, vector<DevicePair>* pairs) {
+#ifndef CPU_ONLY
+  vector<int> remaining(devices);
+
+  // Group GPUs by board
+  for (int i = 0; i < remaining.size(); ++i) {
+    for (int j = i + 1; j < remaining.size(); ++j) {
+      cudaDeviceProp a, b;
+      CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i]));
+      CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j]));
+      if (a.isMultiGpuBoard && b.isMultiGpuBoard) {
+        if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) {
+          pairs->push_back(DevicePair(remaining[i], remaining[j]));
+          DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j];
+          remaining.erase(remaining.begin() + j);
+          break;
+        }
+      }
+    }
+  }
+  ostringstream s;
+  for (int i = 0; i < remaining.size(); ++i) {
+    s << (i ? ", " : "") << remaining[i];
+  }
+  DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str();
+
+  // Group by P2P accessibility
+  for (int i = 0; i < remaining.size(); ++i) {
+    for (int j = i + 1; j < remaining.size(); ++j) {
+      int access;
+      CUDA_CHECK(cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j]));
+      if (access) {
+        pairs->push_back(DevicePair(remaining[i], remaining[j]));
+        DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j];
+        remaining.erase(remaining.begin() + j);
+        break;
+      }
+    }
+  }
+  s.str("");
+  for (int i = 0; i < remaining.size(); ++i) {
+    s << (i ? ", " : "") << remaining[i];
+  }
+  DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str();
+
+  // Group remaining
+  for (int i = 0; i < remaining.size(); ++i) {
+    for (int j = i + 1; j < remaining.size(); ++j) {
+      pairs->push_back(DevicePair(remaining[i], remaining[j]));
+      DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" << remaining[j];
+      remaining.erase(remaining.begin() + j);
+      break;
+    }
+  }
+  CHECK_EQ(remaining.size(), 1);
+  pairs->insert(pairs->begin(), DevicePair(-1, remaining[0]));
+
+  CHECK(pairs->size() == devices.size());
+  for (int i = 0; i < pairs->size(); ++i) {
+    CHECK((*pairs)[i].parent() != (*pairs)[i].device());
+    for (int j = i + 1; j < pairs->size(); ++j) {
+      CHECK((*pairs)[i].device() != (*pairs)[j].device());
+    }
+  }
+#else
+  NO_GPU;
+#endif
+}
+
+//
+
+template<typename Dtype>
+P2PSync<Dtype>::P2PSync(shared_ptr<Solver<Dtype> > root_solver,
+                        P2PSync<Dtype>* parent, const SolverParameter& param)
+    : GPUParams<Dtype>(root_solver, param.device_id()),
+      parent_(parent),
+      children_(),
+      queue_(),
+      initial_iter_(root_solver->iter()),
+      solver_() {
+#ifndef CPU_ONLY
+  int initial_device;
+  CUDA_CHECK(cudaGetDevice(&initial_device));
+  const int self = param.device_id();
+  CUDA_CHECK(cudaSetDevice(self));
+
+  if (parent == NULL) {
+    solver_ = root_solver;
+  } else {
+    Caffe::set_root_solver(false);
+    solver_.reset(new WorkerSolver<Dtype>(param));
+    Caffe::set_root_solver(true);
+  }
+  this->configure(solver_.get());
+  solver_->add_callback(this);
+
+  if (parent) {
+    // Enable p2p access between devices
+    const int peer = parent->solver_->param().device_id();
+    int access;
+    CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer));
+    if (access) {
+      CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0));
+    } else {
+      LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer;
+    }
+    // Allocate receiving buffer on parent
+    CUDA_CHECK(cudaSetDevice(peer));
+    CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype)));
+    CUDA_CHECK(cudaSetDevice(self));
+  }
+
+  CUDA_CHECK(cudaSetDevice(initial_device));
+#else
+  NO_GPU;
+#endif
+}
+
+template<typename Dtype>
+P2PSync<Dtype>::~P2PSync() {
+#ifndef CPU_ONLY
+  int initial_device;
+  CUDA_CHECK(cudaGetDevice(&initial_device));
+  const int self = solver_->param().device_id();
+  CUDA_CHECK(cudaSetDevice(self));
+
+  if (parent_) {
+    CUDA_CHECK(cudaFree(parent_grads_));
+    const int peer = parent_->solver_->param().device_id();
+    int access;
+    CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer));
+    if (access) {
+      CUDA_CHECK(cudaDeviceDisablePeerAccess(peer));
+    }
+  }
+
+  CUDA_CHECK(cudaSetDevice(initial_device));
+#endif
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::InternalThreadEntry() {
+  Caffe::SetDevice(solver_->param().device_id());
+  CHECK(Caffe::root_solver());
+  Caffe::set_root_solver(false);
+  // See if there is a defined seed and reset random state if so
+  if (solver_->param().random_seed() >= 0) {
+    // Fetch random seed and modulate by device ID to make sure
+    // everyone doesn't have the same seed.  We seem to have some
+    // solver instability if we have everyone with the same seed
+    Caffe::set_random_seed(
+        solver_->param().random_seed() + solver_->param().device_id());
+  }
+  solver_->Step(solver_->param().max_iter() - initial_iter_);
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::on_start() {
+#ifndef CPU_ONLY
+#ifdef DEBUG
+  int device;
+  CUDA_CHECK(cudaGetDevice(&device));
+  CHECK(device == solver_->param().device_id());
+#else
+//  CHECK(false);
+#endif
+
+  // Wait for update from parent
+  if (parent_) {
+    P2PSync<Dtype> *parent = queue_.pop();
+    CHECK(parent == parent_);
+  }
+
+  // Update children
+  for (int i = 0; i < children_.size(); ++i) {
+    Dtype* src = data_;
+    Dtype* dst = children_[i]->data_;
+
+#ifdef DEBUG
+    cudaPointerAttributes attributes;
+    CUDA_CHECK(cudaPointerGetAttributes(&attributes, src));
+    CHECK(attributes.device == device);
+    CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst));
+    CHECK(attributes.device == children_[i]->solver_->param().device_id());
+#endif
+
+    CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype),  //
+        cudaMemcpyDeviceToDevice, cudaStreamDefault));
+  }
+  if (children_.size()) {
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault));
+  }
+  for (int i = 0; i < children_.size(); ++i) {
+    children_[i]->queue_.push(this);
+  }
+#endif
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::on_gradients_ready() {
+#ifndef CPU_ONLY
+#ifdef DEBUG
+  int device;
+  CUDA_CHECK(cudaGetDevice(&device));
+  CHECK(device == solver_->param().device_id());
+#endif
+
+  // Sum children gradients as they appear in the queue
+  for (int i = 0; i < children_.size(); ++i) {
+    P2PSync<Dtype> *child = queue_.pop();
+    Dtype* src = child->parent_grads_;
+    Dtype* dst = diff_;
+
+#ifdef DEBUG
+    bool ok = false;
+    for (int j = 0; j < children_.size(); ++j) {
+      if (child == children_[j]) {
+        ok = true;
+      }
+    }
+    CHECK(ok);
+    cudaPointerAttributes attributes;
+    CUDA_CHECK(cudaPointerGetAttributes(&attributes, src));
+    CHECK(attributes.device == device);
+    CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst));
+    CHECK(attributes.device == device);
+#endif
+
+    caffe_gpu_add(size_, src, dst, dst);
+  }
+
+  // Send gradients to parent
+  if (parent_) {
+    Dtype* src = diff_;
+    Dtype* dst = parent_grads_;
+
+#ifdef DEBUG
+    cudaPointerAttributes attributes;
+    CUDA_CHECK(cudaPointerGetAttributes(&attributes, src));
+    CHECK(attributes.device == device);
+    CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst));
+    CHECK(attributes.device == parent_->solver_->param().device_id());
+#endif
+
+    CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype),  //
+        cudaMemcpyDeviceToDevice, cudaStreamDefault));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault));
+    parent_->queue_.push(this);
+  } else {
+    // Loss functions divide gradients by the batch size, so to compensate
+    // for split batch, the root solver divides by number of solvers.
+    caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_);
+  }
+#endif
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::run(const vector<int>& gpus) {
+  // Pair devices for map-reduce synchronization
+  vector<DevicePair> pairs;
+  DevicePair::compute(gpus, &pairs);
+  ostringstream s;
+  for (int i = 1; i < pairs.size(); ++i) {
+    s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device();
+  }
+  LOG(INFO)<< "GPUs pairs " << s.str();
+
+  SolverParameter param(solver_->param());
+  vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
+
+  // Build the GPU tree by finding the parent for each solver
+  for (int attempts = 0; attempts < pairs.size(); ++attempts) {
+    for (int i = 1; i < pairs.size(); ++i) {
+      if (!syncs[i].get()) {
+        P2PSync<Dtype>* parent = NULL;
+        for (int j = 0; j < syncs.size(); ++j) {
+          P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get();
+          if (sync) {
+            const SolverParameter& p = sync->solver()->param();
+            if (p.device_id() == pairs[i].parent()) {
+              parent = sync;
+            }
+          }
+        }
+        if (parent) {
+          param.set_device_id(pairs[i].device());
+          syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param));
+          parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get());
+        }
+      }
+    }
+  }
+
+  LOG(INFO)<< "Starting Optimization";
+
+  for (int i = 1; i < syncs.size(); ++i) {
+    syncs[i]->StartInternalThread();
+  }
+
+  // Run root solver on current thread
+  solver_->Solve();
+
+  for (int i = 1; i < syncs.size(); ++i) {
+    syncs[i]->StopInternalThread();
+  }
+}
+
+INSTANTIATE_CLASS(Params);
+INSTANTIATE_CLASS(GPUParams);
+INSTANTIATE_CLASS(P2PSync);
+
+}  // namespace caffe
+
index 54e085a..b6fd6b6 100644 (file)
@@ -19,13 +19,13 @@ namespace caffe {
 
 template <typename Dtype>
 Solver<Dtype>::Solver(const SolverParameter& param)
-    : net_() {
+    : net_(), callbacks_() {
   Init(param);
 }
 
 template <typename Dtype>
 Solver<Dtype>::Solver(const string& param_file)
-    : net_() {
+    : net_(), callbacks_() {
   SolverParameter param;
   ReadProtoFromTextFileOrDie(param_file, &param);
   Init(param);
@@ -33,17 +33,19 @@ Solver<Dtype>::Solver(const string& param_file)
 
 template <typename Dtype>
 void Solver<Dtype>::Init(const SolverParameter& param) {
-  LOG(INFO) << "Initializing solver from parameters: " << std::endl
-            << param.DebugString();
+  LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
+    << std::endl << param.DebugString();
   param_ = param;
   CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
-  if (param_.random_seed() >= 0) {
+  if (Caffe::root_solver() && param_.random_seed() >= 0) {
     Caffe::set_random_seed(param_.random_seed());
   }
   // Scaffolding code
   InitTrainNet();
-  InitTestNets();
-  LOG(INFO) << "Solver scaffolding done.";
+  if (Caffe::root_solver()) {
+    InitTestNets();
+    LOG(INFO) << "Solver scaffolding done.";
+  }
   iter_ = 0;
   current_step_ = 0;
 }
@@ -59,19 +61,22 @@ void Solver<Dtype>::InitTrainNet() {
       << "one of these fields specifying a train_net: " << field_names;
   NetParameter net_param;
   if (param_.has_train_net_param()) {
-    LOG(INFO) << "Creating training net specified in train_net_param.";
+    LOG_IF(INFO, Caffe::root_solver())
+        << "Creating training net specified in train_net_param.";
     net_param.CopyFrom(param_.train_net_param());
   } else if (param_.has_train_net()) {
-    LOG(INFO) << "Creating training net from train_net file: "
-              << param_.train_net();
+    LOG_IF(INFO, Caffe::root_solver())
+        << "Creating training net from train_net file: " << param_.train_net();
     ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
   }
   if (param_.has_net_param()) {
-    LOG(INFO) << "Creating training net specified in net_param.";
+    LOG_IF(INFO, Caffe::root_solver())
+        << "Creating training net specified in net_param.";
     net_param.CopyFrom(param_.net_param());
   }
   if (param_.has_net()) {
-    LOG(INFO) << "Creating training net from net file: " << param_.net();
+    LOG_IF(INFO, Caffe::root_solver())
+        << "Creating training net from net file: " << param_.net();
     ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
   }
   // Set the correct NetState.  We start with the solver defaults (lowest
@@ -88,6 +93,7 @@ void Solver<Dtype>::InitTrainNet() {
 
 template <typename Dtype>
 void Solver<Dtype>::InitTestNets() {
+  CHECK(Caffe::root_solver());
   const bool has_net_param = param_.has_net_param();
   const bool has_net_file = param_.has_net();
   const int num_generic_nets = has_net_param + has_net_file;
@@ -175,10 +181,14 @@ void Solver<Dtype>::Step(int iters) {
     // zero-init the params
     net_->ClearParamDiffs();
     if (param_.test_interval() && iter_ % param_.test_interval() == 0
-        && (iter_ > 0 || param_.test_initialization())) {
+        && (iter_ > 0 || param_.test_initialization())
+        && Caffe::root_solver()) {
       TestAll();
     }
 
+    for (int i = 0; i < callbacks_.size(); ++i) {
+      callbacks_[i]->on_start();
+    }
     const bool display = param_.display() && iter_ % param_.display() == 0;
     net_->set_debug_info(display && param_.debug_info());
     // accumulate the loss and gradient
@@ -198,7 +208,8 @@ void Solver<Dtype>::Step(int iters) {
       losses[idx] = loss;
     }
     if (display) {
-      LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss;
+      LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
+          << ", loss = " << smoothed_loss;
       const vector<Blob<Dtype>*>& result = net_->output_blobs();
       int score_index = 0;
       for (int j = 0; j < result.size(); ++j) {
@@ -213,12 +224,15 @@ void Solver<Dtype>::Step(int iters) {
             loss_msg_stream << " (* " << loss_weight
                             << " = " << loss_weight * result_vec[k] << " loss)";
           }
-          LOG(INFO) << "    Train net output #"
+          LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
               << score_index++ << ": " << output_name << " = "
               << result_vec[k] << loss_msg_stream.str();
         }
       }
     }
+    for (int i = 0; i < callbacks_.size(); ++i) {
+      callbacks_[i]->on_gradients_ready();
+    }
     ApplyUpdate();
 
     // Increment the internal iter_ counter -- its value should always indicate
@@ -226,7 +240,9 @@ void Solver<Dtype>::Step(int iters) {
     ++iter_;
 
     // Save a snapshot if needed.
-    if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
+    if (param_.snapshot()
+        && iter_ % param_.snapshot() == 0
+        && Caffe::root_solver()) {
       Snapshot();
     }
   }
@@ -234,6 +250,7 @@ void Solver<Dtype>::Step(int iters) {
 
 template <typename Dtype>
 void Solver<Dtype>::Solve(const char* resume_file) {
+  CHECK(Caffe::root_solver());
   LOG(INFO) << "Solving " << net_->name();
   LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
 
@@ -278,6 +295,7 @@ void Solver<Dtype>::TestAll() {
 
 template <typename Dtype>
 void Solver<Dtype>::Test(const int test_net_id) {
+  CHECK(Caffe::root_solver());
   LOG(INFO) << "Iteration " << iter_
             << ", Testing net (#" << test_net_id << ")";
   CHECK_NOTNULL(test_nets_[test_net_id].get())->
@@ -328,13 +346,14 @@ void Solver<Dtype>::Test(const int test_net_id) {
                       << " = " << loss_weight * mean_score << " loss)";
     }
     LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
-        << mean_score << loss_msg_stream.str();
+              << mean_score << loss_msg_stream.str();
   }
 }
 
 
 template <typename Dtype>
 void Solver<Dtype>::Snapshot() {
+  CHECK(Caffe::root_solver());
   string model_filename;
   switch (param_.snapshot_format()) {
     case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
@@ -379,6 +398,7 @@ string Solver<Dtype>::SnapshotToHDF5() {
 
 template <typename Dtype>
 void Solver<Dtype>::Restore(const char* state_file) {
+  CHECK(Caffe::root_solver());
   string state_filename(state_file);
   if (state_filename.size() >= 3 &&
       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
@@ -480,6 +500,7 @@ void SGDSolver<Dtype>::ClipGradients() {
 
 template <typename Dtype>
 void SGDSolver<Dtype>::ApplyUpdate() {
+  CHECK(Caffe::root_solver());
   Dtype rate = GetLearningRate();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
@@ -723,6 +744,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
 
 template <typename Dtype>
 void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  CHECK(Caffe::root_solver());
   const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
   const vector<float>& net_params_lr = this->net_->params_lr();
   Dtype momentum = this->param_.momentum();
@@ -783,6 +805,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 
 template <typename Dtype>
 void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+  CHECK(Caffe::root_solver());
   const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
   const vector<float>& net_params_lr = this->net_->params_lr();
   Dtype delta = this->param_.delta();
index 0da7a3b..a667a86 100644 (file)
@@ -12,8 +12,14 @@ SyncedMemory::~SyncedMemory() {
   }
 
 #ifndef CPU_ONLY
-  if (gpu_ptr_) {
+  if (gpu_ptr_ && own_gpu_data_) {
+    int initial_device;
+    cudaGetDevice(&initial_device);
+    if (gpu_device_ != -1) {
+      CUDA_CHECK(cudaSetDevice(gpu_device_));
+    }
     CUDA_CHECK(cudaFree(gpu_ptr_));
+    cudaSetDevice(initial_device);
   }
 #endif  // CPU_ONLY
 }
@@ -48,13 +54,17 @@ inline void SyncedMemory::to_gpu() {
 #ifndef CPU_ONLY
   switch (head_) {
   case UNINITIALIZED:
+    CUDA_CHECK(cudaGetDevice(&gpu_device_));
     CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
     caffe_gpu_memset(size_, 0, gpu_ptr_);
     head_ = HEAD_AT_GPU;
+    own_gpu_data_ = true;
     break;
   case HEAD_AT_CPU:
     if (gpu_ptr_ == NULL) {
+      CUDA_CHECK(cudaGetDevice(&gpu_device_));
       CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+      own_gpu_data_ = true;
     }
     caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_);
     head_ = SYNCED;
@@ -92,6 +102,26 @@ const void* SyncedMemory::gpu_data() {
 #endif
 }
 
+void SyncedMemory::set_gpu_data(void* data) {
+#ifndef CPU_ONLY
+  CHECK(data);
+  if (own_gpu_data_) {
+    int initial_device;
+    cudaGetDevice(&initial_device);
+    if (gpu_device_ != -1) {
+      CUDA_CHECK(cudaSetDevice(gpu_device_));
+    }
+    CUDA_CHECK(cudaFree(gpu_ptr_));
+    cudaSetDevice(initial_device);
+  }
+  gpu_ptr_ = data;
+  head_ = HEAD_AT_GPU;
+  own_gpu_data_ = false;
+#else
+  NO_GPU;
+#endif
+}
+
 void* SyncedMemory::mutable_cpu_data() {
   to_cpu();
   head_ = HEAD_AT_CPU;
@@ -112,7 +142,9 @@ void* SyncedMemory::mutable_gpu_data() {
 void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
   CHECK(head_ == HEAD_AT_CPU);
   if (gpu_ptr_ == NULL) {
+    CUDA_CHECK(cudaGetDevice(&gpu_device_));
     CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+    own_gpu_data_ = true;
   }
   const cudaMemcpyKind put = cudaMemcpyHostToDevice;
   CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream));
index eaa7a75..1cede07 100644 (file)
@@ -8,6 +8,7 @@
 #include "gtest/gtest.h"
 
 #include "caffe/common.hpp"
+#include "caffe/parallel.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/solver.hpp"
 #include "caffe/util/io.hpp"
@@ -35,6 +36,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
 
   string snapshot_prefix_;
   shared_ptr<SGDSolver<Dtype> > solver_;
+  shared_ptr<P2PSync<Dtype> > sync_;
   int seed_;
   // Dimensions are determined by generate_sample_data.py
   // TODO this is brittle and the hdf5 file should be checked instead.
@@ -70,8 +72,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
 
   string RunLeastSquaresSolver(const Dtype learning_rate,
       const Dtype weight_decay, const Dtype momentum, const int num_iters,
-      const int iter_size = 1, const bool snapshot = false,
-      const char* from_snapshot = NULL) {
+      const int iter_size = 1, const int devices = 1,
+      const bool snapshot = false, const char* from_snapshot = NULL) {
     ostringstream proto;
     proto <<
        "snapshot_after_train: " << snapshot << " "
@@ -184,7 +186,20 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
         this->solver_->net()->Forward(empty_bottom_vec);
       }
     }
-    this->solver_->Solve();
+    if (devices == 1) {
+      this->solver_->Solve();
+    } else {
+      LOG(INFO) << "Multi-GPU test on " << devices << " devices";
+      vector<int> gpus;
+      for (int i = 0; i < devices; ++i) {
+        gpus.push_back(i);
+      }
+      Caffe::set_solver_count(gpus.size());
+      this->sync_.reset(new P2PSync<Dtype>(
+          this->solver_, NULL, this->solver_->param()));
+      this->sync_->run(gpus);
+      Caffe::set_solver_count(1);
+    }
     if (snapshot) {
       ostringstream resume_file;
       resume_file << snapshot_prefix_ << "/_iter_" << num_iters
@@ -410,20 +425,38 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
   void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
       const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
       const int iter_to_check = 0) {
-    // Initialize the solver and run K (= iter_to_check) solver iterations.
-    RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);
-
-    // Compute the (K+1)th update using the analytic least squares gradient.
-    vector<shared_ptr<Blob<Dtype> > > updated_params;
-    ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
-        &updated_params);
-
-    // Reinitialize the solver and run K+1 solver iterations.
-    RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
-        iter_to_check + 1);
-
-    // Check that the solver's solution matches ours.
-    CheckLeastSquaresUpdate(updated_params);
+    const int kNum = num_;
+    const int kIterSize = 1;
+    // Test over all numbers of devices.
+    int available_devices = 1;
+#ifndef CPU_ONLY
+    if (Caffe::mode() == Caffe::GPU) {
+      CUDA_CHECK(cudaGetDeviceCount(&available_devices));
+    }
+#endif
+    for (int devices = 1; devices <= available_devices; ++devices) {
+      // Configure batch size for single / multi device equivalence.
+      // Constant data is needed for multi device as for accumulation.
+      num_ = kNum * devices;
+
+      // Initialize the solver and run K (= iter_to_check) solver iterations
+      // (on single device).
+      RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+                            iter_to_check, kIterSize, 1);
+
+      // Compute the (K+1)th update using the analytic least squares gradient.
+      vector<shared_ptr<Blob<Dtype> > > updated_params;
+      ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
+          &updated_params);
+
+      // Reinitialize the solver and run K+1 solver iterations.
+      num_ = kNum;
+      RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+          iter_to_check + 1, kIterSize, devices);
+
+      // Check that the solver's solution matches ours.
+      CheckLeastSquaresUpdate(updated_params);
+    }
   }
 
   void TestSnapshot(const Dtype learning_rate = 1.0,
@@ -433,8 +466,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     const int total_num_iters = num_iters * 2;
     bool snapshot = false;
     const int kIterSize = 1;
+    const int kDevices = 1;
     RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
-        total_num_iters, kIterSize, snapshot);
+        total_num_iters, kIterSize, kDevices, snapshot);
 
     // Save the resulting param values.
     vector<shared_ptr<Blob<Dtype> > > param_copies;
@@ -464,12 +498,13 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     // Run the solver for num_iters iterations and snapshot.
     snapshot = true;
     string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
-        momentum, num_iters, kIterSize, snapshot);
+        momentum, num_iters, kIterSize, kDevices, snapshot);
 
     // Reinitialize the solver and run for num_iters more iterations.
     snapshot = false;
     RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
-        total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
+        total_num_iters, kIterSize, kDevices,
+        snapshot, snapshot_name.c_str());
 
     // Check that params now match.
     const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
index f7c53f2..d1d1fa8 100644 (file)
@@ -3,6 +3,7 @@
 
 #include "caffe/data_layers.hpp"
 #include "caffe/data_reader.hpp"
+#include "caffe/parallel.hpp"
 #include "caffe/util/blocking_queue.hpp"
 
 namespace caffe {
@@ -89,5 +90,7 @@ template class BlockingQueue<Batch<float>*>;
 template class BlockingQueue<Batch<double>*>;
 template class BlockingQueue<Datum*>;
 template class BlockingQueue<shared_ptr<DataReader::QueuePair> >;
+template class BlockingQueue<P2PSync<float>*>;
+template class BlockingQueue<P2PSync<double>*>;
 
 }  // namespace caffe
index 46f9959..9f31b37 100644 (file)
@@ -17,13 +17,17 @@ using caffe::Blob;
 using caffe::Caffe;
 using caffe::Net;
 using caffe::Layer;
+using caffe::Solver;
 using caffe::shared_ptr;
+using caffe::string;
 using caffe::Timer;
 using caffe::vector;
+using std::ostringstream;
 
-
-DEFINE_int32(gpu, -1,
-    "Run in GPU mode on given device ID.");
+DEFINE_string(gpu, "",
+    "Optional; run in GPU mode on given device IDs separated by ','."
+    "Use '-gpu all' to run on all available GPUs. The effective training "
+    "batch size is multiplied by the number of devices.");
 DEFINE_string(solver, "",
     "The solver definition protocol buffer text file.");
 DEFINE_string(model, "",
@@ -31,8 +35,8 @@ DEFINE_string(model, "",
 DEFINE_string(snapshot, "",
     "Optional; the snapshot solver state to resume training.");
 DEFINE_string(weights, "",
-    "Optional; the pretrained weights to initialize finetuning. "
-    "Cannot be set simultaneously with snapshot.");
+    "Optional; the pretrained weights to initialize finetuning, "
+    "separated by ','. Cannot be set simultaneously with snapshot.");
 DEFINE_int32(iterations, 50,
     "The number of iterations to run.");
 
@@ -66,6 +70,29 @@ static BrewFunction GetBrewFunction(const caffe::string& name) {
   }
 }
 
+// Parse GPU ids or use all available devices
+static void get_gpus(vector<int>* gpus) {
+  if (FLAGS_gpu == "all") {
+    int count = 0;
+#ifndef CPU_ONLY
+    CUDA_CHECK(cudaGetDeviceCount(&count));
+#else
+    NO_GPU;
+#endif
+    for (int i = 0; i < count; ++i) {
+      gpus->push_back(i);
+    }
+  } else if (FLAGS_gpu.size()) {
+    vector<string> strings;
+    boost::split(strings, FLAGS_gpu, boost::is_any_of(","));
+    for (int i = 0; i < strings.size(); ++i) {
+      gpus->push_back(boost::lexical_cast<int>(strings[i]));
+    }
+  } else {
+    CHECK_EQ(gpus->size(), 0);
+  }
+}
+
 // caffe commands to call by
 //     caffe <command> <args>
 //
@@ -74,10 +101,13 @@ static BrewFunction GetBrewFunction(const caffe::string& name) {
 
 // Device Query: show diagnostic information for a GPU device.
 int device_query() {
-  CHECK_GT(FLAGS_gpu, -1) << "Need a device ID to query.";
-  LOG(INFO) << "Querying device ID = " << FLAGS_gpu;
-  caffe::Caffe::SetDevice(FLAGS_gpu);
-  caffe::Caffe::DeviceQuery();
+  LOG(INFO) << "Querying GPUs " << FLAGS_gpu;
+  vector<int> gpus;
+  get_gpus(&gpus);
+  for (int i = 0; i < gpus.size(); ++i) {
+    caffe::Caffe::SetDevice(gpus[i]);
+    caffe::Caffe::DeviceQuery();
+  }
   return 0;
 }
 RegisterBrewFunction(device_query);
@@ -106,34 +136,49 @@ int train() {
   caffe::SolverParameter solver_param;
   caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);
 
-  // If the gpu flag is not provided, allow the mode and device to be set
+  // If the gpus flag is not provided, allow the mode and device to be set
   // in the solver prototxt.
-  if (FLAGS_gpu < 0
+  if (FLAGS_gpu.size() == 0
       && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
-    FLAGS_gpu = solver_param.device_id();
+      if (solver_param.has_device_id()) {
+          FLAGS_gpu = ""  +
+              boost::lexical_cast<string>(solver_param.device_id());
+      } else {  // Set default GPU if unspecified
+          FLAGS_gpu = "" + boost::lexical_cast<string>(0);
+      }
   }
 
-  // Set device id and mode
-  if (FLAGS_gpu >= 0) {
-    LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu;
-    Caffe::SetDevice(FLAGS_gpu);
-    Caffe::set_mode(Caffe::GPU);
-  } else {
-    LOG(INFO) << "Use CPU.";
+  vector<int> gpus;
+  get_gpus(&gpus);
+  if (gpus.size() == 0) {
     Caffe::set_mode(Caffe::CPU);
+  } else {
+    ostringstream s;
+    for (int i = 0; i < gpus.size(); ++i) {
+      s << (i ? ", " : "") << gpus[i];
+    }
+    LOG(INFO) << "Using GPUs " << s.str();
+
+    solver_param.set_device_id(gpus[0]);
+    Caffe::SetDevice(gpus[0]);
+    Caffe::set_mode(Caffe::GPU);
+    Caffe::set_solver_count(gpus.size());
   }
 
-  LOG(INFO) << "Starting Optimization";
-  shared_ptr<caffe::Solver<float> >
-    solver(caffe::GetSolver<float>(solver_param));
+  shared_ptr<Solver<float> > solver(caffe::GetSolver<float>(solver_param));
 
   if (FLAGS_snapshot.size()) {
     LOG(INFO) << "Resuming from " << FLAGS_snapshot;
-    solver->Solve(FLAGS_snapshot);
+    solver->Restore(FLAGS_snapshot.c_str());
   } else if (FLAGS_weights.size()) {
-    CopyLayers(&*solver, FLAGS_weights);
-    solver->Solve();
+    CopyLayers(solver.get(), FLAGS_weights);
+  }
+
+  if (gpus.size() > 1) {
+    caffe::P2PSync<float> sync(solver, NULL, solver->param());
+    sync.run(gpus);
   } else {
+    LOG(INFO) << "Starting Optimization";
     solver->Solve();
   }
   LOG(INFO) << "Optimization Done.";
@@ -148,9 +193,11 @@ int test() {
   CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score.";
 
   // Set device id and mode
-  if (FLAGS_gpu >= 0) {
-    LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu;
-    Caffe::SetDevice(FLAGS_gpu);
+  vector<int> gpus;
+  get_gpus(&gpus);
+  if (gpus.size() != 0) {
+    LOG(INFO) << "Use GPU with device ID " << gpus[0];
+    Caffe::SetDevice(gpus[0]);
     Caffe::set_mode(Caffe::GPU);
   } else {
     LOG(INFO) << "Use CPU.";
@@ -213,9 +260,11 @@ int time() {
   CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time.";
 
   // Set device id and mode
-  if (FLAGS_gpu >= 0) {
-    LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu;
-    Caffe::SetDevice(FLAGS_gpu);
+  vector<int> gpus;
+  get_gpus(&gpus);
+  if (gpus.size() != 0) {
+    LOG(INFO) << "Use GPU with device ID " << gpus[0];
+    Caffe::SetDevice(gpus[0]);
     Caffe::set_mode(Caffe::GPU);
   } else {
     LOG(INFO) << "Use CPU.";