Extract common data layer functionalities out of the DataLayer
authorKai Li <kaili_kloud@163.com>
Thu, 28 Aug 2014 08:10:47 +0000 (16:10 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 3 Sep 2014 05:25:21 +0000 (13:25 +0800)
include/caffe/data_layers.hpp
src/caffe/layers/base_data_layer.cpp
src/caffe/layers/data_layer.cpp

index b33d1ee..bc207c0 100644 (file)
@@ -30,16 +30,18 @@ class BaseDataLayer : public Layer<Dtype> {
       : Layer<Dtype>(param),
         data_transformer_(param.data_param().transform_param()) {}
   virtual ~BaseDataLayer() {}
+  // LayerSetUp: implements common data layer setup functionality, and calls
+  // DataLayerSetUp to do special data layer setup for individual layer types.
+  // This method may not be overridden.
+  void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {}
 
- protected:
-  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) = 0;
-  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) = 0;
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
  protected:
   DataTransformer<Dtype> data_transformer_;
@@ -49,6 +51,7 @@ class BaseDataLayer : public Layer<Dtype> {
   int datum_size_;
   Blob<Dtype> data_mean_;
   Caffe::Phase phase_;
+  bool output_labels_;
 };
 
 template <typename Dtype>
@@ -59,10 +62,15 @@ class BasePrefetchingDataLayer :
       : BaseDataLayer<Dtype>(param) {}
   virtual ~BasePrefetchingDataLayer() {}
 
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
   virtual void CreatePrefetchThread();
   virtual void JoinPrefetchThread();
   // The thread's function
-  virtual void InternalThreadEntry() = 0;
+  virtual void InternalThreadEntry() {}
 
  protected:
   Blob<Dtype> prefetch_data_;
@@ -70,13 +78,12 @@ class BasePrefetchingDataLayer :
 };
 
 template <typename Dtype>
-class DataLayer : public Layer<Dtype>, public InternalThread {
+class DataLayer : public BasePrefetchingDataLayer<Dtype> {
  public:
   explicit DataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param),
-        data_transformer_(param.data_param().transform_param()) {}
+      : BasePrefetchingDataLayer<Dtype>(param) {}
   virtual ~DataLayer();
-  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+  virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
@@ -87,22 +94,8 @@ class DataLayer : public Layer<Dtype>, public InternalThread {
   virtual inline int MaxTopBlobs() const { return 2; }
 
  protected:
-  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
-
-  virtual void CreatePrefetchThread();
-  virtual void JoinPrefetchThread();
-  // The thread's function
   virtual void InternalThreadEntry();
 
-  DataTransformer<Dtype> data_transformer_;
-
   // LEVELDB
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::Iterator> iter_;
@@ -112,16 +105,6 @@ class DataLayer : public Layer<Dtype>, public InternalThread {
   MDB_txn* mdb_txn_;
   MDB_cursor* mdb_cursor_;
   MDB_val mdb_key_, mdb_value_;
-
-  int datum_channels_;
-  int datum_height_;
-  int datum_width_;
-  int datum_size_;
-  Blob<Dtype> prefetch_data_;
-  Blob<Dtype> prefetch_label_;
-  Blob<Dtype> data_mean_;
-  bool output_labels_;
-  Caffe::Phase phase_;
 };
 
 template <typename Dtype>
index 2a22c8c..f8f8ad0 100644 (file)
@@ -3,11 +3,20 @@
 namespace caffe {
 
 template <typename Dtype>
+void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  if (top->size() == 1) {
+    output_labels_ = false;
+  } else {
+    output_labels_ = true;
+  }
+  DataLayerSetUp(bottom, top);
+}
+
+template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
   this->phase_ = Caffe::phase();
-
   this->data_transformer_.InitRand();
-
   CHECK(StartInternalThread()) << "Pthread execution failed";
 }
 
@@ -16,4 +25,27 @@ void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {
   CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
 }
 
+template <typename Dtype>
+void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  // First, join the thread
+  JoinPrefetchThread();
+  // Copy the data
+  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
+             (*top)[0]->mutable_cpu_data());
+  if (this->output_labels_) {
+    caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
+               (*top)[1]->mutable_cpu_data());
+  }
+  // Start a new prefetch thread
+  CreatePrefetchThread();
+}
+
+#ifdef CPU_ONLY
+STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward);
+#endif
+
+INSTANTIATE_CLASS(BaseDataLayer);
+INSTANTIATE_CLASS(BasePrefetchingDataLayer);
+
 }  // namespace caffe
index 519b666..0d3054f 100644 (file)
@@ -17,15 +17,15 @@ namespace caffe {
 template <typename Dtype>
 void DataLayer<Dtype>::InternalThreadEntry() {
   Datum datum;
-  CHECK(prefetch_data_.count());
-  Dtype* top_data = prefetch_data_.mutable_cpu_data();
+  CHECK(this->prefetch_data_.count());
+  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
-  if (output_labels_) {
-    top_label = prefetch_label_.mutable_cpu_data();
+  if (this->output_labels_) {
+    top_label = this->prefetch_label_.mutable_cpu_data();
   }
   const int batch_size = this->layer_param_.data_param().batch_size();
 
-  const Dtype* mean = data_mean_.cpu_data();
+  const Dtype* mean = this->data_mean_.cpu_data();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     // get a blob
     switch (this->layer_param_.data_param().backend()) {
@@ -45,9 +45,9 @@ void DataLayer<Dtype>::InternalThreadEntry() {
     }
 
     // Apply data transformations (mirror, scale, crop...)
-    data_transformer_.Transform(item_id, datum, mean, top_data);
+    this->data_transformer_.Transform(item_id, datum, mean, top_data);
 
-    if (output_labels_) {
+    if (this->output_labels_) {
       top_label[item_id] = datum.label();
     }
 
@@ -78,7 +78,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
 
 template <typename Dtype>
 DataLayer<Dtype>::~DataLayer<Dtype>() {
-  JoinPrefetchThread();
+  this->JoinPrefetchThread();
   // clean up the database resources
   switch (this->layer_param_.data_param().backend()) {
   case DataParameter_DB_LEVELDB:
@@ -95,13 +95,8 @@ DataLayer<Dtype>::~DataLayer<Dtype>() {
 }
 
 template <typename Dtype>
-void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
-  if (top->size() == 1) {
-    output_labels_ = false;
-  } else {
-    output_labels_ = true;
-  }
   // Initialize DB
   switch (this->layer_param_.data_param().backend()) {
   case DataParameter_DB_LEVELDB:
@@ -183,31 +178,31 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   if (crop_size > 0) {
     (*top)[0]->Reshape(this->layer_param_.data_param().batch_size(),
                        datum.channels(), crop_size, crop_size);
-    prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
+    this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
         datum.channels(), crop_size, crop_size);
   } else {
     (*top)[0]->Reshape(
         this->layer_param_.data_param().batch_size(), datum.channels(),
         datum.height(), datum.width());
-    prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
+    this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
         datum.channels(), datum.height(), datum.width());
   }
   LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
       << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
       << (*top)[0]->width();
   // label
-  if (output_labels_) {
+  if (this->output_labels_) {
     (*top)[1]->Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1);
-    prefetch_label_.Reshape(this->layer_param_.data_param().batch_size(),
+    this->prefetch_label_.Reshape(this->layer_param_.data_param().batch_size(),
         1, 1, 1);
   }
   // datum size
-  datum_channels_ = datum.channels();
-  datum_height_ = datum.height();
-  datum_width_ = datum.width();
-  datum_size_ = datum.channels() * datum.height() * datum.width();
-  CHECK_GT(datum_height_, crop_size);
-  CHECK_GT(datum_width_, crop_size);
+  this->datum_channels_ = datum.channels();
+  this->datum_height_ = datum.height();
+  this->datum_width_ = datum.width();
+  this->datum_size_ = datum.channels() * datum.height() * datum.width();
+  CHECK_GT(this->datum_height_, crop_size);
+  CHECK_GT(this->datum_width_, crop_size);
   // check if we want to have mean
   if (this->layer_param_.data_param().transform_param().has_mean_file()) {
     const string& mean_file =
@@ -215,63 +210,30 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
     LOG(INFO) << "Loading mean file from" << mean_file;
     BlobProto blob_proto;
     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
-    data_mean_.FromProto(blob_proto);
-    CHECK_EQ(data_mean_.num(), 1);
-    CHECK_EQ(data_mean_.channels(), datum_channels_);
-    CHECK_EQ(data_mean_.height(), datum_height_);
-    CHECK_EQ(data_mean_.width(), datum_width_);
+    this->data_mean_.FromProto(blob_proto);
+    CHECK_EQ(this->data_mean_.num(), 1);
+    CHECK_EQ(this->data_mean_.channels(), this->datum_channels_);
+    CHECK_EQ(this->data_mean_.height(), this->datum_height_);
+    CHECK_EQ(this->data_mean_.width(), this->datum_width_);
   } else {
     // Simply initialize an all-empty mean.
-    data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
+    this->data_mean_.Reshape(1, this->datum_channels_, this->datum_height_,
+                             this->datum_width_);
   }
   // Now, start the prefetch thread. Before calling prefetch, we make two
   // cpu_data calls so that the prefetch thread does not accidentally make
   // simultaneous cudaMalloc calls when the main thread is running. In some
   // GPUs this seems to cause failures if we do not so.
-  prefetch_data_.mutable_cpu_data();
-  if (output_labels_) {
-    prefetch_label_.mutable_cpu_data();
+  this->prefetch_data_.mutable_cpu_data();
+  if (this->output_labels_) {
+    this->prefetch_label_.mutable_cpu_data();
   }
-  data_mean_.cpu_data();
+  this->data_mean_.cpu_data();
   DLOG(INFO) << "Initializing prefetch";
-  CreatePrefetchThread();
+  this->CreatePrefetchThread();
   DLOG(INFO) << "Prefetch initialized.";
 }
 
-template <typename Dtype>
-void DataLayer<Dtype>::CreatePrefetchThread() {
-  phase_ = Caffe::phase();
-
-  data_transformer_.InitRand();
-
-  CHECK(StartInternalThread()) << "Thread execution failed";
-}
-
-template <typename Dtype>
-void DataLayer<Dtype>::JoinPrefetchThread() {
-  CHECK(WaitForInternalThreadToExit()) << "Thread joining failed";
-}
-
-template <typename Dtype>
-void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  // First, join the thread
-  JoinPrefetchThread();
-  // Copy the data
-  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
-             (*top)[0]->mutable_cpu_data());
-  if (output_labels_) {
-    caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
-               (*top)[1]->mutable_cpu_data());
-  }
-  // Start a new prefetch thread
-  CreatePrefetchThread();
-}
-
-#ifdef CPU_ONLY
-STUB_GPU_FORWARD(DataLayer, Forward);
-#endif
-
 INSTANTIATE_CLASS(DataLayer);
 
 }  // namespace caffe