Added cache_images to WindowDataLayer
authorSergio <sguada@gmail.com>
Wed, 8 Oct 2014 00:19:15 +0000 (17:19 -0700)
committerSergio <sguada@gmail.com>
Thu, 16 Oct 2014 00:03:07 +0000 (17:03 -0700)
Added root_folder to WindowDataLayer to locate images

include/caffe/data_layers.hpp
src/caffe/layers/window_data_layer.cpp
src/caffe/proto/caffe.proto

index c4903ce..34b9b30 100644 (file)
@@ -328,6 +328,8 @@ class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {
   vector<Dtype> mean_values_;
   bool has_mean_file_;
   bool has_mean_values_;
+  bool cache_images_;
+  vector<std::pair<std::string, Datum > > image_database_cache_;
 };
 
 }  // namespace caffe
index fc0ffc8..8f75557 100644 (file)
@@ -59,7 +59,14 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       << "  background (non-object) overlap threshold: "
       << this->layer_param_.window_data_param().bg_threshold() << std::endl
       << "  foreground sampling fraction: "
-      << this->layer_param_.window_data_param().fg_fraction();
+      << this->layer_param_.window_data_param().fg_fraction() << std::endl
+      << "  cache_images: "
+      << this->layer_param_.window_data_param().cache_images() << std::endl
+      << "  root_folder: "
+      << this->layer_param_.window_data_param().root_folder();
+
+  cache_images_ = this->layer_param_.window_data_param().cache_images();
+  string root_folder = this->layer_param_.window_data_param().root_folder();
 
   const bool prefetch_needs_rand =
       this->transform_param_.mirror() ||
@@ -88,12 +95,21 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
     // read image path
     string image_path;
     infile >> image_path;
+    image_path = root_folder + image_path;
     // read image dimensions
     vector<int> image_size(3);
     infile >> image_size[0] >> image_size[1] >> image_size[2];
     channels = image_size[0];
     image_database_.push_back(std::make_pair(image_path, image_size));
 
+    if (cache_images_) {
+      Datum datum;
+      if (!ReadFileToDatum(image_path, &datum)) {
+        LOG(ERROR) << "Could not open or find file " << image_path;
+        return;
+      }
+      image_database_cache_.push_back(std::make_pair(image_path, datum));
+    }
     // read each box
     int num_windows;
     infile >> num_windows;
@@ -227,7 +243,9 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
   const float fg_fraction =
       this->layer_param_.window_data_param().fg_fraction();
   Dtype* mean = NULL;
-  int mean_off, mean_width, mean_height;
+  int mean_off = 0;
+  int mean_width = 0;
+  int mean_height = 0;
   if (this->has_mean_file_) {
     mean = this->data_mean_.mutable_cpu_data();
     mean_off = (this->data_mean_.width() - crop_size) / 2;
@@ -265,10 +283,17 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
       pair<std::string, vector<int> > image =
           image_database_[window[WindowDataLayer<Dtype>::IMAGE_INDEX]];
 
-      cv::Mat cv_img = cv::imread(image.first, CV_LOAD_IMAGE_COLOR);
-      if (!cv_img.data) {
-        LOG(ERROR) << "Could not open or find file " << image.first;
-        return;
+      cv::Mat cv_img;
+      if (this->cache_images_) {
+        pair<std::string, Datum> image_cached =
+          image_database_cache_[window[WindowDataLayer<Dtype>::IMAGE_INDEX]];
+        cv_img = DecodeDatumToCVMat(image_cached.second);
+      } else {
+        cv_img = cv::imread(image.first, CV_LOAD_IMAGE_COLOR);
+        if (!cv_img.data) {
+          LOG(ERROR) << "Could not open or find file " << image.first;
+          return;
+        }
       }
       #ifdef TIMING
       read_time += timer.MilliSeconds();
@@ -442,6 +467,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
     }
   }
   #ifdef TIMING
+  batch_timer.Stop();
   LOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << "ms.";
   LOG(INFO) << "Read time: " << read_time << "ms.";
   LOG(INFO) << "Transform time: " << trans_time << "ms.";
index b602d0e..03d955f 100644 (file)
@@ -714,6 +714,10 @@ message WindowDataParameter {
   // warp: cropped window is warped to a fixed size and aspect ratio
   // square: the tightest square around the window is cropped
   optional string crop_mode = 11 [default = "warp"];
+  // cache_images: will load all images in memory for faster access
+  optional bool cache_images = 12 [default = false];
+  // append root_folder to locate images
+  optional string root_folder = 13 [default = ""];
 }
 
 // DEPRECATED: V0LayerParameter is the old way of specifying layer parameters