#define HDF5_DATA_DATASET_NAME "data"
#define HDF5_DATA_LABEL_NAME "label"
-// TODO: DataLayer, ImageDataLayer, and WindowDataLayer all have the
-// same basic structure and a lot of duplicated code.
+template <typename Dtype>
+class BaseDataLayer : public Layer<Dtype> {
+ public:
+ explicit BaseDataLayer(const LayerParameter& param)
+ : Layer<Dtype>(param),
+ data_transformer_(param.data_param().transform_param()) {}
+ virtual ~BaseDataLayer() {}
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) = 0;
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) = 0;
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+
+ protected:
+ DataTransformer<Dtype> data_transformer_;
+ int datum_channels_;
+ int datum_height_;
+ int datum_width_;
+ int datum_size_;
+ Blob<Dtype> data_mean_;
+ Caffe::Phase phase_;
+};
+
+template <typename Dtype>
+class BasePrefetchingDataLayer :
+ public BaseDataLayer<Dtype>, public InternalThread {
+ public:
+ explicit BasePrefetchingDataLayer(const LayerParameter& param)
+ : BaseDataLayer<Dtype>(param) {}
+ virtual ~BasePrefetchingDataLayer() {}
+
+ virtual void CreatePrefetchThread();
+ virtual void JoinPrefetchThread();
+ // The thread's function
+ virtual void InternalThreadEntry() = 0;
+
+ protected:
+ Blob<Dtype> prefetch_data_;
+ Blob<Dtype> prefetch_label_;
+};
+
template <typename Dtype>
class DataLayer : public Layer<Dtype>, public InternalThread {
public:
--- /dev/null
+#include "caffe/data_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
+ this->phase_ = Caffe::phase();
+
+ this->data_transformer_.InitRand();
+
+ CHECK(StartInternalThread()) << "Pthread execution failed";
+}
+
+template <typename Dtype>
+void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {
+ CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
+}
+
+} // namespace caffe