Simplify image_data_layer reshapes by letting data_transformer do the job.
authorSergio Guadarrama <sguada@google.com>
Wed, 8 Apr 2015 19:44:48 +0000 (12:44 -0700)
committerSergio Guadarrama <sguada@google.com>
Thu, 9 Apr 2015 00:32:12 +0000 (17:32 -0700)
src/caffe/layers/image_data_layer.cpp

index 38ebbd5..18c035c 100644 (file)
@@ -62,21 +62,15 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // Read an image, and use it to initialize the top blob.
   cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
                                     new_height, new_width, is_color);
-  const int channels = cv_img.channels();
-  const int height = cv_img.rows;
-  const int width = cv_img.cols;
-  // image
-  const int crop_size = this->layer_param_.transform_param().crop_size();
+  // Use data_transformer to infer the expected blob shape from a cv_image.
+  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
+  this->transformed_data_.Reshape(top_shape);
+  // Reshape prefetch_data and top[0] according to the batch_size.
   const int batch_size = this->layer_param_.image_data_param().batch_size();
-  if (crop_size > 0) {
-    top[0]->Reshape(batch_size, channels, crop_size, crop_size);
-    this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size);
-    this->transformed_data_.Reshape(1, channels, crop_size, crop_size);
-  } else {
-    top[0]->Reshape(batch_size, channels, height, width);
-    this->prefetch_data_.Reshape(batch_size, channels, height, width);
-    this->transformed_data_.Reshape(1, channels, height, width);
-  }
+  top_shape[0] = batch_size;
+  this->prefetch_data_.Reshape(top_shape);
+  top[0]->ReshapeLike(this->prefetch_data_);
+
   LOG(INFO) << "output data size: " << top[0]->num() << ","
       << top[0]->channels() << "," << top[0]->height() << ","
       << top[0]->width();
@@ -107,19 +101,19 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
   const int batch_size = image_data_param.batch_size();
   const int new_height = image_data_param.new_height();
   const int new_width = image_data_param.new_width();
-  const int crop_size = this->layer_param_.transform_param().crop_size();
   const bool is_color = image_data_param.is_color();
   string root_folder = image_data_param.root_folder();
 
-  // Reshape on single input batches for inputs of varying dimension.
-  if (batch_size == 1 && crop_size == 0 && new_height == 0 && new_width == 0) {
-    cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
-        0, 0, is_color);
-    this->prefetch_data_.Reshape(1, cv_img.channels(),
-        cv_img.rows, cv_img.cols);
-    this->transformed_data_.Reshape(1, cv_img.channels(),
-        cv_img.rows, cv_img.cols);
-  }
+  // Reshape according to the first image of each batch
+  // on single input batches allows for inputs of varying dimension.
+  cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
+      new_height, new_width, is_color);
+  // Use data_transformer to infer the expected blob shape from a cv_img.
+  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
+  this->transformed_data_.Reshape(top_shape);
+  // Reshape prefetch_data according to the batch_size.
+  top_shape[0] = batch_size;
+  this->prefetch_data_.Reshape(top_shape);
 
   Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data();
   Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data();