Speed up WindowDataLayer and add mean_values
authorSergio <sguada@gmail.com>
Tue, 7 Oct 2014 21:14:50 +0000 (14:14 -0700)
committerSergio <sguada@gmail.com>
Thu, 16 Oct 2014 00:03:07 +0000 (17:03 -0700)
include/caffe/data_layers.hpp
src/caffe/layers/window_data_layer.cpp

index a2ea854..c4903ce 100644 (file)
@@ -325,6 +325,9 @@ class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {
   vector<vector<float> > fg_windows_;
   vector<vector<float> > bg_windows_;
   Blob<Dtype> data_mean_;
+  vector<Dtype> mean_values_;
+  bool has_mean_file_;
+  bool has_mean_values_;
 };
 
 }  // namespace caffe
index 8e65615..fc0ffc8 100644 (file)
@@ -170,15 +170,30 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   this->prefetch_label_.Reshape(batch_size, 1, 1, 1);
 
   // data mean
-  if (this->transform_param_.has_mean_file()) {
+  has_mean_file_ = this->transform_param_.has_mean_file();
+  has_mean_values_ = this->transform_param_.mean_value_size() > 0;
+  if (has_mean_file_) {
     const string& mean_file =
           this->transform_param_.mean_file();
     LOG(INFO) << "Loading mean file from" << mean_file;
     BlobProto blob_proto;
     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
     data_mean_.FromProto(blob_proto);
-  } else {
-    data_mean_.Reshape(1, channels, crop_size, crop_size);
+  }
+  if (has_mean_values_) {
+    CHECK(has_mean_file_ == false) <<
+      "Cannot specify mean_file and mean_value at the same time";
+    for (int c = 0; c < this->transform_param_.mean_value_size(); ++c) {
+      mean_values_.push_back(this->transform_param_.mean_value(c));
+    }
+    CHECK(mean_values_.size() == 1 || mean_values_.size() == channels) <<
+     "Specify either 1 mean_value or as many as channels: " << channels;
+    if (channels > 1 && mean_values_.size() == 1) {
+      // Replicate the mean_value for simplicity
+      for (int c = 1; c < channels; ++c) {
+        mean_values_.push_back(mean_values_[0]);
+      }
+    }
   }
 }
 
@@ -211,10 +226,14 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
   const bool mirror = this->transform_param_.mirror();
   const float fg_fraction =
       this->layer_param_.window_data_param().fg_fraction();
-  const Dtype* mean = this->data_mean_.cpu_data();
-  const int mean_off = (this->data_mean_.width() - crop_size) / 2;
-  const int mean_width = this->data_mean_.width();
-  const int mean_height = this->data_mean_.height();
+  Dtype* mean = NULL;
+  int mean_off, mean_width, mean_height;
+  if (this->has_mean_file_) {
+    mean = this->data_mean_.mutable_cpu_data();
+    mean_off = (this->data_mean_.width() - crop_size) / 2;
+    mean_width = this->data_mean_.width();
+    mean_height = this->data_mean_.height();
+  }
   cv::Size cv_crop_size(crop_size, crop_size);
   const string& crop_mode = this->layer_param_.window_data_param().crop_mode();
 
@@ -357,18 +376,26 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
       }
 
       // copy the warped window into top_data
-      for (int c = 0; c < channels; ++c) {
-        for (int h = 0; h < cv_cropped_img.rows; ++h) {
-          for (int w = 0; w < cv_cropped_img.cols; ++w) {
-            Dtype pixel =
-                static_cast<Dtype>(cv_cropped_img.at<cv::Vec3b>(h, w)[c]);
-
-            top_data[((item_id * channels + c) * crop_size + h + pad_h)
-                     * crop_size + w + pad_w]
-                = (pixel
-                    - mean[(c * mean_height + h + mean_off + pad_h)
-                           * mean_width + w + mean_off + pad_w])
-                  * scale;
+      for (int h = 0; h < cv_cropped_img.rows; ++h) {
+        const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
+        int img_index = 0;
+        for (int w = 0; w < cv_cropped_img.cols; ++w) {
+          for (int c = 0; c < channels; ++c) {
+            int top_index = ((item_id * channels + c) * crop_size + h + pad_h)
+                     * crop_size + w + pad_w;
+            // int top_index = (c * height + h) * width + w;
+            Dtype pixel = static_cast<Dtype>(ptr[img_index++]);
+            if (this->has_mean_file_) {
+              int mean_index = (c * mean_height + h + mean_off + pad_h)
+                           * mean_width + w + mean_off + pad_w;
+              top_data[top_index] = (pixel - mean[mean_index]) * scale;
+            } else {
+              if (this->has_mean_values_) {
+                top_data[top_index] = (pixel - this->mean_values_[c]) * scale;
+              } else {
+                top_data[top_index] = pixel * scale;
+              }
+            }
           }
         }
       }