reshape DATA + IMAGE_DATA for inputs of varying dimension
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 17 Oct 2014 05:52:30 +0000 (22:52 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 17 Feb 2015 04:10:17 +0000 (20:10 -0800)
To feed inputs of varying dimension, the `DATA` and `IMAGE_DATA` layer
reshapes its prefetch and top blobs when the batch size is 1.

The `BasePrefetchingDataLayer` always reshapes on forward.

src/caffe/layers/base_data_layer.cpp
src/caffe/layers/base_data_layer.cu
src/caffe/layers/data_layer.cpp
src/caffe/layers/image_data_layer.cpp

index eb0aaf8..c3b9bc4 100644 (file)
@@ -61,6 +61,9 @@ void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
   // First, join the thread
   JoinPrefetchThread();
   DLOG(INFO) << "Thread joined";
+  // Reshape to loaded data.
+  top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(),
+      this->prefetch_data_.height(), this->prefetch_data_.width());
   // Copy the data
   caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
              top[0]->mutable_cpu_data());
index 204a16d..775f6c4 100644 (file)
@@ -9,6 +9,9 @@ void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   // First, join the thread
   JoinPrefetchThread();
+  // Reshape to loaded data.
+  top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(),
+      this->prefetch_data_.height(), this->prefetch_data_.width());
   // Copy the data
   caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
       top[0]->mutable_gpu_data());
index 227db20..98ce26e 100644 (file)
@@ -49,7 +49,7 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   int crop_size = this->layer_param_.transform_param().crop_size();
   if (crop_size > 0) {
     top[0]->Reshape(this->layer_param_.data_param().batch_size(),
-                       datum.channels(), crop_size, crop_size);
+        datum.channels(), crop_size, crop_size);
     this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
         datum.channels(), crop_size, crop_size);
     this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size);
@@ -83,13 +83,25 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   CPUTimer timer;
   CHECK(this->prefetch_data_.count());
   CHECK(this->transformed_data_.count());
+
+  // Reshape on single input batches for inputs of varying dimension.
+  const int batch_size = this->layer_param_.data_param().batch_size();
+  const int crop_size = this->layer_param_.transform_param().crop_size();
+  if (batch_size == 1 && crop_size == 0) {
+    Datum datum;
+    datum.ParseFromString(cursor_->value());
+    this->prefetch_data_.Reshape(1, datum.channels(),
+        datum.height(), datum.width());
+    this->transformed_data_.Reshape(1, datum.channels(),
+        datum.height(), datum.width());
+  }
+
   Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
 
   if (this->output_labels_) {
     top_label = this->prefetch_label_.mutable_cpu_data();
   }
-  const int batch_size = this->layer_param_.data_param().batch_size();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     timer.Start();
     // get a blob
index bd4b8a0..e98ff85 100644 (file)
@@ -102,15 +102,27 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
   CPUTimer timer;
   CHECK(this->prefetch_data_.count());
   CHECK(this->transformed_data_.count());
-  Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
-  Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
   ImageDataParameter image_data_param = this->layer_param_.image_data_param();
   const int batch_size = image_data_param.batch_size();
   const int new_height = image_data_param.new_height();
   const int new_width = image_data_param.new_width();
+  const int crop_size = this->layer_param_.transform_param().crop_size();
   const bool is_color = image_data_param.is_color();
   string root_folder = image_data_param.root_folder();
 
+  // Reshape on single input batches for inputs of varying dimension.
+  if (batch_size == 1 && crop_size == 0 && new_height == 0 && new_width == 0) {
+    cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
+        0, 0, is_color);
+    this->prefetch_data_.Reshape(1, cv_img.channels(),
+        cv_img.rows, cv_img.cols);
+    this->transformed_data_.Reshape(1, cv_img.channels(),
+        cv_img.rows, cv_img.cols);
+  }
+
+  Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data();
+  Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data();
+
   // datum scales
   const int lines_size = lines_.size();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
@@ -124,11 +136,11 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
     timer.Start();
     // Apply transformations (mirror, crop...) to the image
     int offset = this->prefetch_data_.offset(item_id);
-    this->transformed_data_.set_cpu_data(top_data + offset);
+    this->transformed_data_.set_cpu_data(prefetch_data + offset);
     this->data_transformer_.Transform(cv_img, &(this->transformed_data_));
     trans_time += timer.MicroSeconds();
 
-    top_label[item_id] = lines_[lines_id_].second;
+    prefetch_label[item_id] = lines_[lines_id_].second;
     // go to the next iter
     lines_id_++;
     if (lines_id_ >= lines_size) {