Add CVMatToDatum
authorSergio <sguada@gmail.com>
Wed, 24 Sep 2014 14:36:13 +0000 (07:36 -0700)
committerSergio <sguada@gmail.com>
Fri, 3 Oct 2014 18:45:59 +0000 (11:45 -0700)
include/caffe/util/io.hpp
src/caffe/data_transformer.cpp
src/caffe/util/io.cpp

index 8b83534..2032b1f 100644 (file)
@@ -126,6 +126,8 @@ inline cv::Mat ReadImageToCVMat(const string& filename) {
   return ReadImageToCVMat(filename, 0, 0, true);
 }
 
+void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
+
 leveldb::Options GetLevelDBOptions();
 
 template <typename Dtype>
index cdb6e10..f001f1c 100644 (file)
@@ -150,77 +150,9 @@ void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
                                        Blob<Dtype>* transformed_blob) {
-  const int img_channels = cv_img.channels();
-  const int img_height = cv_img.rows;
-  const int img_width = cv_img.cols;
-
-  const int channels = transformed_blob->channels();
-  const int height = transformed_blob->height();
-  const int width = transformed_blob->width();
-  const int num = transformed_blob->num();
-
-  CHECK_EQ(channels, img_channels);
-  CHECK_LE(height, img_height);
-  CHECK_LE(width, img_width);
-  CHECK_GE(num, 1);
-
-  const int crop_size = param_.crop_size();
-  const Dtype scale = param_.scale();
-  const bool do_mirror = param_.mirror() && Rand() % 2;
-  const bool has_mean_file = param_.has_mean_file();
-
-  int h_off = 0;
-  int w_off = 0;
-  if (crop_size) {
-    CHECK_EQ(crop_size, height);
-    CHECK_EQ(crop_size, width);
-    // We only do random crop when we do training.
-    if (phase_ == Caffe::TRAIN) {
-      h_off = Rand() % (img_height - crop_size);
-      w_off = Rand() % (img_width - crop_size);
-    } else {
-      h_off = (img_height - crop_size) / 2;
-      w_off = (img_width - crop_size) / 2;
-    }
-  } else {
-    CHECK_EQ(img_height, height);
-    CHECK_EQ(img_width, width);
-  }
-
-  Dtype* mean = NULL;
-  if (has_mean_file) {
-    CHECK_EQ(img_channels, data_mean_.channels());
-    CHECK_EQ(img_height, data_mean_.height());
-    CHECK_EQ(img_width, data_mean_.width());
-    mean = data_mean_.mutable_cpu_data();
-  }
-
-  Dtype* transformed_data = transformed_blob->mutable_cpu_data();
-  Dtype pixel;
-  int top_index;
-  for (int c = 0; c < channels; ++c) {
-    int top_index_c = c * height;
-    int mean_index_c = c * img_height + h_off;
-    for (int h = 0; h < height; ++h) {
-      int top_index_h = (top_index_c + h) * width;
-      int mean_index_h = (mean_index_c + h) * img_width + w_off;
-      for (int w = 0; w < width; ++w) {
-        if (do_mirror) {
-          top_index = top_index_h + (width - 1 - w);
-        } else {
-          top_index = top_index_h + w;
-        }
-        pixel = static_cast<Dtype>(
-              cv_img.at<cv::Vec3b>(h + h_off, w + w_off)[c]);
-        if (has_mean_file) {
-          int mean_index = mean_index_h + w;
-          transformed_data[top_index] = (pixel - mean[mean_index]) * scale;
-        } else {
-          transformed_data[top_index] = pixel * scale;
-        }
-      }
-    }
-  }
+  Datum datum;
+  CVMatToDatum(cv_img, &datum);
+  Transform(datum, transformed_blob);
 }
 
 template<typename Dtype>
index a576a91..047b140 100644 (file)
@@ -86,13 +86,21 @@ cv::Mat ReadImageToCVMat(const string& filename,
 bool ReadImageToDatum(const string& filename, const int label,
     const int height, const int width, const bool is_color, Datum* datum) {
   cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
+  if (cv_img.data) {
+    CVMatToDatum(cv_img, datum);
+    datum->set_label(label);
+    return true;
+  } else {
+    return false;
+  }
+}
 
+void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
   CHECK(cv_img.depth() == CV_8U) <<
       "Image data type must be unsigned byte";
   datum->set_channels(cv_img.channels());
   datum->set_height(cv_img.rows);
   datum->set_width(cv_img.cols);
-  datum->set_label(label);
   datum->clear_data();
   datum->clear_float_data();
   int datum_channels = datum->channels();
@@ -111,9 +119,9 @@ bool ReadImageToDatum(const string& filename, const int label,
     }
   }
   datum->set_data(buffer);
-  return true;
 }
 
+
 leveldb::Options GetLevelDBOptions() {
   // In default, we will return the leveldb option and set the max open files
   // in order to avoid using up the operating system's limit.