From e5575cf17a43a56e4ba9bc5465548ac0512197d8 Mon Sep 17 00:00:00 2001 From: Cyprien Noel Date: Tue, 19 May 2015 11:11:05 -0700 Subject: [PATCH] Multi-GPU - Parallelize batches among GPUs and tree-reduce the gradients - The effective batch size scales with the number of devices - Batch size is multiplied by the number of devices - Split batches between GPUs, and tree-reduce the gradients - Detect machine topology (twin-GPU boards, P2P connectivity) - Track device in syncedmem (thanks @thatguymike) - Insert a callback in the solver for minimal code change - Accept list for gpu flag of caffe tool, e.g. '-gpu 0,1' or '-gpu all'. Run on default GPU if no ID given. - Add multi-GPU solver test - Deterministic architecture for reproducible runs --- include/caffe/caffe.hpp | 1 + include/caffe/common.hpp | 7 + include/caffe/internal_thread.hpp | 3 +- include/caffe/layer_factory.hpp | 4 +- include/caffe/parallel.hpp | 118 +++++++ include/caffe/solver.hpp | 38 +++ include/caffe/syncedmem.hpp | 7 +- src/caffe/common.cpp | 5 +- src/caffe/data_reader.cpp | 4 +- src/caffe/data_transformer.cpp | 4 +- src/caffe/internal_thread.cpp | 9 +- src/caffe/net.cpp | 180 +++++++---- src/caffe/parallel.cpp | 430 ++++++++++++++++++++++++++ src/caffe/solver.cpp | 57 +++- src/caffe/syncedmem.cpp | 34 +- src/caffe/test/test_gradient_based_solver.cpp | 75 +++-- src/caffe/util/blocking_queue.cpp | 3 + tools/caffe.cpp | 111 +++++-- 18 files changed, 949 insertions(+), 141 deletions(-) create mode 100644 include/caffe/parallel.hpp create mode 100644 src/caffe/parallel.cpp diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index 3c829f2..68a5e1d 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -10,6 +10,7 @@ #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/benchmark.hpp" diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 3fa8143..1df6b9a 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -149,6 +149,11 @@ class Caffe { static void SetDevice(const int device_id); // Prints the current GPU status. static void DeviceQuery(); + // Parallel training info + inline static int solver_count() { return Get().solver_count_; } + inline static void set_solver_count(int val) { Get().solver_count_ = val; } + inline static bool root_solver() { return Get().root_solver_; } + inline static void set_root_solver(bool val) { Get().root_solver_ = val; } protected: #ifndef CPU_ONLY @@ -158,6 +163,8 @@ class Caffe { shared_ptr random_generator_; Brew mode_; + int solver_count_; + bool root_solver_; private: // The private constructor to avoid duplicate instantiation. diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index be6ff7f..6a8c5a0 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -42,7 +42,8 @@ class InternalThread { bool must_stop(); private: - void entry(int device, Caffe::Brew mode, int rand_seed); + void entry(int device, Caffe::Brew mode, int rand_seed, int solver_count, + bool root_solver); shared_ptr thread_; }; diff --git a/include/caffe/layer_factory.hpp b/include/caffe/layer_factory.hpp index 2fcd938..32e849d 100644 --- a/include/caffe/layer_factory.hpp +++ b/include/caffe/layer_factory.hpp @@ -71,7 +71,9 @@ class LayerRegistry { // Get a layer using a LayerParameter. static shared_ptr > CreateLayer(const LayerParameter& param) { - LOG(INFO) << "Creating layer " << param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating layer " << param.name(); + } const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown layer type: " << type diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp new file mode 100644 index 0000000..85fc2b5 --- /dev/null +++ b/include/caffe/parallel.hpp @@ -0,0 +1,118 @@ +#ifndef CAFFE_PARALLEL_HPP_ +#define CAFFE_PARALLEL_HPP_ + +#include + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/internal_thread.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +// Represents a net parameters. Once a net is created, its parameter buffers can +// be replaced by ones from Params, to allow parallelization. Params ensures +// parameters are allocated in one consecutive array. +template +class Params { + public: + explicit Params(shared_ptr > root_solver); + virtual ~Params() { + } + + inline size_t size() const { + return size_; + } + inline Dtype* data() const { + return data_; + } + inline Dtype* diff() const { + return diff_; + } + + protected: + const size_t size_; // Size of buffers + Dtype* data_; // Network parameters + Dtype* diff_; // Gradient + +DISABLE_COPY_AND_ASSIGN(Params); +}; + +// Params stored in GPU memory. +template +class GPUParams : public Params { + public: + GPUParams(shared_ptr > root_solver, int device); + virtual ~GPUParams(); + + void configure(Solver* solver) const; + + protected: + using Params::size_; + using Params::data_; + using Params::diff_; +}; + +class DevicePair { + public: + DevicePair(int parent, int device) + : parent_(parent), + device_(device) { + } + inline int parent() { + return parent_; + } + inline int device() { + return device_; + } + + // Group GPUs in pairs, by proximity depending on machine's topology + static void compute(const vector devices, vector* pairs); + + protected: + int parent_; + int device_; +}; + +// Synchronous data parallelism using map-reduce between local GPUs. +template +class P2PSync : public GPUParams, public Solver::Callback, + public InternalThread { + public: + explicit P2PSync(shared_ptr > root_solver, + P2PSync* parent, const SolverParameter& param); + virtual ~P2PSync(); + + inline const shared_ptr >& solver() const { + return solver_; + } + + void run(const vector& gpus); + + protected: + void on_start(); + void on_gradients_ready(); + + void InternalThreadEntry(); + + P2PSync* parent_; + vector*> children_; + BlockingQueue*> queue_; + const int initial_iter_; + Dtype* parent_grads_; + shared_ptr > solver_; + + using Params::size_; + using Params::data_; + using Params::diff_; +}; + +} // namespace caffe + +#endif diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index fbade93..89a6c76 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -32,12 +32,27 @@ class Solver { // methods to restore the state from the appropriate snapshot type. void Restore(const char* resume_file); virtual ~Solver() {} + inline const SolverParameter& param() const { return param_; } inline shared_ptr > net() { return net_; } inline const vector > >& test_nets() { return test_nets_; } int iter() { return iter_; } + // Invoked at specific points during an iteration + class Callback { + protected: + virtual void on_start() = 0; + virtual void on_gradients_ready() = 0; + + template + friend class Solver; + }; + const vector& callbacks() const { return callbacks_; } + void add_callback(Callback* value) { + callbacks_.push_back(value); + } + protected: // Make and apply the update value for the current iteration. virtual void ApplyUpdate() = 0; @@ -62,10 +77,33 @@ class Solver { int current_step_; shared_ptr > net_; vector > > test_nets_; + vector callbacks_; DISABLE_COPY_AND_ASSIGN(Solver); }; +/** + * @brief Solver that only computes gradients, used as worker + * for multi-GPU training. + */ +template +class WorkerSolver : public Solver { + public: + explicit WorkerSolver(const SolverParameter& param) + : Solver(param) {} + + protected: + void ApplyUpdate() {} + void SnapshotSolverState(const string& model_filename) { + LOG(FATAL) << "Should not be called on worker solver."; + } + void RestoreSolverStateFromBinaryProto(const string& state_file) { + LOG(FATAL) << "Should not be called on worker solver."; + } + void RestoreSolverStateFromHDF5(const string& state_file) { + LOG(FATAL) << "Should not be called on worker solver."; + } +}; /** * @brief Optimizes the parameters of a Net using diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index 4a1a2f3..62aadef 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -45,14 +45,15 @@ class SyncedMemory { public: SyncedMemory() : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {} explicit SyncedMemory(size_t size) : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {} ~SyncedMemory(); const void* cpu_data(); void set_cpu_data(void* data); const void* gpu_data(); + void set_gpu_data(void* data); void* mutable_cpu_data(); void* mutable_gpu_data(); enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; @@ -71,6 +72,8 @@ class SyncedMemory { size_t size_; SyncedHead head_; bool own_cpu_data_; + bool own_gpu_data_; + int gpu_device_; DISABLE_COPY_AND_ASSIGN(SyncedMemory); }; // class SyncedMemory diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 0215c76..7077f37 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -51,7 +51,8 @@ void GlobalInit(int* pargc, char*** pargv) { #ifdef CPU_ONLY // CPU-only Caffe. Caffe::Caffe() - : random_generator_(), mode_(Caffe::CPU) { } + : random_generator_(), mode_(Caffe::CPU), + solver_count_(1), root_solver_(true) { } Caffe::~Caffe() { } @@ -95,7 +96,7 @@ void* Caffe::RNG::generator() { Caffe::Caffe() : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU) { + mode_(Caffe::CPU), solver_count_(1), root_solver_(true) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { diff --git a/src/caffe/data_reader.cpp b/src/caffe/data_reader.cpp index 60606f0..1637820 100644 --- a/src/caffe/data_reader.cpp +++ b/src/caffe/data_reader.cpp @@ -76,9 +76,7 @@ void DataReader::Body::InternalThreadEntry() { shared_ptr cursor(db->NewCursor()); vector > qps; try { - // int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; - // TODO single solver until multi-gpu merge - int solver_count = 1; + int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; // To ensure deterministic runs, only start running once all solvers // are ready. But solvers need to peek on one item during initialization, diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 2263392..4666d9b 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -19,7 +19,9 @@ DataTransformer::DataTransformer(const TransformationParameter& param, CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time"; const string& mean_file = param.mean_file(); - LOG(INFO) << "Loading mean file from: " << mean_file; + if (Caffe::root_solver()) { + LOG(INFO) << "Loading mean file from: " << mean_file; + } BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index b193826..104884e 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -27,21 +27,26 @@ void InternalThread::StartInternalThread() { #endif Caffe::Brew mode = Caffe::mode(); int rand_seed = caffe_rng_rand(); + int solver_count = Caffe::solver_count(); + bool root_solver = Caffe::root_solver(); try { thread_.reset(new boost::thread(&InternalThread::entry, this, device, mode, - rand_seed)); + rand_seed, solver_count, root_solver)); } catch (std::exception& e) { LOG(FATAL) << "Thread exception: " << e.what(); } } -void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed) { +void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed, + int solver_count, bool root_solver) { #ifndef CPU_ONLY CUDA_CHECK(cudaSetDevice(device)); #endif Caffe::set_mode(mode); Caffe::set_random_seed(rand_seed); + Caffe::set_solver_count(solver_count); + Caffe::set_root_solver(root_solver); InternalThreadEntry(); } diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 0e5ed80..5d0f432 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -10,6 +10,7 @@ #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/hdf5.hpp" #include "caffe/util/insert_splits.hpp" @@ -41,8 +42,10 @@ void Net::Init(const NetParameter& in_param) { // the current NetState. NetParameter filtered_param; FilterNet(in_param, &filtered_param); - LOG(INFO) << "Initializing net from parameters: " << std::endl - << filtered_param.DebugString(); + if (Caffe::root_solver()) { + LOG(INFO) << "Initializing net from parameters: " << std::endl + << filtered_param.DebugString(); + } // Create a copy of filtered_param with splits added where necessary. NetParameter param; InsertSplits(filtered_param, ¶m); @@ -66,7 +69,8 @@ void Net::Init(const NetParameter& in_param) { const int layer_id = -1; // inputs have fake layer ID -1 AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + DLOG_IF(INFO, Caffe::root_solver()) + << "Memory required for data: " << memory_used_ * sizeof(Dtype); // For each layer, set up its input and output bottom_vecs_.resize(param.layer_size()); top_vecs_.resize(param.layer_size()); @@ -89,7 +93,9 @@ void Net::Init(const NetParameter& in_param) { } layers_.push_back(LayerRegistry::CreateLayer(layer_param)); layer_names_.push_back(layer_param.name()); - LOG(INFO) << "Creating Layer " << layer_param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating Layer " << layer_param.name(); + } bool need_backward = false; // Figure out this layer's input and output @@ -119,20 +125,30 @@ void Net::Init(const NetParameter& in_param) { } } // After this layer is connected, set it up. - LOG(INFO) << "Setting up " << layer_names_[layer_id]; + if (Caffe::root_solver()) { + LOG(INFO) << "Setting up " << layer_names_[layer_id]; + } layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) { blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0)); } blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id); - LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string(); + if (Caffe::root_solver()) { + LOG(INFO) << "Top shape: " + << top_vecs_[layer_id][top_id]->shape_string(); + } if (layer->loss(top_id)) { - LOG(INFO) << " with loss weight " << layer->loss(top_id); + if (Caffe::root_solver()) { + LOG(INFO) << " with loss weight " << layer->loss(top_id); + } } memory_used_ += top_vecs_[layer_id][top_id]->count(); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + DLOG(INFO) << "Memory required for data: " + << memory_used_ * sizeof(Dtype); + } const int param_size = layer_param.param_size(); const int num_param_blobs = layers_[layer_id]->blobs().size(); CHECK_LE(param_size, num_param_blobs) @@ -191,10 +207,14 @@ void Net::Init(const NetParameter& in_param) { } if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; } if (layer_need_backward_[layer_id]) { - LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + } } else { - LOG(INFO) << layer_names_[layer_id] - << " does not need backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] + << " does not need backward computation."; + } } for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size(); ++bottom_id) { @@ -234,7 +254,9 @@ void Net::Init(const NetParameter& in_param) { // In the end, all remaining blobs are considered output blobs. for (set::iterator it = available_blobs.begin(); it != available_blobs.end(); ++it) { - LOG(INFO) << "This network produces output " << *it; + if (Caffe::root_solver()) { + LOG(INFO) << "This network produces output " << *it; + } net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); net_output_blob_indices_.push_back(blob_name_to_idx[*it]); } @@ -246,8 +268,10 @@ void Net::Init(const NetParameter& in_param) { } ShareWeights(); debug_info_ = param.debug_info(); - LOG(INFO) << "Network initialization done."; - LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + LOG(INFO) << "Network initialization done."; + LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + } } template @@ -286,27 +310,33 @@ bool Net::StateMeetsRule(const NetState& state, // Check whether the rule is broken due to phase. if (rule.has_phase()) { if (rule.phase() != state.phase()) { - LOG(INFO) << "The NetState phase (" << state.phase() - << ") differed from the phase (" << rule.phase() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState phase (" << state.phase() + << ") differed from the phase (" << rule.phase() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to min level. if (rule.has_min_level()) { if (state.level() < rule.min_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the min_level (" << rule.min_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the min_level (" << rule.min_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to max level. if (rule.has_max_level()) { if (state.level() > rule.max_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the max_level (" << rule.max_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the max_level (" << rule.max_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } @@ -319,8 +349,10 @@ bool Net::StateMeetsRule(const NetState& state, if (rule.stage(i) == state.stage(j)) { has_stage = true; } } if (!has_stage) { - LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -333,8 +365,10 @@ bool Net::StateMeetsRule(const NetState& state, if (rule.not_stage(i) == state.stage(j)) { has_stage = true; } } if (has_stage) { - LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -356,7 +390,9 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id && blob_name == layer_param->bottom(top_id)) { // In-place computation - LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + if (Caffe::root_solver()) { + LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + } top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get()); top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]); } else if (blob_name_to_idx && @@ -366,10 +402,12 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, LOG(FATAL) << "Duplicate blobs produced by multiple sources."; } else { // Normal output. - if (layer_param) { - LOG(INFO) << layer_param->name() << " -> " << blob_name; - } else { - LOG(INFO) << "Input " << top_id << " -> " << blob_name; + if (Caffe::root_solver()) { + if (layer_param) { + LOG(INFO) << layer_param->name() << " -> " << blob_name; + } else { + LOG(INFO) << "Input " << top_id << " -> " << blob_name; + } } shared_ptr > blob_pointer(new Blob()); const int blob_id = blobs_.size(); @@ -409,7 +447,9 @@ int Net::AppendBottom(const NetParameter& param, const int layer_id, << " (at index " << bottom_id << ") to layer " << layer_id; } const int blob_id = (*blob_name_to_idx)[blob_name]; - LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + } bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); @@ -468,9 +508,10 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, param_layer_indices_[owner_net_param_id]; const int owner_layer_id = owner_index.first; const int owner_param_id = owner_index.second; - LOG(INFO) << "Sharing parameters '" << param_name << "' owned by " - << "layer '" << layer_names_[owner_layer_id] << "', param " - << "index " << owner_param_id; + LOG_IF(INFO, Caffe::root_solver()) << "Sharing parameters '" << param_name + << "' owned by " + << "layer '" << layer_names_[owner_layer_id] << "', param " + << "index " << owner_param_id; Blob* this_blob = layers_[layer_id]->blobs()[param_id].get(); Blob* owner_blob = layers_[owner_layer_id]->blobs()[owner_param_id].get(); @@ -595,8 +636,10 @@ void Net::InputDebugInfo(const int input_id) { const Blob& blob = *net_input_blobs_[input_id]; const string& blob_name = blob_names_[net_input_blob_indices_[input_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Input " << blob_name << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Input " << blob_name << " data: " << data_abs_val_mean; + } } template @@ -605,9 +648,12 @@ void Net::ForwardDebugInfo(const int layer_id) { const Blob& blob = *top_vecs_[layer_id][top_id]; const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", top blob " << blob_name + << " data: " << data_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { @@ -615,9 +661,12 @@ void Net::ForwardDebugInfo(const int layer_id) { const int net_param_id = param_id_vecs_[layer_id][param_id]; const string& blob_name = param_display_names_[net_param_id]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << blob_name + << " data: " << data_abs_val_mean; + } } } @@ -629,18 +678,24 @@ void Net::BackwardDebugInfo(const int layer_id) { const Blob& blob = *bottom_vec[bottom_id]; const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", bottom blob " << blob_name + << " diff: " << diff_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; } const Blob& blob = *layers_[layer_id]->blobs()[param_id]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << param_id - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << param_id + << " diff: " << diff_abs_val_mean; + } } } @@ -653,17 +708,22 @@ void Net::UpdateDebugInfo(const int param_id) { const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); if (param_owner < 0) { const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Update] Layer " << layer_name - << ", param " << param_display_name - << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param " << param_display_name + << " data: " << data_abs_val_mean + << "; diff: " << diff_abs_val_mean; + } } else { const string& owner_layer_name = layer_names_[param_layer_indices_[param_owner].first]; - LOG(INFO) << " [Update] Layer " << layer_name - << ", param blob " << param_display_name - << " (owned by layer " << owner_layer_name << ", " - << "param " << param_display_names_[param_owners_[param_id]] << ")" - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param blob " << param_display_name + << " (owned by layer " << owner_layer_name << ", " << "param " + << param_display_names_[param_owners_[param_id]] << ")" + << " diff: " << diff_abs_val_mean; + } } } @@ -720,8 +780,8 @@ void Net::Backward() { const Dtype l2norm_data = std::sqrt(sumsq_data); const Dtype l2norm_diff = std::sqrt(sumsq_diff); LOG(ERROR) << " [Backward] All net params (data, diff): " - << "L1 norm = (" << asum_data << ", " << asum_diff << "); " - << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; + << "L1 norm = (" << asum_data << ", " << asum_diff << "); " + << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; } } diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp new file mode 100644 index 0000000..3fef8cf --- /dev/null +++ b/src/caffe/parallel.cpp @@ -0,0 +1,430 @@ +#ifndef CPU_ONLY +#include +#endif +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "boost/thread.hpp" +#include "caffe/caffe.hpp" +#include "caffe/parallel.hpp" + +namespace caffe { + +enum Op { + copy, + replace_cpu, + replace_gpu, + replace_cpu_diff, + replace_gpu_diff +}; + +template +static void apply_buffers(const vector*>& blobs, + Dtype* buffer, size_t total_size, Op op) { + Dtype* ptr = buffer; + for (int i = 0; i < blobs.size(); ++i) { + int size = blobs[i]->count(); + switch (op) { + case copy: { + // Init buffer to current values of blobs + caffe_copy(size, + reinterpret_cast(blobs[i]->data()->cpu_data()), + ptr); + break; + } + case replace_cpu: + blobs[i]->data()->set_cpu_data(ptr); + break; + case replace_gpu: + blobs[i]->data()->set_gpu_data(ptr); + break; + case replace_cpu_diff: + blobs[i]->diff()->set_cpu_data(ptr); + break; + case replace_gpu_diff: + blobs[i]->diff()->set_gpu_data(ptr); + break; + } + ptr += size; + } + CHECK_EQ(total_size, ptr - buffer); +} + +// Buffer size necessary to store given blobs +template +static size_t total_size(const vector*>& params) { + size_t size = 0; + for (int i = 0; i < params.size(); ++i) + size += params[i]->count(); + return size; +} + +template +Params::Params(shared_ptr > root_solver) + : size_(total_size(root_solver->net()->learnable_params())), + data_(), + diff_() { +} + +template +GPUParams::GPUParams(shared_ptr > root_solver, int device) + : Params(root_solver) { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + + // Allocate device buffers + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaMalloc(&data_, size_ * sizeof(Dtype))); + + // Copy blob values + const vector*>& net = + root_solver->net()->learnable_params(); + apply_buffers(net, data_, size_, copy); + + CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype))); + caffe_gpu_set(size_, Dtype(0), diff_); + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template +GPUParams::~GPUParams() { +#ifndef CPU_ONLY + CUDA_CHECK(cudaFree(data_)); + CUDA_CHECK(cudaFree(diff_)); +#endif +} + +template +void GPUParams::configure(Solver* solver) const { + const vector*>& net = + solver->net()->learnable_params(); + apply_buffers(net, data_, size_, replace_gpu); + apply_buffers(net, diff_, size_, replace_gpu_diff); +} + +void DevicePair::compute(const vector devices, vector* pairs) { +#ifndef CPU_ONLY + vector remaining(devices); + + // Group GPUs by board + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + cudaDeviceProp a, b; + CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); + CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); + if (a.isMultiGpuBoard && b.isMultiGpuBoard) { + if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + ostringstream s; + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); + + // Group by P2P accessibility + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); + if (access) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + s.str(""); + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); + + // Group remaining + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + CHECK_EQ(remaining.size(), 1); + pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); + + CHECK(pairs->size() == devices.size()); + for (int i = 0; i < pairs->size(); ++i) { + CHECK((*pairs)[i].parent() != (*pairs)[i].device()); + for (int j = i + 1; j < pairs->size(); ++j) { + CHECK((*pairs)[i].device() != (*pairs)[j].device()); + } + } +#else + NO_GPU; +#endif +} + +// + +template +P2PSync::P2PSync(shared_ptr > root_solver, + P2PSync* parent, const SolverParameter& param) + : GPUParams(root_solver, param.device_id()), + parent_(parent), + children_(), + queue_(), + initial_iter_(root_solver->iter()), + solver_() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = param.device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent == NULL) { + solver_ = root_solver; + } else { + Caffe::set_root_solver(false); + solver_.reset(new WorkerSolver(param)); + Caffe::set_root_solver(true); + } + this->configure(solver_.get()); + solver_->add_callback(this); + + if (parent) { + // Enable p2p access between devices + const int peer = parent->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0)); + } else { + LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer; + } + // Allocate receiving buffer on parent + CUDA_CHECK(cudaSetDevice(peer)); + CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype))); + CUDA_CHECK(cudaSetDevice(self)); + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template +P2PSync::~P2PSync() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = solver_->param().device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent_) { + CUDA_CHECK(cudaFree(parent_grads_)); + const int peer = parent_->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceDisablePeerAccess(peer)); + } + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#endif +} + +template +void P2PSync::InternalThreadEntry() { + Caffe::SetDevice(solver_->param().device_id()); + CHECK(Caffe::root_solver()); + Caffe::set_root_solver(false); + // See if there is a defined seed and reset random state if so + if (solver_->param().random_seed() >= 0) { + // Fetch random seed and modulate by device ID to make sure + // everyone doesn't have the same seed. We seem to have some + // solver instability if we have everyone with the same seed + Caffe::set_random_seed( + solver_->param().random_seed() + solver_->param().device_id()); + } + solver_->Step(solver_->param().max_iter() - initial_iter_); +} + +template +void P2PSync::on_start() { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#else +// CHECK(false); +#endif + + // Wait for update from parent + if (parent_) { + P2PSync *parent = queue_.pop(); + CHECK(parent == parent_); + } + + // Update children + for (int i = 0; i < children_.size(); ++i) { + Dtype* src = data_; + Dtype* dst = children_[i]->data_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == children_[i]->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + } + if (children_.size()) { + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + } + for (int i = 0; i < children_.size(); ++i) { + children_[i]->queue_.push(this); + } +#endif +} + +template +void P2PSync::on_gradients_ready() { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#endif + + // Sum children gradients as they appear in the queue + for (int i = 0; i < children_.size(); ++i) { + P2PSync *child = queue_.pop(); + Dtype* src = child->parent_grads_; + Dtype* dst = diff_; + +#ifdef DEBUG + bool ok = false; + for (int j = 0; j < children_.size(); ++j) { + if (child == children_[j]) { + ok = true; + } + } + CHECK(ok); + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == device); +#endif + + caffe_gpu_add(size_, src, dst, dst); + } + + // Send gradients to parent + if (parent_) { + Dtype* src = diff_; + Dtype* dst = parent_grads_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == parent_->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + parent_->queue_.push(this); + } else { + // Loss functions divide gradients by the batch size, so to compensate + // for split batch, the root solver divides by number of solvers. + caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_); + } +#endif +} + +template +void P2PSync::run(const vector& gpus) { + // Pair devices for map-reduce synchronization + vector pairs; + DevicePair::compute(gpus, &pairs); + ostringstream s; + for (int i = 1; i < pairs.size(); ++i) { + s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device(); + } + LOG(INFO)<< "GPUs pairs " << s.str(); + + SolverParameter param(solver_->param()); + vector > > syncs(gpus.size()); + + // Build the GPU tree by finding the parent for each solver + for (int attempts = 0; attempts < pairs.size(); ++attempts) { + for (int i = 1; i < pairs.size(); ++i) { + if (!syncs[i].get()) { + P2PSync* parent = NULL; + for (int j = 0; j < syncs.size(); ++j) { + P2PSync* sync = j == 0 ? this : syncs[j].get(); + if (sync) { + const SolverParameter& p = sync->solver()->param(); + if (p.device_id() == pairs[i].parent()) { + parent = sync; + } + } + } + if (parent) { + param.set_device_id(pairs[i].device()); + syncs[i].reset(new P2PSync(solver_, parent, param)); + parent->children_.push_back((P2PSync*) syncs[i].get()); + } + } + } + } + + LOG(INFO)<< "Starting Optimization"; + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StartInternalThread(); + } + + // Run root solver on current thread + solver_->Solve(); + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StopInternalThread(); + } +} + +INSTANTIATE_CLASS(Params); +INSTANTIATE_CLASS(GPUParams); +INSTANTIATE_CLASS(P2PSync); + +} // namespace caffe + diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 54e085a..b6fd6b6 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -19,13 +19,13 @@ namespace caffe { template Solver::Solver(const SolverParameter& param) - : net_() { + : net_(), callbacks_() { Init(param); } template Solver::Solver(const string& param_file) - : net_() { + : net_(), callbacks_() { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -33,17 +33,19 @@ Solver::Solver(const string& param_file) template void Solver::Init(const SolverParameter& param) { - LOG(INFO) << "Initializing solver from parameters: " << std::endl - << param.DebugString(); + LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: " + << std::endl << param.DebugString(); param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; - if (param_.random_seed() >= 0) { + if (Caffe::root_solver() && param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } // Scaffolding code InitTrainNet(); - InitTestNets(); - LOG(INFO) << "Solver scaffolding done."; + if (Caffe::root_solver()) { + InitTestNets(); + LOG(INFO) << "Solver scaffolding done."; + } iter_ = 0; current_step_ = 0; } @@ -59,19 +61,22 @@ void Solver::InitTrainNet() { << "one of these fields specifying a train_net: " << field_names; NetParameter net_param; if (param_.has_train_net_param()) { - LOG(INFO) << "Creating training net specified in train_net_param."; + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net specified in train_net_param."; net_param.CopyFrom(param_.train_net_param()); } else if (param_.has_train_net()) { - LOG(INFO) << "Creating training net from train_net file: " - << param_.train_net(); + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net from train_net file: " << param_.train_net(); ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); } if (param_.has_net_param()) { - LOG(INFO) << "Creating training net specified in net_param."; + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net specified in net_param."; net_param.CopyFrom(param_.net_param()); } if (param_.has_net()) { - LOG(INFO) << "Creating training net from net file: " << param_.net(); + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net from net file: " << param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); } // Set the correct NetState. We start with the solver defaults (lowest @@ -88,6 +93,7 @@ void Solver::InitTrainNet() { template void Solver::InitTestNets() { + CHECK(Caffe::root_solver()); const bool has_net_param = param_.has_net_param(); const bool has_net_file = param_.has_net(); const int num_generic_nets = has_net_param + has_net_file; @@ -175,10 +181,14 @@ void Solver::Step(int iters) { // zero-init the params net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization())) { + && (iter_ > 0 || param_.test_initialization()) + && Caffe::root_solver()) { TestAll(); } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_start(); + } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); // accumulate the loss and gradient @@ -198,7 +208,8 @@ void Solver::Step(int iters) { losses[idx] = loss; } if (display) { - LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss; + LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ + << ", loss = " << smoothed_loss; const vector*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { @@ -213,12 +224,15 @@ void Solver::Step(int iters) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } - LOG(INFO) << " Train net output #" + LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" << score_index++ << ": " << output_name << " = " << result_vec[k] << loss_msg_stream.str(); } } } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_gradients_ready(); + } ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate @@ -226,7 +240,9 @@ void Solver::Step(int iters) { ++iter_; // Save a snapshot if needed. - if (param_.snapshot() && iter_ % param_.snapshot() == 0) { + if (param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) { Snapshot(); } } @@ -234,6 +250,7 @@ void Solver::Step(int iters) { template void Solver::Solve(const char* resume_file) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); @@ -278,6 +295,7 @@ void Solver::TestAll() { template void Solver::Test(const int test_net_id) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")"; CHECK_NOTNULL(test_nets_[test_net_id].get())-> @@ -328,13 +346,14 @@ void Solver::Test(const int test_net_id) { << " = " << loss_weight * mean_score << " loss)"; } LOG(INFO) << " Test net output #" << i << ": " << output_name << " = " - << mean_score << loss_msg_stream.str(); + << mean_score << loss_msg_stream.str(); } } template void Solver::Snapshot() { + CHECK(Caffe::root_solver()); string model_filename; switch (param_.snapshot_format()) { case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: @@ -379,6 +398,7 @@ string Solver::SnapshotToHDF5() { template void Solver::Restore(const char* state_file) { + CHECK(Caffe::root_solver()); string state_filename(state_file); if (state_filename.size() >= 3 && state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { @@ -480,6 +500,7 @@ void SGDSolver::ClipGradients() { template void SGDSolver::ApplyUpdate() { + CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; @@ -723,6 +744,7 @@ void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) { template void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); @@ -783,6 +805,7 @@ void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { template void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 0da7a3b..a667a86 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -12,8 +12,14 @@ SyncedMemory::~SyncedMemory() { } #ifndef CPU_ONLY - if (gpu_ptr_) { + if (gpu_ptr_ && own_gpu_data_) { + int initial_device; + cudaGetDevice(&initial_device); + if (gpu_device_ != -1) { + CUDA_CHECK(cudaSetDevice(gpu_device_)); + } CUDA_CHECK(cudaFree(gpu_ptr_)); + cudaSetDevice(initial_device); } #endif // CPU_ONLY } @@ -48,13 +54,17 @@ inline void SyncedMemory::to_gpu() { #ifndef CPU_ONLY switch (head_) { case UNINITIALIZED: + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); caffe_gpu_memset(size_, 0, gpu_ptr_); head_ = HEAD_AT_GPU; + own_gpu_data_ = true; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; @@ -92,6 +102,26 @@ const void* SyncedMemory::gpu_data() { #endif } +void SyncedMemory::set_gpu_data(void* data) { +#ifndef CPU_ONLY + CHECK(data); + if (own_gpu_data_) { + int initial_device; + cudaGetDevice(&initial_device); + if (gpu_device_ != -1) { + CUDA_CHECK(cudaSetDevice(gpu_device_)); + } + CUDA_CHECK(cudaFree(gpu_ptr_)); + cudaSetDevice(initial_device); + } + gpu_ptr_ = data; + head_ = HEAD_AT_GPU; + own_gpu_data_ = false; +#else + NO_GPU; +#endif +} + void* SyncedMemory::mutable_cpu_data() { to_cpu(); head_ = HEAD_AT_CPU; @@ -112,7 +142,9 @@ void* SyncedMemory::mutable_gpu_data() { void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { CHECK(head_ == HEAD_AT_CPU); if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } const cudaMemcpyKind put = cudaMemcpyHostToDevice; CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream)); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index eaa7a75..1cede07 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -8,6 +8,7 @@ #include "gtest/gtest.h" #include "caffe/common.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/io.hpp" @@ -35,6 +36,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { string snapshot_prefix_; shared_ptr > solver_; + shared_ptr > sync_; int seed_; // Dimensions are determined by generate_sample_data.py // TODO this is brittle and the hdf5 file should be checked instead. @@ -70,8 +72,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { string RunLeastSquaresSolver(const Dtype learning_rate, const Dtype weight_decay, const Dtype momentum, const int num_iters, - const int iter_size = 1, const bool snapshot = false, - const char* from_snapshot = NULL) { + const int iter_size = 1, const int devices = 1, + const bool snapshot = false, const char* from_snapshot = NULL) { ostringstream proto; proto << "snapshot_after_train: " << snapshot << " " @@ -184,7 +186,20 @@ class GradientBasedSolverTest : public MultiDeviceTest { this->solver_->net()->Forward(empty_bottom_vec); } } - this->solver_->Solve(); + if (devices == 1) { + this->solver_->Solve(); + } else { + LOG(INFO) << "Multi-GPU test on " << devices << " devices"; + vector gpus; + for (int i = 0; i < devices; ++i) { + gpus.push_back(i); + } + Caffe::set_solver_count(gpus.size()); + this->sync_.reset(new P2PSync( + this->solver_, NULL, this->solver_->param())); + this->sync_->run(gpus); + Caffe::set_solver_count(1); + } if (snapshot) { ostringstream resume_file; resume_file << snapshot_prefix_ << "/_iter_" << num_iters @@ -410,20 +425,38 @@ class GradientBasedSolverTest : public MultiDeviceTest { void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, const int iter_to_check = 0) { - // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); - - // Compute the (K+1)th update using the analytic least squares gradient. - vector > > updated_params; - ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); - - // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); - - // Check that the solver's solution matches ours. - CheckLeastSquaresUpdate(updated_params); + const int kNum = num_; + const int kIterSize = 1; + // Test over all numbers of devices. + int available_devices = 1; +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaGetDeviceCount(&available_devices)); + } +#endif + for (int devices = 1; devices <= available_devices; ++devices) { + // Configure batch size for single / multi device equivalence. + // Constant data is needed for multi device as for accumulation. + num_ = kNum * devices; + + // Initialize the solver and run K (= iter_to_check) solver iterations + // (on single device). + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check, kIterSize, 1); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + num_ = kNum; + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1, kIterSize, devices); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } } void TestSnapshot(const Dtype learning_rate = 1.0, @@ -433,8 +466,9 @@ class GradientBasedSolverTest : public MultiDeviceTest { const int total_num_iters = num_iters * 2; bool snapshot = false; const int kIterSize = 1; + const int kDevices = 1; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot); + total_num_iters, kIterSize, kDevices, snapshot); // Save the resulting param values. vector > > param_copies; @@ -464,12 +498,13 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Run the solver for num_iters iterations and snapshot. snapshot = true; string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, num_iters, kIterSize, snapshot); + momentum, num_iters, kIterSize, kDevices, snapshot); // Reinitialize the solver and run for num_iters more iterations. snapshot = false; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); + total_num_iters, kIterSize, kDevices, + snapshot, snapshot_name.c_str()); // Check that params now match. const vector*>& params = solver_->net()->learnable_params(); diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp index f7c53f2..d1d1fa8 100644 --- a/src/caffe/util/blocking_queue.cpp +++ b/src/caffe/util/blocking_queue.cpp @@ -3,6 +3,7 @@ #include "caffe/data_layers.hpp" #include "caffe/data_reader.hpp" +#include "caffe/parallel.hpp" #include "caffe/util/blocking_queue.hpp" namespace caffe { @@ -89,5 +90,7 @@ template class BlockingQueue*>; template class BlockingQueue*>; template class BlockingQueue; template class BlockingQueue >; +template class BlockingQueue*>; +template class BlockingQueue*>; } // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 46f9959..9f31b37 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -17,13 +17,17 @@ using caffe::Blob; using caffe::Caffe; using caffe::Net; using caffe::Layer; +using caffe::Solver; using caffe::shared_ptr; +using caffe::string; using caffe::Timer; using caffe::vector; +using std::ostringstream; - -DEFINE_int32(gpu, -1, - "Run in GPU mode on given device ID."); +DEFINE_string(gpu, "", + "Optional; run in GPU mode on given device IDs separated by ','." + "Use '-gpu all' to run on all available GPUs. The effective training " + "batch size is multiplied by the number of devices."); DEFINE_string(solver, "", "The solver definition protocol buffer text file."); DEFINE_string(model, "", @@ -31,8 +35,8 @@ DEFINE_string(model, "", DEFINE_string(snapshot, "", "Optional; the snapshot solver state to resume training."); DEFINE_string(weights, "", - "Optional; the pretrained weights to initialize finetuning. " - "Cannot be set simultaneously with snapshot."); + "Optional; the pretrained weights to initialize finetuning, " + "separated by ','. Cannot be set simultaneously with snapshot."); DEFINE_int32(iterations, 50, "The number of iterations to run."); @@ -66,6 +70,29 @@ static BrewFunction GetBrewFunction(const caffe::string& name) { } } +// Parse GPU ids or use all available devices +static void get_gpus(vector* gpus) { + if (FLAGS_gpu == "all") { + int count = 0; +#ifndef CPU_ONLY + CUDA_CHECK(cudaGetDeviceCount(&count)); +#else + NO_GPU; +#endif + for (int i = 0; i < count; ++i) { + gpus->push_back(i); + } + } else if (FLAGS_gpu.size()) { + vector strings; + boost::split(strings, FLAGS_gpu, boost::is_any_of(",")); + for (int i = 0; i < strings.size(); ++i) { + gpus->push_back(boost::lexical_cast(strings[i])); + } + } else { + CHECK_EQ(gpus->size(), 0); + } +} + // caffe commands to call by // caffe // @@ -74,10 +101,13 @@ static BrewFunction GetBrewFunction(const caffe::string& name) { // Device Query: show diagnostic information for a GPU device. int device_query() { - CHECK_GT(FLAGS_gpu, -1) << "Need a device ID to query."; - LOG(INFO) << "Querying device ID = " << FLAGS_gpu; - caffe::Caffe::SetDevice(FLAGS_gpu); - caffe::Caffe::DeviceQuery(); + LOG(INFO) << "Querying GPUs " << FLAGS_gpu; + vector gpus; + get_gpus(&gpus); + for (int i = 0; i < gpus.size(); ++i) { + caffe::Caffe::SetDevice(gpus[i]); + caffe::Caffe::DeviceQuery(); + } return 0; } RegisterBrewFunction(device_query); @@ -106,34 +136,49 @@ int train() { caffe::SolverParameter solver_param; caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); - // If the gpu flag is not provided, allow the mode and device to be set + // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. - if (FLAGS_gpu < 0 + if (FLAGS_gpu.size() == 0 && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { - FLAGS_gpu = solver_param.device_id(); + if (solver_param.has_device_id()) { + FLAGS_gpu = "" + + boost::lexical_cast(solver_param.device_id()); + } else { // Set default GPU if unspecified + FLAGS_gpu = "" + boost::lexical_cast(0); + } } - // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); - Caffe::set_mode(Caffe::GPU); - } else { - LOG(INFO) << "Use CPU."; + vector gpus; + get_gpus(&gpus); + if (gpus.size() == 0) { Caffe::set_mode(Caffe::CPU); + } else { + ostringstream s; + for (int i = 0; i < gpus.size(); ++i) { + s << (i ? ", " : "") << gpus[i]; + } + LOG(INFO) << "Using GPUs " << s.str(); + + solver_param.set_device_id(gpus[0]); + Caffe::SetDevice(gpus[0]); + Caffe::set_mode(Caffe::GPU); + Caffe::set_solver_count(gpus.size()); } - LOG(INFO) << "Starting Optimization"; - shared_ptr > - solver(caffe::GetSolver(solver_param)); + shared_ptr > solver(caffe::GetSolver(solver_param)); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; - solver->Solve(FLAGS_snapshot); + solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) { - CopyLayers(&*solver, FLAGS_weights); - solver->Solve(); + CopyLayers(solver.get(), FLAGS_weights); + } + + if (gpus.size() > 1) { + caffe::P2PSync sync(solver, NULL, solver->param()); + sync.run(gpus); } else { + LOG(INFO) << "Starting Optimization"; solver->Solve(); } LOG(INFO) << "Optimization Done."; @@ -148,9 +193,11 @@ int test() { CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score."; // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); } else { LOG(INFO) << "Use CPU."; @@ -213,9 +260,11 @@ int time() { CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); } else { LOG(INFO) << "Use CPU."; -- 2.7.4