Add transformer to the memory data layer
authorKai Li <kaili_kloud@163.com>
Thu, 28 Aug 2014 19:02:10 +0000 (03:02 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 3 Sep 2014 05:25:22 +0000 (13:25 +0800)
include/caffe/data_layers.hpp
src/caffe/layers/memory_data_layer.cpp
src/caffe/proto/caffe.proto

index 4f599e4..952fcb7 100644 (file)
@@ -9,6 +9,7 @@
 #include "hdf5.h"
 #include "leveldb/db.h"
 #include "lmdb.h"
+#include <opencv2/opencv.hpp>
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
@@ -43,6 +44,11 @@ class BaseDataLayer : public Layer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
+  int datum_channels() const { return datum_channels_; }
+  int datum_height() const { return datum_height_; }
+  int datum_width() const { return datum_width_; }
+  int datum_size() const { return datum_size_; }
+
  protected:
   DataTransformer<Dtype> data_transformer_;
   int datum_channels_;
@@ -228,10 +234,10 @@ class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
 /* MemoryDataLayer
 */
 template <typename Dtype>
-class MemoryDataLayer : public Layer<Dtype> {
+class MemoryDataLayer : public BaseDataLayer<Dtype> {
  public:
   explicit MemoryDataLayer(const LayerParameter& param)
-      : Layer<Dtype>(param) {}
+      : BaseDataLayer<Dtype>(param), is_data_set_up_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
@@ -244,28 +250,22 @@ class MemoryDataLayer : public Layer<Dtype> {
   // Reset should accept const pointers, but can't, because the memory
   //  will be given to Blob, which is mutable
   void Reset(Dtype* data, Dtype* label, int n);
-  int datum_channels() { return datum_channels_; }
-  int datum_height() { return datum_height_; }
-  int datum_width() { return datum_width_; }
+
+  virtual void AddImagesAndLabels(const vector<cv::Mat>& images,
+                                  const vector<int>& labels);
+
   int batch_size() { return batch_size_; }
 
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
-  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
   Dtype* data_;
   Dtype* labels_;
-  int datum_channels_;
-  int datum_height_;
-  int datum_width_;
-  int datum_size_;
   int batch_size_;
   int n_;
   int pos_;
+  bool is_data_set_up_;
 };
 
 template <typename Dtype>
index fda9297..d482df5 100644 (file)
@@ -1,7 +1,7 @@
 #include <vector>
 
+#include "caffe/data_layers.hpp"
 #include "caffe/layer.hpp"
-#include "caffe/vision_layers.hpp"
 
 namespace caffe {
 
@@ -9,13 +9,16 @@ template <typename Dtype>
 void MemoryDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      vector<Blob<Dtype>*>* top) {
   batch_size_ = this->layer_param_.memory_data_param().batch_size();
-  datum_channels_ = this->layer_param_.memory_data_param().channels();
-  datum_height_ = this->layer_param_.memory_data_param().height();
-  datum_width_ = this->layer_param_.memory_data_param().width();
-  datum_size_ = datum_channels_ * datum_height_ * datum_width_;
-  CHECK_GT(batch_size_ * datum_size_, 0) << "batch_size, channels, height,"
-    " and width must be specified and positive in memory_data_param";
-  (*top)[0]->Reshape(batch_size_, datum_channels_, datum_height_, datum_width_);
+  this->datum_channels_ = this->layer_param_.memory_data_param().channels();
+  this->datum_height_ = this->layer_param_.memory_data_param().height();
+  this->datum_width_ = this->layer_param_.memory_data_param().width();
+  this->datum_size_ = this->datum_channels_ * this->datum_height_ *
+      this->datum_width_;
+  CHECK_GT(batch_size_ * this->datum_size_, 0) <<
+      "batch_size, channels, height, and width must be specified and"
+      " positive in memory_data_param";
+  (*top)[0]->Reshape(batch_size_, this->datum_channels_, this->datum_height_,
+                     this->datum_width_);
   (*top)[1]->Reshape(batch_size_, 1, 1, 1);
   data_ = NULL;
   labels_ = NULL;
@@ -36,11 +39,37 @@ template <typename Dtype>
 void MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";
-  (*top)[0]->set_cpu_data(data_ + pos_ * datum_size_);
+  (*top)[0]->set_cpu_data(data_ + pos_ * this->datum_size_);
   (*top)[1]->set_cpu_data(labels_ + pos_);
   pos_ = (pos_ + batch_size_) % n_;
 }
 
+template <typename Dtype>
+void MemoryDataLayer<Dtype>::AddImagesAndLabels(
+    const vector<cv::Mat>& images, const vector<int>& labels) {
+  size_t num_images = images.size();
+  CHECK_GT(num_images, 0) << "There is no image to add";
+  CHECK_LE(num_images, batch_size_)<<
+      "The number of added images " << images.size() <<
+      " must be no greater than the batch size " << batch_size;
+  CHECK_LE(num_images, labels.size()) <<
+      "The number of images " << images.size() <<
+      " must be no greater than the number of labels " << labels.size();
+
+  Datum datum;
+  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
+  Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
+  // Apply data transformations (mirror, scale, crop...)
+  this->data_transformer_.Transform(item_id, datum, mean, top_data);
+  int image_id;
+  for (int item_id = 0; item_id < batch_size; ++item_id) {
+    image_id = item_id % num_images;
+    OpenCVImageToDatum(images[image_id], labels[image_id], new_height,
+                       new_width, &datum);
+    this->data_transformer_.Transform(item_id, datum, mean, top_data);
+  }
+}
+
 INSTANTIATE_CLASS(MemoryDataLayer);
 
 }  // namespace caffe
index ff18779..5acc776 100644 (file)
@@ -529,6 +529,7 @@ message MemoryDataParameter {
   optional uint32 channels = 2;
   optional uint32 height = 3;
   optional uint32 width = 4;
+  optional TransformationParameter transform_param = 5;
 }
 
 // Message that stores parameters used by MVNLayer