Create base data layer and base prefetching data layer
authorKai Li <kaili_kloud@163.com>
Thu, 28 Aug 2014 07:29:25 +0000 (15:29 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 3 Sep 2014 05:25:21 +0000 (13:25 +0800)
include/caffe/data_layers.hpp
src/caffe/layers/base_data_layer.cpp [new file with mode: 0644]

index f972965..b33d1ee 100644 (file)
@@ -23,8 +23,52 @@ namespace caffe {
 #define HDF5_DATA_DATASET_NAME "data"
 #define HDF5_DATA_LABEL_NAME "label"
 
-// TODO: DataLayer, ImageDataLayer, and WindowDataLayer all have the
-// same basic structure and a lot of duplicated code.
+template <typename Dtype>
+class BaseDataLayer : public Layer<Dtype> {
+ public:
+  explicit BaseDataLayer(const LayerParameter& param)
+      : Layer<Dtype>(param),
+        data_transformer_(param.data_param().transform_param()) {}
+  virtual ~BaseDataLayer() {}
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) = 0;
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+
+ protected:
+  DataTransformer<Dtype> data_transformer_;
+  int datum_channels_;
+  int datum_height_;
+  int datum_width_;
+  int datum_size_;
+  Blob<Dtype> data_mean_;
+  Caffe::Phase phase_;
+};
+
+template <typename Dtype>
+class BasePrefetchingDataLayer :
+    public BaseDataLayer<Dtype>, public InternalThread {
+ public:
+  explicit BasePrefetchingDataLayer(const LayerParameter& param)
+      : BaseDataLayer<Dtype>(param) {}
+  virtual ~BasePrefetchingDataLayer() {}
+
+  virtual void CreatePrefetchThread();
+  virtual void JoinPrefetchThread();
+  // The thread's function
+  virtual void InternalThreadEntry() = 0;
+
+ protected:
+  Blob<Dtype> prefetch_data_;
+  Blob<Dtype> prefetch_label_;
+};
+
 template <typename Dtype>
 class DataLayer : public Layer<Dtype>, public InternalThread {
  public:
diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp
new file mode 100644 (file)
index 0000000..2a22c8c
--- /dev/null
@@ -0,0 +1,19 @@
+#include "caffe/data_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
+  this->phase_ = Caffe::phase();
+
+  this->data_transformer_.InitRand();
+
+  CHECK(StartInternalThread()) << "Pthread execution failed";
+}
+
+template <typename Dtype>
+void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {
+  CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
+}
+
+}  // namespace caffe