datalayer random cropping, not tested
authorYangqing Jia <jiayq84@gmail.com>
Tue, 1 Oct 2013 23:50:44 +0000 (16:50 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 1 Oct 2013 23:50:44 +0000 (16:50 -0700)
src/caffe/layers/data_layer.cpp
src/caffe/proto/caffe.proto
src/caffe/vision_layers.hpp

index d42a810..fdf9807 100644 (file)
@@ -34,13 +34,24 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   Datum datum;
   datum.ParseFromString(iter_->value().ToString());
   // image
-  (*top)[0]->Reshape(
-      this->layer_param_.batchsize(), datum.channels(), datum.height(),
-      datum.width());
+  int cropsize = this->layer_param_.cropsize();
+  if (cropsize > 0) {
+    (*top)[0]->Reshape(
+        this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
+  } else {
+    (*top)[0]->Reshape(
+        this->layer_param_.batchsize(), datum.channels(), datum.height(),
+        datum.width());
+  }
   // label
   (*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
   // datum size
-    datum_size_ = datum.channels() * datum.height() * datum.width();
+  datum_channels_ = datum.channels();
+  datum_height_ = datum.height();
+  datum_width_ = datum.width();
+  datum_size_ = datum.channels() * datum.height() * datum.width();
+  CHECK_GT(datum_height_, cropsize);
+  CHECK_GT(datum_width_, cropsize);
 }
 
 template <typename Dtype>
@@ -51,24 +62,38 @@ void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   Dtype* top_label = (*top)[1]->mutable_cpu_data();
   const Dtype scale = this->layer_param_.scale();
   const Dtype subtraction = this->layer_param_.subtraction();
-  // LOG(ERROR) << "Debug code on";
-  // if (true) {
-  //   iter_->SeekToFirst();
-  // }
+  int cropsize = this->layer_param_.cropsize();
   for (int i = 0; i < this->layer_param_.batchsize(); ++i) {
     // get a blob
     datum.ParseFromString(iter_->value().ToString());
     const string& data = datum.data();
-    // we will prefer to use data() first, and then try float_data()
-    if (data.size()) {
-      for (int j = 0; j < datum_size_; ++j) {
-        top_data[i * datum_size_ + j] =
-            (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+    if (cropsize) {
+      CHECK(data.size()) << "Image cropping only support uint8 data";
+      int h_offset = rand() % (datum_height_ - cropsize);
+      int w_offset = rand() % (datum_width_ - cropsize);
+      for (int c = 0; c < datum_channels_; ++i) {
+        for (int h = 0; h < cropsize; ++h) {
+          for (int w = 0; w < cropsize; ++w) {
+            top_data[((i * datum_channels_ + c) * cropsize + h) * cropsize + w] =
+                static_cast<Dtype>((uint8_t)data[
+                    (c * datum_height_ + h + h_offset) * datum_width_
+                    + w + w_offset]
+                ) * scale - subtraction;
+          }
+        }
       }
     } else {
-      for (int j = 0; j < datum_size_; ++j) {
-        top_data[i * datum_size_ + j] =
-            (datum.float_data(j) * scale) - subtraction;
+      // we will prefer to use data() first, and then try float_data()
+      if (data.size()) {
+        for (int j = 0; j < datum_size_; ++j) {
+          top_data[i * datum_size_ + j] =
+              (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+        }
+      } else {
+        for (int j = 0; j < datum_size_; ++j) {
+          top_data[i * datum_size_ + j] =
+              (datum.float_data(j) * scale) - subtraction;
+        }
       }
     }
     top_label[i] = datum.label();
index c3632c1..eef6058 100644 (file)
@@ -63,8 +63,10 @@ message LayerParameter {
   // For data pre-processing, we can do simple scaling and constant subtraction
   optional float scale = 17 [ default = 1 ];
   optional float subtraction = 18 [ default = 0 ];
-  // For datay layers, specify the batch size.
+  // For data layers, specify the batch size.
   optional uint32 batchsize = 19;
+  // For data layers, specify if we would like to randomly crop an image.
+  optional uint32 cropsize = 20 [default = 0];
 
   // The blobs containing the numeric parameters of the layer
   repeated BlobProto blobs = 50;
index 74c5978..23678bf 100644 (file)
@@ -252,6 +252,9 @@ class DataLayer : public Layer<Dtype> {
 
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::Iterator> iter_;
+  int datum_channels_;
+  int datum_height_;
+  int datum_width_;
   int datum_size_;
 };