From: Cyprien Noel Date: Tue, 19 May 2015 00:45:20 +0000 (-0700) Subject: Persistent prefetch thread X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ddcdc9d711e81312caf127e8aa512c3298101297;p=platform%2Fupstream%2Fcaffe.git Persistent prefetch thread --- diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 3958cb7..f57ab6b 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -15,6 +15,7 @@ #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/blocking_queue.hpp" #include "caffe/util/db.hpp" namespace caffe { @@ -51,11 +52,16 @@ class BaseDataLayer : public Layer { }; template +class Batch { + public: + Blob data_, label_; +}; + +template class BasePrefetchingDataLayer : public BaseDataLayer, public InternalThread { public: - explicit BasePrefetchingDataLayer(const LayerParameter& param) - : BaseDataLayer(param) {} + explicit BasePrefetchingDataLayer(const LayerParameter& param); // LayerSetUp: implements common data layer setup functionality, and calls // DataLayerSetUp to do special data layer setup for individual layer types. // This method may not be overridden. @@ -67,14 +73,17 @@ class BasePrefetchingDataLayer : virtual void Forward_gpu(const vector*>& bottom, const vector*>& top); - virtual void CreatePrefetchThread(); - virtual void JoinPrefetchThread(); - // The thread's function - virtual void InternalThreadEntry() {} + // Prefetches batches (asynchronously if to GPU memory) + static const int PREFETCH_COUNT = 3; protected: - Blob prefetch_data_; - Blob prefetch_label_; + virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch) = 0; + + Batch prefetch_[PREFETCH_COUNT]; + BlockingQueue*> prefetch_free_; + BlockingQueue*> prefetch_full_; + Blob transformed_data_; }; @@ -93,7 +102,7 @@ class DataLayer : public BasePrefetchingDataLayer { virtual inline int MaxTopBlobs() const { return 2; } protected: - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); shared_ptr db_; shared_ptr cursor_; @@ -235,7 +244,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer { protected: shared_ptr prefetch_rng_; virtual void ShuffleImages(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); vector > lines_; int lines_id_; @@ -307,7 +316,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer { protected: virtual unsigned int PrefetchRand(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); shared_ptr prefetch_rng_; vector > > image_database_; diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index 1b726de..4d339bf 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -56,6 +56,10 @@ class SyncedMemory { SyncedHead head() { return head_; } size_t size() { return size_; } +#ifndef CPU_ONLY + void async_gpu_push(const cudaStream_t& stream); +#endif + private: void to_cpu(); void to_gpu(); diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index d6c2655..b193826 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -19,10 +19,7 @@ bool InternalThread::must_stop() { } void InternalThread::StartInternalThread() { - // TODO switch to failing once Caffe prefetch thread is persistent. - // Threads should not be started and stopped repeatedly. - // CHECK(!is_started()); - StopInternalThread(); + CHECK(!is_started()) << "Threads should persist and not be restarted."; int device = 0; #ifndef CPU_ONLY diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index facaed7..9288d91 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -1,7 +1,9 @@ +#include #include #include #include "caffe/data_layers.hpp" +#include "caffe/net.hpp" #include "caffe/util/io.hpp" namespace caffe { @@ -28,55 +30,91 @@ void BaseDataLayer::LayerSetUp(const vector*>& bottom, } template +BasePrefetchingDataLayer::BasePrefetchingDataLayer( + const LayerParameter& param) + : BaseDataLayer(param), + prefetch_free_(), prefetch_full_() { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_free_.push(&prefetch_[i]); + } +} + +template void BasePrefetchingDataLayer::LayerSetUp( const vector*>& bottom, const vector*>& top) { BaseDataLayer::LayerSetUp(bottom, top); - // Now, start the prefetch thread. Before calling prefetch, we make two - // cpu_data calls so that the prefetch thread does not accidentally make - // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. - this->prefetch_data_.mutable_cpu_data(); - if (this->output_labels_) { - this->prefetch_label_.mutable_cpu_data(); + // Before starting the prefetch thread, we make cpu_data and gpu_data + // calls so that the prefetch thread does not accidentally make simultaneous + // cudaMalloc calls when the main thread is running. In some GPUs this + // seems to cause failures if we do not so. + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_cpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_cpu_data(); + } } +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_gpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_gpu_data(); + } + } + } +#endif DLOG(INFO) << "Initializing prefetch"; - this->CreatePrefetchThread(); - DLOG(INFO) << "Prefetch initialized."; -} - -template -void BasePrefetchingDataLayer::CreatePrefetchThread() { this->data_transformer_->InitRand(); StartInternalThread(); + DLOG(INFO) << "Prefetch initialized."; } template -void BasePrefetchingDataLayer::JoinPrefetchThread() { - StopInternalThread(); +void BasePrefetchingDataLayer::InternalThreadEntry() { +#ifndef CPU_ONLY + cudaStream_t stream; + if (Caffe::mode() == Caffe::GPU) { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } +#endif + + try { + while (!must_stop()) { + Batch* batch = prefetch_free_.pop(); + load_batch(batch); +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + batch->data_.data().get()->async_gpu_push(stream); + cudaStreamSynchronize(stream); + } +#endif + prefetch_full_.push(batch); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } } template void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); - DLOG(INFO) << "Thread joined"; + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->ReshapeLike(prefetch_data_); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.cpu_data(), top[0]->mutable_cpu_data()); DLOG(INFO) << "Prefetch copied"; if (this->output_labels_) { // Reshape to loaded labels. top[1]->ReshapeLike(prefetch_label_); // Copy the labels. - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - top[1]->mutable_cpu_data()); + caffe_copy(batch->label_.count(), batch->label_.cpu_data(), + top[1]->mutable_cpu_data()); } - // Start a new prefetch thread - DLOG(INFO) << "CreatePrefetchThread"; - CreatePrefetchThread(); + + prefetch_free_.push(batch); } #ifdef CPU_ONLY diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 9335a5b..56439bc 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,22 +7,21 @@ namespace caffe { template void BasePrefetchingDataLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->ReshapeLike(this->prefetch_data_); + top[0]->ReshapeLike(batch->data_); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.gpu_data(), top[0]->mutable_gpu_data()); if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(prefetch_label_); + top[1]->ReshapeLike(batch->label_); // Copy the labels. - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - // Start a new prefetch thread - CreatePrefetchThread(); + + prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 161a75e..22d9f43 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -17,8 +17,8 @@ namespace caffe { template -DataLayer::~DataLayer() { - this->JoinPrefetchThread(); +DataLayer::~DataLayer() { + this->StopInternalThread(); } template @@ -54,21 +54,23 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, << top[0]->width(); // label if (this->output_labels_) { - vector label_shape(1, this->layer_param_.data_param().batch_size()); + vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } } -// This function is used to create a thread that prefetches the data. -template -void DataLayer::InternalThreadEntry() { +// This function is called on prefetch thread +template +void DataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); // Reshape according to the first datum of each batch @@ -81,13 +83,13 @@ void DataLayer::InternalThreadEntry() { this->transformed_data_.Reshape(top_shape); // Reshape prefetch_data according to the batch_size. top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); + batch->data_.Reshape(top_shape); - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); Dtype* top_label = NULL; // suppress warnings about uninitialized variables if (this->output_labels_) { - top_label = this->prefetch_label_.mutable_cpu_data(); + top_label = batch->label_.mutable_cpu_data(); } timer.Start(); for (int item_id = 0; item_id < batch_size; ++item_id) { @@ -97,7 +99,7 @@ void DataLayer::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply data transformations (mirror, scale, crop...) - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); this->data_transformer_->Transform(datum, &(this->transformed_data_)); // Copy label. diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index dcc5334..223ba3a 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -17,7 +17,7 @@ namespace caffe { template ImageDataLayer::~ImageDataLayer() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template @@ -70,8 +70,10 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); CHECK_GT(batch_size, 0) << "Positive batch size required"; top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); - top[0]->ReshapeLike(this->prefetch_data_); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(top_shape); + } + top[0]->Reshape(top_shape); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -79,7 +81,9 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } template @@ -89,15 +93,15 @@ void ImageDataLayer::ShuffleImages() { shuffle(lines_.begin(), lines_.end(), prefetch_rng); } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void ImageDataLayer::InternalThreadEntry() { +void ImageDataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); @@ -114,12 +118,12 @@ void ImageDataLayer::InternalThreadEntry() { // Use data_transformer to infer the expected blob shape from a cv_img. vector top_shape = this->data_transformer_->InferBlobShape(cv_img); this->transformed_data_.Reshape(top_shape); - // Reshape prefetch_data according to the batch_size. + // Reshape batch according to the batch_size. top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); + batch->data_.Reshape(top_shape); - Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* prefetch_data = batch->data_.mutable_cpu_data(); + Dtype* prefetch_label = batch->label_.mutable_cpu_data(); // datum scales const int lines_size = lines_.size(); @@ -133,7 +137,7 @@ void ImageDataLayer::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(prefetch_data + offset); this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); trans_time += timer.MicroSeconds(); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index c127d56..f637f2e 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -27,7 +27,7 @@ namespace caffe { template WindowDataLayer::~WindowDataLayer() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template @@ -171,7 +171,9 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape( + batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -179,7 +181,9 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } // data mean has_mean_file_ = this->transform_param_.has_mean_file(); @@ -217,9 +221,9 @@ unsigned int WindowDataLayer::PrefetchRand() { return (*prefetch_rng)(); } -// Thread fetching the data +// This function is called on prefetch thread template -void WindowDataLayer::InternalThreadEntry() { +void WindowDataLayer::load_batch(Batch* batch) { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows CPUTimer batch_timer; @@ -227,8 +231,8 @@ void WindowDataLayer::InternalThreadEntry() { double read_time = 0; double trans_time = 0; CPUTimer timer; - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int context_pad = this->layer_param_.window_data_param().context_pad(); @@ -252,7 +256,7 @@ void WindowDataLayer::InternalThreadEntry() { bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + caffe_set(batch->data_.count(), Dtype(0), top_data); const int num_fg = static_cast(static_cast(batch_size) * fg_fraction); diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccf..0da7a3b 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -108,6 +108,18 @@ void* SyncedMemory::mutable_gpu_data() { #endif } +#ifndef CPU_ONLY +void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { + CHECK(head_ == HEAD_AT_CPU); + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + } + const cudaMemcpyKind put = cudaMemcpyHostToDevice; + CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream)); + // Assume caller will synchronize on the stream before use + head_ = SYNCED; +} +#endif } // namespace caffe diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp index 73c9564..6ab6ba0 100644 --- a/src/caffe/util/blocking_queue.cpp +++ b/src/caffe/util/blocking_queue.cpp @@ -1,6 +1,7 @@ #include #include +#include "caffe/data_layers.hpp" #include "caffe/util/blocking_queue.hpp" namespace caffe { @@ -83,4 +84,7 @@ size_t BlockingQueue::size() const { return queue_.size(); } +template class BlockingQueue*>; +template class BlockingQueue*>; + } // namespace caffe