MemoryDataLayer now accepts dynamic batch_size
authormanuele <manuele.tamburrano@gmail.com>
Mon, 10 Nov 2014 18:24:02 +0000 (19:24 +0100)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 7 Feb 2015 04:55:30 +0000 (20:55 -0800)
include/caffe/data_layers.hpp
src/caffe/layers/memory_data_layer.cpp

index b5449b3..117b1b5 100644 (file)
@@ -284,6 +284,7 @@ class MemoryDataLayer : public BaseDataLayer<Dtype> {
   Blob<Dtype> added_data_;
   Blob<Dtype> added_label_;
   bool has_new_data_;
+  bool needs_reshape_;
 };
 
 /**
index 67990f3..f36eb6f 100644 (file)
@@ -32,7 +32,15 @@ void MemoryDataLayer<Dtype>::AddDatumVector(const vector<Datum>& datum_vector) {
   CHECK(!has_new_data_) <<
       "Can't add Datum when earlier ones haven't been consumed"
       << " by the upper layers";
+
   size_t num = datum_vector.size();
+  if (batch_size_ != num) {
+    needs_reshape_ = true;
+    batch_size_ = num;
+    added_data_.Reshape(batch_size_, channels_, height_, width_);
+    added_label_.Reshape(batch_size_, 1, 1, 1);
+  }
+
   CHECK_GT(num, 0) << "There is no datum to add";
   CHECK_LE(num, batch_size_) <<
       "The number of added datum must be no greater than the batch size";
@@ -53,10 +61,22 @@ void MemoryDataLayer<Dtype>::AddDatumVector(const vector<Datum>& datum_vector) {
 template <typename Dtype>
 void MemoryDataLayer<Dtype>::AddMatVector(const vector<cv::Mat>& mat_vector,
     const vector<int>& labels) {
+
   CHECK(!has_new_data_) <<
       "Can't add Mat when earlier ones haven't been consumed"
       << " by the upper layers";
+
+  CHECK_EQ(mat_vector.size(), labels.size()) <<
+      "vector of labels and vector of mats need to be of the same size";
+
   size_t num = mat_vector.size();
+  if (batch_size_ != num) {
+    needs_reshape_ = true;
+    batch_size_ = num;
+    added_data_.Reshape(batch_size_, channels_, height_, width_);
+    added_label_.Reshape(batch_size_, 1, 1, 1);
+  }
+
   CHECK_GT(num, 0) << "There is no mat to add";
   CHECK_LE(num, batch_size_) <<
       "The number of added mat must be no greater than the batch size";
@@ -89,10 +109,15 @@ template <typename Dtype>
 void MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";
+  if (needs_reshape_) {
+    top[0]->Reshape(batch_size_, channels_, height_, width_);
+    top[1]->Reshape(batch_size_, 1, 1, 1);
+  }
   top[0]->set_cpu_data(data_ + pos_ * size_);
   top[1]->set_cpu_data(labels_ + pos_);
   pos_ = (pos_ + batch_size_) % n_;
   has_new_data_ = false;
+  needs_reshape_ = false;
 }
 
 INSTANTIATE_CLASS(MemoryDataLayer);