Add and transform Datum vector in the MemeoryDataLayer
authorKai Li <kaili_kloud@163.com>
Tue, 2 Sep 2014 06:43:58 +0000 (14:43 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 3 Sep 2014 05:25:22 +0000 (13:25 +0800)
include/caffe/data_layers.hpp
src/caffe/layers/memory_data_layer.cpp
src/caffe/test/test_memory_data_layer.cpp

index 0d435dd..e6a1410 100644 (file)
@@ -251,6 +251,8 @@ class MemoryDataLayer : public BaseDataLayer<Dtype> {
   virtual inline int ExactNumBottomBlobs() const { return 0; }
   virtual inline int ExactNumTopBlobs() const { return 2; }
 
+  virtual void AddDatumVector(const vector<Datum>& datum_vector);
+
   // Reset should accept const pointers, but can't, because the memory
   //  will be given to Blob, which is mutable
   void Reset(Dtype* data, Dtype* label, int n);
index 4a77806..03b3e37 100644 (file)
@@ -31,6 +31,29 @@ void MemoryDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
+void MemoryDataLayer<Dtype>::AddDatumVector(const vector<Datum>& datum_vector) {
+  CHECK(!has_new_data_) <<
+      "Can't add Datum when earlier ones haven't been consumed"
+      << " by the upper layers";
+  size_t num = datum_vector.size();
+  CHECK_GT(num, 0) << "There is no datum to add";
+  CHECK_EQ(num, batch_size_) <<
+      "The number of added datum must be equal to the batch size";
+
+  Dtype* top_data = added_data_.mutable_cpu_data();
+  Dtype* top_label = added_label_.mutable_cpu_data();
+  for (int batch_item_id = 0; batch_item_id < num; ++batch_item_id) {
+    // Apply data transformations (mirror, scale, crop...)
+    this->data_transformer_.Transform(
+        batch_item_id, datum_vector[batch_item_id], this->mean_, top_data);
+    top_label[batch_item_id] = datum_vector[batch_item_id].label();
+  }
+  // num_images == batch_size_
+  Reset(top_data, top_label, batch_size_);
+  has_new_data_ = true;
+}
+
+template <typename Dtype>
 void MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n) {
   CHECK(data);
   CHECK(labels);
index 748a6a9..3dc0034 100644 (file)
@@ -1,3 +1,4 @@
+#include <string>
 #include <vector>
 
 #include "caffe/data_layers.hpp"
@@ -110,4 +111,57 @@ TYPED_TEST(MemoryDataLayerTest, TestForward) {
   }
 }
 
+TYPED_TEST(MemoryDataLayerTest, AddDatumVectorDefaultTransform) {
+  typedef typename TypeParam::Dtype Dtype;
+
+  LayerParameter param;
+  MemoryDataParameter* memory_data_param = param.mutable_memory_data_param();
+  memory_data_param->set_batch_size(this->batch_size_);
+  memory_data_param->set_channels(this->channels_);
+  memory_data_param->set_height(this->height_);
+  memory_data_param->set_width(this->width_);
+  MemoryDataLayer<Dtype> layer(param);
+  layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+
+  vector<Datum> datum_vector(this->batch_size_);
+  const size_t count = this->channels_ * this->height_ * this->width_;
+  size_t pixel_index = 0;
+  for (int i = 0; i < this->batch_size_; ++i) {
+    LOG(ERROR) << "i " << i;
+    datum_vector[i].set_channels(this->channels_);
+    datum_vector[i].set_height(this->height_);
+    datum_vector[i].set_width(this->width_);
+    datum_vector[i].set_label(i);
+    vector<char> pixels(count);
+    for (int j = 0; j < count; ++j) {
+      pixels[j] = pixel_index++ % 256;
+    }
+    datum_vector[i].set_data(&(pixels[0]), count);
+  }
+
+  layer.AddDatumVector(datum_vector);
+
+  int data_index;
+  // Go through the data 5 times
+  for (int iter = 0; iter < 5; ++iter) {
+    layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
+    const Dtype* data = this->data_blob_->cpu_data();
+    size_t index = 0;
+    for (int i = 0; i < this->batch_size_; ++i) {
+      const string& data_string = datum_vector[i].data();
+      EXPECT_EQ(i, this->label_blob_->cpu_data()[i]);
+      for (int c = 0; c < this->channels_; ++c) {
+        for (int h = 0; h < this->height_; ++h) {
+          for (int w = 0; w < this->width_; ++w) {
+            data_index = (c * this->height_ + h) * this->width_ + w;
+            EXPECT_EQ(static_cast<Dtype>(
+                static_cast<uint8_t>(data_string[data_index])),
+                      data[index++]);
+          }
+        }
+      }
+    }
+  }
+}
+
 }  // namespace caffe