Persistent prefetch thread
authorCyprien Noel <cyprien.noel@gmail.com>
Tue, 19 May 2015 00:45:20 +0000 (17:45 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 9 Aug 2015 22:13:10 +0000 (15:13 -0700)
include/caffe/data_layers.hpp
include/caffe/syncedmem.hpp
src/caffe/internal_thread.cpp
src/caffe/layers/base_data_layer.cpp
src/caffe/layers/base_data_layer.cu
src/caffe/layers/data_layer.cpp
src/caffe/layers/image_data_layer.cpp
src/caffe/layers/window_data_layer.cpp
src/caffe/syncedmem.cpp
src/caffe/util/blocking_queue.cpp

index 3958cb7..f57ab6b 100644 (file)
@@ -15,6 +15,7 @@
 #include "caffe/internal_thread.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/blocking_queue.hpp"
 #include "caffe/util/db.hpp"
 
 namespace caffe {
@@ -51,11 +52,16 @@ class BaseDataLayer : public Layer<Dtype> {
 };
 
 template <typename Dtype>
+class Batch {
+ public:
+  Blob<Dtype> data_, label_;
+};
+
+template <typename Dtype>
 class BasePrefetchingDataLayer :
     public BaseDataLayer<Dtype>, public InternalThread {
  public:
-  explicit BasePrefetchingDataLayer(const LayerParameter& param)
-      : BaseDataLayer<Dtype>(param) {}
+  explicit BasePrefetchingDataLayer(const LayerParameter& param);
   // LayerSetUp: implements common data layer setup functionality, and calls
   // DataLayerSetUp to do special data layer setup for individual layer types.
   // This method may not be overridden.
@@ -67,14 +73,17 @@ class BasePrefetchingDataLayer :
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
 
-  virtual void CreatePrefetchThread();
-  virtual void JoinPrefetchThread();
-  // The thread's function
-  virtual void InternalThreadEntry() {}
+  // Prefetches batches (asynchronously if to GPU memory)
+  static const int PREFETCH_COUNT = 3;
 
  protected:
-  Blob<Dtype> prefetch_data_;
-  Blob<Dtype> prefetch_label_;
+  virtual void InternalThreadEntry();
+  virtual void load_batch(Batch<Dtype>* batch) = 0;
+
+  Batch<Dtype> prefetch_[PREFETCH_COUNT];
+  BlockingQueue<Batch<Dtype>*> prefetch_free_;
+  BlockingQueue<Batch<Dtype>*> prefetch_full_;
+
   Blob<Dtype> transformed_data_;
 };
 
@@ -93,7 +102,7 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
   virtual inline int MaxTopBlobs() const { return 2; }
 
  protected:
-  virtual void InternalThreadEntry();
+  virtual void load_batch(Batch<Dtype>* batch);
 
   shared_ptr<db::DB> db_;
   shared_ptr<db::Cursor> cursor_;
@@ -235,7 +244,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
  protected:
   shared_ptr<Caffe::RNG> prefetch_rng_;
   virtual void ShuffleImages();
-  virtual void InternalThreadEntry();
+  virtual void load_batch(Batch<Dtype>* batch);
 
   vector<std::pair<std::string, int> > lines_;
   int lines_id_;
@@ -307,7 +316,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {
 
  protected:
   virtual unsigned int PrefetchRand();
-  virtual void InternalThreadEntry();
+  virtual void load_batch(Batch<Dtype>* batch);
 
   shared_ptr<Caffe::RNG> prefetch_rng_;
   vector<std::pair<std::string, vector<int> > > image_database_;
index 1b726de..4d339bf 100644 (file)
@@ -56,6 +56,10 @@ class SyncedMemory {
   SyncedHead head() { return head_; }
   size_t size() { return size_; }
 
+#ifndef CPU_ONLY
+  void async_gpu_push(const cudaStream_t& stream);
+#endif
+
  private:
   void to_cpu();
   void to_gpu();
index d6c2655..b193826 100644 (file)
@@ -19,10 +19,7 @@ bool InternalThread::must_stop() {
 }
 
 void InternalThread::StartInternalThread() {
-  // TODO switch to failing once Caffe prefetch thread is persistent.
-  // Threads should not be started and stopped repeatedly.
-  // CHECK(!is_started());
-  StopInternalThread();
+  CHECK(!is_started()) << "Threads should persist and not be restarted.";
 
   int device = 0;
 #ifndef CPU_ONLY
index facaed7..9288d91 100644 (file)
@@ -1,7 +1,9 @@
+#include <boost/thread.hpp>
 #include <string>
 #include <vector>
 
 #include "caffe/data_layers.hpp"
+#include "caffe/net.hpp"
 #include "caffe/util/io.hpp"
 
 namespace caffe {
@@ -28,55 +30,91 @@ void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
+BasePrefetchingDataLayer<Dtype>::BasePrefetchingDataLayer(
+    const LayerParameter& param)
+    : BaseDataLayer<Dtype>(param),
+      prefetch_free_(), prefetch_full_() {
+  for (int i = 0; i < PREFETCH_COUNT; ++i) {
+    prefetch_free_.push(&prefetch_[i]);
+  }
+}
+
+template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::LayerSetUp(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   BaseDataLayer<Dtype>::LayerSetUp(bottom, top);
-  // Now, start the prefetch thread. Before calling prefetch, we make two
-  // cpu_data calls so that the prefetch thread does not accidentally make
-  // simultaneous cudaMalloc calls when the main thread is running. In some
-  // GPUs this seems to cause failures if we do not so.
-  this->prefetch_data_.mutable_cpu_data();
-  if (this->output_labels_) {
-    this->prefetch_label_.mutable_cpu_data();
+  // Before starting the prefetch thread, we make cpu_data and gpu_data
+  // calls so that the prefetch thread does not accidentally make simultaneous
+  // cudaMalloc calls when the main thread is running. In some GPUs this
+  // seems to cause failures if we do not so.
+  for (int i = 0; i < PREFETCH_COUNT; ++i) {
+    prefetch_[i].data_.mutable_cpu_data();
+    if (this->output_labels_) {
+      prefetch_[i].label_.mutable_cpu_data();
+    }
   }
+#ifndef CPU_ONLY
+  if (Caffe::mode() == Caffe::GPU) {
+    for (int i = 0; i < PREFETCH_COUNT; ++i) {
+      prefetch_[i].data_.mutable_gpu_data();
+      if (this->output_labels_) {
+        prefetch_[i].label_.mutable_gpu_data();
+      }
+    }
+  }
+#endif
   DLOG(INFO) << "Initializing prefetch";
-  this->CreatePrefetchThread();
-  DLOG(INFO) << "Prefetch initialized.";
-}
-
-template <typename Dtype>
-void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
   this->data_transformer_->InitRand();
   StartInternalThread();
+  DLOG(INFO) << "Prefetch initialized.";
 }
 
 template <typename Dtype>
-void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {
-  StopInternalThread();
+void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {
+#ifndef CPU_ONLY
+  cudaStream_t stream;
+  if (Caffe::mode() == Caffe::GPU) {
+    cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking);
+  }
+#endif
+
+  try {
+    while (!must_stop()) {
+      Batch<Dtype>* batch = prefetch_free_.pop();
+      load_batch(batch);
+#ifndef CPU_ONLY
+      if (Caffe::mode() == Caffe::GPU) {
+        batch->data_.data().get()->async_gpu_push(stream);
+        cudaStreamSynchronize(stream);
+      }
+#endif
+      prefetch_full_.push(batch);
+    }
+  } catch (boost::thread_interrupted&) {
+    // Interrupted exception is expected on shutdown
+  }
 }
 
 template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
-  // First, join the thread
-  JoinPrefetchThread();
-  DLOG(INFO) << "Thread joined";
+  Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
   // Reshape to loaded data.
-  top[0]->ReshapeLike(prefetch_data_);
+  top[0]->Reshape(batch->data_.num(), batch->data_.channels(),
+      batch->data_.height(), batch->data_.width());
   // Copy the data
-  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
+  caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
              top[0]->mutable_cpu_data());
   DLOG(INFO) << "Prefetch copied";
   if (this->output_labels_) {
     // Reshape to loaded labels.
     top[1]->ReshapeLike(prefetch_label_);
     // Copy the labels.
-    caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
-               top[1]->mutable_cpu_data());
+    caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
+        top[1]->mutable_cpu_data());
   }
-  // Start a new prefetch thread
-  DLOG(INFO) << "CreatePrefetchThread";
-  CreatePrefetchThread();
+
+  prefetch_free_.push(batch);
 }
 
 #ifdef CPU_ONLY
index 9335a5b..56439bc 100644 (file)
@@ -7,22 +7,21 @@ namespace caffe {
 template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
-  // First, join the thread
-  JoinPrefetchThread();
+  Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
   // Reshape to loaded data.
-  top[0]->ReshapeLike(this->prefetch_data_);
+  top[0]->ReshapeLike(batch->data_);
   // Copy the data
-  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
+  caffe_copy(batch->data_.count(), batch->data_.gpu_data(),
       top[0]->mutable_gpu_data());
   if (this->output_labels_) {
     // Reshape to loaded labels.
-    top[1]->ReshapeLike(prefetch_label_);
+    top[1]->ReshapeLike(batch->label_);
     // Copy the labels.
-    caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
+    caffe_copy(batch->label_.count(), batch->label_.gpu_data(),
         top[1]->mutable_gpu_data());
   }
-  // Start a new prefetch thread
-  CreatePrefetchThread();
+
+  prefetch_free_.push(batch);
 }
 
 INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer);
index 161a75e..22d9f43 100644 (file)
@@ -17,8 +17,8 @@
 namespace caffe {
 
 template <typename Dtype>
-DataLayer<Dtype>::~DataLayer<Dtype>() {
-  this->JoinPrefetchThread();
+DataLayer<Dtype>::~DataLayer() {
+  this->StopInternalThread();
 }
 
 template <typename Dtype>
@@ -54,21 +54,23 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       << top[0]->width();
   // label
   if (this->output_labels_) {
-    vector<int> label_shape(1, this->layer_param_.data_param().batch_size());
+    vector<int> label_shape(1, batch_size);
     top[1]->Reshape(label_shape);
-    this->prefetch_label_.Reshape(label_shape);
+    for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+      this->prefetch_[i].label_.Reshape(label_shape);
+    }
   }
 }
 
-// This function is used to create a thread that prefetches the data.
-template <typename Dtype>
-void DataLayer<Dtype>::InternalThreadEntry() {
+// This function is called on prefetch thread
+template<typename Dtype>
+void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
   CPUTimer batch_timer;
   batch_timer.Start();
   double read_time = 0;
   double trans_time = 0;
   CPUTimer timer;
-  CHECK(this->prefetch_data_.count());
+  CHECK(batch->data_.count());
   CHECK(this->transformed_data_.count());
 
   // Reshape according to the first datum of each batch
@@ -81,13 +83,13 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   this->transformed_data_.Reshape(top_shape);
   // Reshape prefetch_data according to the batch_size.
   top_shape[0] = batch_size;
-  this->prefetch_data_.Reshape(top_shape);
+  batch->data_.Reshape(top_shape);
 
-  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
+  Dtype* top_data = batch->data_.mutable_cpu_data();
   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
 
   if (this->output_labels_) {
-    top_label = this->prefetch_label_.mutable_cpu_data();
+    top_label = batch->label_.mutable_cpu_data();
   }
   timer.Start();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
@@ -97,7 +99,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
     read_time += timer.MicroSeconds();
     timer.Start();
     // Apply data transformations (mirror, scale, crop...)
-    int offset = this->prefetch_data_.offset(item_id);
+    int offset = batch->data_.offset(item_id);
     this->transformed_data_.set_cpu_data(top_data + offset);
     this->data_transformer_->Transform(datum, &(this->transformed_data_));
     // Copy label.
index dcc5334..223ba3a 100644 (file)
@@ -17,7 +17,7 @@ namespace caffe {
 
 template <typename Dtype>
 ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {
-  this->JoinPrefetchThread();
+  this->StopInternalThread();
 }
 
 template <typename Dtype>
@@ -70,8 +70,10 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   const int batch_size = this->layer_param_.image_data_param().batch_size();
   CHECK_GT(batch_size, 0) << "Positive batch size required";
   top_shape[0] = batch_size;
-  this->prefetch_data_.Reshape(top_shape);
-  top[0]->ReshapeLike(this->prefetch_data_);
+  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+    this->prefetch_[i].data_.Reshape(top_shape);
+  }
+  top[0]->Reshape(top_shape);
 
   LOG(INFO) << "output data size: " << top[0]->num() << ","
       << top[0]->channels() << "," << top[0]->height() << ","
@@ -79,7 +81,9 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // label
   vector<int> label_shape(1, batch_size);
   top[1]->Reshape(label_shape);
-  this->prefetch_label_.Reshape(label_shape);
+  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+    this->prefetch_[i].label_.Reshape(label_shape);
+  }
 }
 
 template <typename Dtype>
@@ -89,15 +93,15 @@ void ImageDataLayer<Dtype>::ShuffleImages() {
   shuffle(lines_.begin(), lines_.end(), prefetch_rng);
 }
 
-// This function is used to create a thread that prefetches the data.
+// This function is called on prefetch thread
 template <typename Dtype>
-void ImageDataLayer<Dtype>::InternalThreadEntry() {
+void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
   CPUTimer batch_timer;
   batch_timer.Start();
   double read_time = 0;
   double trans_time = 0;
   CPUTimer timer;
-  CHECK(this->prefetch_data_.count());
+  CHECK(batch->data_.count());
   CHECK(this->transformed_data_.count());
   ImageDataParameter image_data_param = this->layer_param_.image_data_param();
   const int batch_size = image_data_param.batch_size();
@@ -114,12 +118,12 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
   // Use data_transformer to infer the expected blob shape from a cv_img.
   vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
   this->transformed_data_.Reshape(top_shape);
-  // Reshape prefetch_data according to the batch_size.
+  // Reshape batch according to the batch_size.
   top_shape[0] = batch_size;
-  this->prefetch_data_.Reshape(top_shape);
+  batch->data_.Reshape(top_shape);
 
-  Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data();
-  Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data();
+  Dtype* prefetch_data = batch->data_.mutable_cpu_data();
+  Dtype* prefetch_label = batch->label_.mutable_cpu_data();
 
   // datum scales
   const int lines_size = lines_.size();
@@ -133,7 +137,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
     read_time += timer.MicroSeconds();
     timer.Start();
     // Apply transformations (mirror, crop...) to the image
-    int offset = this->prefetch_data_.offset(item_id);
+    int offset = batch->data_.offset(item_id);
     this->transformed_data_.set_cpu_data(prefetch_data + offset);
     this->data_transformer_->Transform(cv_img, &(this->transformed_data_));
     trans_time += timer.MicroSeconds();
index c127d56..f637f2e 100644 (file)
@@ -27,7 +27,7 @@ namespace caffe {
 
 template <typename Dtype>
 WindowDataLayer<Dtype>::~WindowDataLayer<Dtype>() {
-  this->JoinPrefetchThread();
+  this->StopInternalThread();
 }
 
 template <typename Dtype>
@@ -171,7 +171,9 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_GT(crop_size, 0);
   const int batch_size = this->layer_param_.window_data_param().batch_size();
   top[0]->Reshape(batch_size, channels, crop_size, crop_size);
-  this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size);
+  for (int i = 0; i < this->PREFETCH_COUNT; ++i)
+    this->prefetch_[i].data_.Reshape(
+        batch_size, channels, crop_size, crop_size);
 
   LOG(INFO) << "output data size: " << top[0]->num() << ","
       << top[0]->channels() << "," << top[0]->height() << ","
@@ -179,7 +181,9 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // label
   vector<int> label_shape(1, batch_size);
   top[1]->Reshape(label_shape);
-  this->prefetch_label_.Reshape(label_shape);
+  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+    this->prefetch_[i].label_.Reshape(label_shape);
+  }
 
   // data mean
   has_mean_file_ = this->transform_param_.has_mean_file();
@@ -217,9 +221,9 @@ unsigned int WindowDataLayer<Dtype>::PrefetchRand() {
   return (*prefetch_rng)();
 }
 
-// Thread fetching the data
+// This function is called on prefetch thread
 template <typename Dtype>
-void WindowDataLayer<Dtype>::InternalThreadEntry() {
+void WindowDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
   // At each iteration, sample N windows where N*p are foreground (object)
   // windows and N*(1-p) are background (non-object) windows
   CPUTimer batch_timer;
@@ -227,8 +231,8 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
   double read_time = 0;
   double trans_time = 0;
   CPUTimer timer;
-  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
-  Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
+  Dtype* top_data = batch->data_.mutable_cpu_data();
+  Dtype* top_label = batch->label_.mutable_cpu_data();
   const Dtype scale = this->layer_param_.window_data_param().scale();
   const int batch_size = this->layer_param_.window_data_param().batch_size();
   const int context_pad = this->layer_param_.window_data_param().context_pad();
@@ -252,7 +256,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
   bool use_square = (crop_mode == "square") ? true : false;
 
   // zero out batch
-  caffe_set(this->prefetch_data_.count(), Dtype(0), top_data);
+  caffe_set(batch->data_.count(), Dtype(0), top_data);
 
   const int num_fg = static_cast<int>(static_cast<float>(batch_size)
       * fg_fraction);
index 7617ccf..0da7a3b 100644 (file)
@@ -108,6 +108,18 @@ void* SyncedMemory::mutable_gpu_data() {
 #endif
 }
 
+#ifndef CPU_ONLY
+void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
+  CHECK(head_ == HEAD_AT_CPU);
+  if (gpu_ptr_ == NULL) {
+    CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+  }
+  const cudaMemcpyKind put = cudaMemcpyHostToDevice;
+  CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream));
+  // Assume caller will synchronize on the stream before use
+  head_ = SYNCED;
+}
+#endif
 
 }  // namespace caffe
 
index 73c9564..6ab6ba0 100644 (file)
@@ -1,6 +1,7 @@
 #include <boost/thread.hpp>
 #include <string>
 
+#include "caffe/data_layers.hpp"
 #include "caffe/util/blocking_queue.hpp"
 
 namespace caffe {
@@ -83,4 +84,7 @@ size_t BlockingQueue<T>::size() const {
   return queue_.size();
 }
 
+template class BlockingQueue<Batch<float>*>;
+template class BlockingQueue<Batch<double>*>;
+
 }  // namespace caffe