Add fast code for transform(cv::Mat,Blob)
authorSergio <sguada@gmail.com>
Wed, 15 Oct 2014 22:35:26 +0000 (15:35 -0700)
committerSergio <sguada@gmail.com>
Wed, 15 Oct 2014 22:56:13 +0000 (15:56 -0700)
src/caffe/data_transformer.cpp
src/caffe/layers/data_layer.cpp

index dffaba5..8f14599 100644 (file)
@@ -179,9 +179,106 @@ void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
                                        Blob<Dtype>* transformed_blob) {
-  Datum datum;
-  CVMatToDatum(cv_img, &datum);
-  Transform(datum, transformed_blob);
+  const int img_channels = cv_img.channels();
+  const int img_height = cv_img.rows;
+  const int img_width = cv_img.cols;
+
+  const int channels = transformed_blob->channels();
+  const int height = transformed_blob->height();
+  const int width = transformed_blob->width();
+  const int num = transformed_blob->num();
+
+  CHECK_EQ(channels, img_channels);
+  CHECK_LE(height, img_height);
+  CHECK_LE(width, img_width);
+  CHECK_GE(num, 1);
+
+  CHECK(cv_img.depth() == CV_8U || cv_img.depth() == CV_8S) <<
+      "Image data type must be unsigned or signed byte";
+
+  const int crop_size = param_.crop_size();
+  const Dtype scale = param_.scale();
+  const bool do_mirror = param_.mirror() && Rand(2);
+  const bool has_mean_file = param_.has_mean_file();
+  const bool has_mean_values = mean_values_.size() > 0;
+
+  CHECK_GT(img_channels, 0);
+  CHECK_GE(img_height, crop_size);
+  CHECK_GE(img_width, crop_size);
+
+  Dtype* mean = NULL;
+  if (has_mean_file) {
+    CHECK_EQ(img_channels, data_mean_.channels());
+    CHECK_EQ(img_height, data_mean_.height());
+    CHECK_EQ(img_width, data_mean_.width());
+    mean = data_mean_.mutable_cpu_data();
+  }
+  if (has_mean_values) {
+    CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
+     "Specify either 1 mean_value or as many as channels: " << img_channels;
+    if (img_channels > 1 && mean_values_.size() == 1) {
+      // Replicate the mean_value for simplicity
+      for (int c = 1; c < img_channels; ++c) {
+        mean_values_.push_back(mean_values_[0]);
+      }
+    }
+  }
+
+  int h_off = 0;
+  int w_off = 0;
+  cv::Mat cv_cropped_img = cv_img;
+  if (crop_size) {
+    CHECK_EQ(crop_size, height);
+    CHECK_EQ(crop_size, width);
+    // We only do random crop when we do training.
+    if (phase_ == Caffe::TRAIN) {
+      h_off = Rand(img_height - crop_size + 1);
+      w_off = Rand(img_width - crop_size + 1);
+    } else {
+      h_off = (img_height - crop_size) / 2;
+      w_off = (img_width - crop_size) / 2;
+    }
+    cv::Rect roi(h_off, w_off, crop_size, crop_size);
+    cv_cropped_img = cv_img(roi);
+  } else {
+    CHECK_EQ(img_height, height);
+    CHECK_EQ(img_width, width);
+  }
+  
+  // if (do_mirror) {
+  //   cv::flip(cv_cropped_img, cv_cropped_img, 1);
+  // }
+  CHECK(cv_cropped_img.data);
+
+  Dtype* transformed_data = transformed_blob->mutable_cpu_data();
+  int top_index;
+  for (int h = 0; h < height; ++h) {
+    const char* ptr = cv_cropped_img.ptr<char>(h);
+    int img_index = 0;
+    for (int w = 0; w < width; ++w) {
+      for (int c = 0; c < img_channels; ++c) {
+        if (do_mirror) {
+          top_index = (c * height + h) * width + (width - 1 - w);
+        } else {
+          top_index = (c * height + h) * width + w;
+        }
+        // int top_index = (c * height + h) * width + w;
+        Dtype pixel = static_cast<Dtype>(ptr[img_index++]);
+        if (has_mean_file) {
+          int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
+          transformed_data[top_index] =
+            (pixel - mean[mean_index]) * scale;
+        } else {
+          if (has_mean_values) {
+            transformed_data[top_index] =
+              (pixel - mean_values_[c]) * scale;
+          } else {
+            transformed_data[top_index] = pixel * scale;
+          }
+        }
+      }
+    }
+  }
 }
 #endif
 
index 330381d..95c5427 100644 (file)
@@ -101,14 +101,22 @@ void DataLayer<Dtype>::InternalThreadEntry() {
     // get a blob
     CHECK(iter_ != dataset_->end());
     const Datum& datum = iter_->value;
-    cv::Mat cv_img = DecodeDatumToCVMat(datum);
+    cv::Mat cv_img;
+    if (datum.encoded()) {
+       cv_img = DecodeDatumToCVMat(datum);
+    }
     read_time += timer.MilliSeconds();
     timer.Start();
 
     // Apply data transformations (mirror, scale, crop...)
     int offset = this->prefetch_data_.offset(item_id);
     this->transformed_data_.set_cpu_data(top_data + offset);
-    this->data_transformer_.Transform(cv_img, &(this->transformed_data_));
+    if (datum.encoded()) {
+      this->data_transformer_.Transform(cv_img, &(this->transformed_data_));  
+    } else {
+      this->data_transformer_.Transform(datum, &(this->transformed_data_));
+    }
+    
     if (this->output_labels_) {
       top_label[item_id] = datum.label();
     }