Remove duplicate codes from the ImageDataLayer
authorKai Li <kaili_kloud@163.com>
Thu, 28 Aug 2014 08:22:41 +0000 (16:22 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 3 Sep 2014 05:25:21 +0000 (13:25 +0800)
include/caffe/data_layers.hpp
src/caffe/layers/base_data_layer.cpp
src/caffe/layers/data_layer.cpp
src/caffe/layers/image_data_layer.cpp

index bc207c0..499ee7f 100644 (file)
@@ -60,7 +60,7 @@ class BasePrefetchingDataLayer :
  public:
   explicit BasePrefetchingDataLayer(const LayerParameter& param)
       : BaseDataLayer<Dtype>(param) {}
-  virtual ~BasePrefetchingDataLayer() {}
+  virtual ~BasePrefetchingDataLayer();
 
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
@@ -202,13 +202,12 @@ class HDF5OutputLayer : public Layer<Dtype> {
 };
 
 template <typename Dtype>
-class ImageDataLayer : public Layer<Dtype>, public InternalThread {
+class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
  public:
   explicit ImageDataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param),
-        data_transformer_(param.image_data_param().transform_param()) {}
-  virtual ~ImageDataLayer();
-  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      : BasePrefetchingDataLayer<Dtype>(param) {}
+  virtual ~ImageDataLayer() {}
+  virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
@@ -218,34 +217,12 @@ class ImageDataLayer : public Layer<Dtype>, public InternalThread {
   virtual inline int ExactNumTopBlobs() const { return 2; }
 
  protected:
-  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
-
+  shared_ptr<Caffe::RNG> prefetch_rng_;
   virtual void ShuffleImages();
-
-  virtual void CreatePrefetchThread();
-  virtual void JoinPrefetchThread();
   virtual void InternalThreadEntry();
 
-  DataTransformer<Dtype> data_transformer_;
-  shared_ptr<Caffe::RNG> prefetch_rng_;
-
   vector<std::pair<std::string, int> > lines_;
   int lines_id_;
-  int datum_channels_;
-  int datum_height_;
-  int datum_width_;
-  int datum_size_;
-  Blob<Dtype> prefetch_data_;
-  Blob<Dtype> prefetch_label_;
-  Blob<Dtype> data_mean_;
-  Caffe::Phase phase_;
 };
 
 /* MemoryDataLayer
index f8f8ad0..957a771 100644 (file)
@@ -14,6 +14,11 @@ void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
+BasePrefetchingDataLayer<Dtype>::~BasePrefetchingDataLayer<Dtype>() {
+  JoinPrefetchThread();
+}
+
+template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
   this->phase_ = Caffe::phase();
   this->data_transformer_.InitRand();
index 0d3054f..1b3acbd 100644 (file)
@@ -78,7 +78,6 @@ void DataLayer<Dtype>::InternalThreadEntry() {
 
 template <typename Dtype>
 DataLayer<Dtype>::~DataLayer<Dtype>() {
-  this->JoinPrefetchThread();
   // clean up the database resources
   switch (this->layer_param_.data_param().backend()) {
   case DataParameter_DB_LEVELDB:
index 5039ad0..e4805ca 100644 (file)
@@ -16,9 +16,9 @@ namespace caffe {
 template <typename Dtype>
 void ImageDataLayer<Dtype>::InternalThreadEntry() {
   Datum datum;
-  CHECK(prefetch_data_.count());
-  Dtype* top_data = prefetch_data_.mutable_cpu_data();
-  Dtype* top_label = prefetch_label_.mutable_cpu_data();
+  CHECK(this->prefetch_data_.count());
+  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
+  Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
   ImageDataParameter image_data_param = this->layer_param_.image_data_param();
   const int batch_size = image_data_param.batch_size();
   const int new_height = image_data_param.new_height();
@@ -26,7 +26,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
 
   // datum scales
   const int lines_size = lines_.size();
-  const Dtype* mean = data_mean_.cpu_data();
+  const Dtype* mean = this->data_mean_.cpu_data();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     // get a blob
     CHECK_GT(lines_size, lines_id_);
@@ -37,7 +37,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
     }
 
     // Apply transformations (mirror, crop...) to the data
-    data_transformer_.Transform(item_id, datum, mean, top_data);
+    this->data_transformer_.Transform(item_id, datum, mean, top_data);
 
     top_label[item_id] = datum.label();
     // go to the next iter
@@ -54,12 +54,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
 }
 
 template <typename Dtype>
-ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {
-  JoinPrefetchThread();
-}
-
-template <typename Dtype>
-void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   const int new_height = this->layer_param_.image_data_param().new_height();
   const int new_width  = this->layer_param_.image_data_param().new_width();
@@ -106,11 +101,11 @@ void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       .transform_param().mean_file();
   if (crop_size > 0) {
     (*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
-    prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
+    this->prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
   } else {
     (*top)[0]->Reshape(batch_size, datum.channels(), datum.height(),
                        datum.width());
-    prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(),
+    this->prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(),
         datum.width());
   }
   LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
@@ -118,81 +113,48 @@ void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       << (*top)[0]->width();
   // label
   (*top)[1]->Reshape(batch_size, 1, 1, 1);
-  prefetch_label_.Reshape(batch_size, 1, 1, 1);
+  this->prefetch_label_.Reshape(batch_size, 1, 1, 1);
   // datum size
-  datum_channels_ = datum.channels();
-  datum_height_ = datum.height();
-  datum_width_ = datum.width();
-  datum_size_ = datum.channels() * datum.height() * datum.width();
-  CHECK_GT(datum_height_, crop_size);
-  CHECK_GT(datum_width_, crop_size);
+  this->datum_channels_ = datum.channels();
+  this->datum_height_ = datum.height();
+  this->datum_width_ = datum.width();
+  this->datum_size_ = datum.channels() * datum.height() * datum.width();
+  CHECK_GT(this->datum_height_, crop_size);
+  CHECK_GT(this->datum_width_, crop_size);
   // check if we want to have mean
   if (this->layer_param_.image_data_param().transform_param().has_mean_file()) {
     BlobProto blob_proto;
     LOG(INFO) << "Loading mean file from" << mean_file;
     ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto);
-    data_mean_.FromProto(blob_proto);
-    CHECK_EQ(data_mean_.num(), 1);
-    CHECK_EQ(data_mean_.channels(), datum_channels_);
-    CHECK_EQ(data_mean_.height(), datum_height_);
-    CHECK_EQ(data_mean_.width(), datum_width_);
+    this->data_mean_.FromProto(blob_proto);
+    CHECK_EQ(this->data_mean_.num(), 1);
+    CHECK_EQ(this->data_mean_.channels(), this->datum_channels_);
+    CHECK_EQ(this->data_mean_.height(), this->datum_height_);
+    CHECK_EQ(this->data_mean_.width(), this->datum_width_);
   } else {
     // Simply initialize an all-empty mean.
-    data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
+    this->data_mean_.Reshape(1, this->datum_channels_, this->datum_height_,
+                             this->datum_width_);
   }
   // Now, start the prefetch thread. Before calling prefetch, we make two
   // cpu_data calls so that the prefetch thread does not accidentally make
   // simultaneous cudaMalloc calls when the main thread is running. In some
   // GPUs this seems to cause failures if we do not so.
-  prefetch_data_.mutable_cpu_data();
-  prefetch_label_.mutable_cpu_data();
-  data_mean_.cpu_data();
+  this->prefetch_data_.mutable_cpu_data();
+  this->prefetch_label_.mutable_cpu_data();
+  this->data_mean_.cpu_data();
   DLOG(INFO) << "Initializing prefetch";
-  CreatePrefetchThread();
+  this->CreatePrefetchThread();
   DLOG(INFO) << "Prefetch initialized.";
 }
 
 template <typename Dtype>
-void ImageDataLayer<Dtype>::CreatePrefetchThread() {
-  phase_ = Caffe::phase();
-
-  data_transformer_.InitRand();
-
-  // Create the thread.
-  CHECK(StartInternalThread()) << "Thread execution failed";
-}
-
-template <typename Dtype>
 void ImageDataLayer<Dtype>::ShuffleImages() {
   caffe::rng_t* prefetch_rng =
       static_cast<caffe::rng_t*>(prefetch_rng_->generator());
   shuffle(lines_.begin(), lines_.end(), prefetch_rng);
 }
 
-
-template <typename Dtype>
-void ImageDataLayer<Dtype>::JoinPrefetchThread() {
-  CHECK(WaitForInternalThreadToExit()) << "Thread joining failed";
-}
-
-template <typename Dtype>
-void ImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  // First, join the thread
-  JoinPrefetchThread();
-  // Copy the data
-  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
-             (*top)[0]->mutable_cpu_data());
-  caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
-             (*top)[1]->mutable_cpu_data());
-  // Start a new prefetch thread
-  CreatePrefetchThread();
-}
-
-#ifdef CPU_ONLY
-STUB_GPU_FORWARD(ImageDataLayer, Forward);
-#endif
-
 INSTANTIATE_CLASS(ImageDataLayer);
 
 }  // namespace caffe