Initial cv::Mat transformation
authorSergio <sguada@gmail.com>
Fri, 26 Sep 2014 02:58:06 +0000 (19:58 -0700)
committerSergio <sguada@gmail.com>
Fri, 3 Oct 2014 18:45:58 +0000 (11:45 -0700)
Added cv::Mat transformation to ImageDataLayer

Conflicts:

src/caffe/layers/image_data_layer.cpp

Added transform Datum to Blob

Conflicts:

src/caffe/layers/base_data_layer.cpp
src/caffe/layers/base_data_layer.cu

Added transform cv::Mat to Blob

Added transform Vector<Datum> to Blob

Conflicts:
src/caffe/data_transformer.cpp

13 files changed:
include/caffe/data_layers.hpp
include/caffe/data_transformer.hpp
include/caffe/util/benchmark.hpp
include/caffe/util/io.hpp
src/caffe/data_transformer.cpp
src/caffe/layers/base_data_layer.cpp
src/caffe/layers/base_data_layer.cu
src/caffe/layers/data_layer.cpp
src/caffe/layers/image_data_layer.cpp
src/caffe/layers/memory_data_layer.cpp
src/caffe/proto/caffe.proto
src/caffe/util/benchmark.cpp
src/caffe/util/io.cpp

index e598a71..e3ba2d1 100644 (file)
@@ -17,6 +17,7 @@
 #include "caffe/internal_thread.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/benchmark.hpp"
 
 namespace caffe {
 
@@ -51,6 +52,7 @@ class BaseDataLayer : public Layer<Dtype> {
   DataTransformer<Dtype> data_transformer_;
   Caffe::Phase phase_;
   bool output_labels_;
+  Timer timer_forward_;
 };
 
 template <typename Dtype>
@@ -79,6 +81,7 @@ class BasePrefetchingDataLayer :
  protected:
   Blob<Dtype> prefetch_data_;
   Blob<Dtype> prefetch_label_;
+  Blob<Dtype> transformed_data_;
 };
 
 template <typename Dtype>
index fcd1011..2359421 100644 (file)
@@ -1,9 +1,11 @@
 #ifndef CAFFE_DATA_TRANSFORMER_HPP
 #define CAFFE_DATA_TRANSFORMER_HPP
 
+#include <opencv2/core/core.hpp>
+
+#include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
 
 namespace caffe {
 
@@ -33,8 +35,15 @@ class DataTransformer {
    *    written at the appropriate place within the blob's data.
    */
 
-  void Transform(const int batch_item_id, const Datum& datum,
-                 Dtype* transformed_data);
+  void Transform(const Datum& datum, Dtype* transformed_data);
+
+  void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
+
+  void Transform(const vector<Datum> & datum_vector, Blob<Dtype>* transformed_blob);
+
+  void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);
+
+  void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);
 
  protected:
    /**
@@ -47,8 +56,6 @@ class DataTransformer {
    */
   virtual int Rand(int n);
 
-  void Transform(const int batch_item_id, const Datum& datum,
-                 const Dtype* mean, Dtype* transformed_data);
   // Tranformation parameters
   TransformationParameter param_;
 
index f7ef8ea..6c01ac4 100644 (file)
@@ -14,6 +14,7 @@ class Timer {
   void Start();
   void Stop();
   float MilliSeconds();
+  float MicroSeconds();
   float Seconds();
 
   inline bool initted() { return initted_; }
@@ -33,6 +34,7 @@ class Timer {
   boost::posix_time::ptime start_cpu_;
   boost::posix_time::ptime stop_cpu_;
   float elapsed_milliseconds_;
+  float elapsed_microseconds_;
 };
 
 }  // namespace caffe
index 8dd338d..88037c8 100644 (file)
@@ -7,6 +7,7 @@
 #include "google/protobuf/message.h"
 #include "hdf5.h"
 #include "hdf5_hl.h"
+#include <opencv2/core/core.hpp>
 
 #include "caffe/blob.hpp"
 #include "caffe/proto/caffe.pb.h"
@@ -102,6 +103,23 @@ inline bool ReadImageToDatum(const string& filename, const int label,
   return ReadImageToDatum(filename, label, 0, 0, datum);
 }
 
+cv::Mat ReadImageToCVMat(const string& filename, 
+    const int height, const int width, const bool is_color);
+
+inline cv::Mat ReadImageToCVMat(const string& filename, 
+    const int height, const int width) {
+  return ReadImageToCVMat(filename, height, width, true);
+}
+
+inline cv::Mat ReadImageToCVMat(const string& filename,
+    const bool is_color) {
+  return ReadImageToCVMat(filename, 0, 0, is_color);
+}
+
+inline cv::Mat ReadImageToCVMat(const string& filename) {
+  return ReadImageToCVMat(filename, 0, 0, true);
+}
+
 leveldb::Options GetLevelDBOptions();
 
 template <typename Dtype>
index 553717a..836bf3b 100644 (file)
@@ -1,9 +1,12 @@
 #include <string>
 
 #include "caffe/data_transformer.hpp"
+#include "caffe/util/io.hpp"
 #include "caffe/util/math_functions.hpp"
 #include "caffe/util/rng.hpp"
 
+#include <opencv2/core/core.hpp>
+
 namespace caffe {
 
 template<typename Dtype>
@@ -21,105 +24,365 @@ DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param)
 }
 
 template<typename Dtype>
-void DataTransformer<Dtype>::Transform(const int batch_item_id,
-                                       const Datum& datum,
-                                       Dtype* transformed_data) {
-  CHECK_GT(datum.channels(), 0);
-  CHECK_GE(datum.height(), param_.crop_size());
-  CHECK_GE(datum.height(), param_.crop_size());
-  const int size = datum.channels() * datum.height() * datum.width();
-  if (data_mean_.count() < size) {
-    data_mean_.Reshape(1, datum.channels(), datum.height(), datum.width());
-    LOG(INFO) << "Transform without mean";
-  }
-  const Dtype* mean = data_mean_.cpu_data();
-  Transform(batch_item_id, datum, mean, transformed_data);
-}
-
-template<typename Dtype>
-void DataTransformer<Dtype>::Transform(const int batch_item_id,
-                                       const Datum& datum,
-                                       const Dtype* mean,
+void DataTransformer<Dtype>::Transform(const Datum& datum,
                                        Dtype* transformed_data) {
   const string& data = datum.data();
-  const int channels = datum.channels();
-  const int height = datum.height();
-  const int width = datum.width();
+  const int datum_channels = datum.channels();
+  const int datum_height = datum.height();
+  const int datum_width = datum.width();
   const int size = datum.channels() * datum.height() * datum.width();
 
   const int crop_size = param_.crop_size();
-  const bool mirror = param_.mirror();
   const Dtype scale = param_.scale();
+  const bool do_mirror = param_.mirror() && Rand() % 2;
+  const bool has_mean_file = param_.has_mean_file();
+  const bool has_unit8 = data.size() > 0;
+
+  CHECK_GT(datum_channels, 0);
+  CHECK_GE(datum_height, crop_size);
+  CHECK_GE(datum_width, crop_size);
 
-  if (mirror && crop_size == 0) {
-    LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
-               << "set at the same time.";
+  Dtype* mean = NULL;
+  if (has_mean_file) {
+    CHECK_EQ(datum_channels, data_mean_.channels());
+    CHECK_EQ(datum_height, data_mean_.height());
+    CHECK_EQ(datum_width, data_mean_.width());
+    mean = data_mean_.mutable_cpu_data();
   }
 
+  int h_off = 0;
+  int w_off = 0;
+  Dtype datum_element;
   if (crop_size) {
-    CHECK(data.size()) << "Image cropping only support uint8 data";
-    int h_off, w_off;
     // We only do random crop when we do training.
     if (phase_ == Caffe::TRAIN) {
       h_off = Rand(height - crop_size + 1);
       w_off = Rand(width - crop_size + 1);
     } else {
-      h_off = (height - crop_size) / 2;
-      w_off = (width - crop_size) / 2;
+      h_off = (datum_height - crop_size) / 2;
+      w_off = (datum_width - crop_size) / 2;
     }
-    if (mirror && (Rand(2) == 1)) {
-      // Copy mirrored version
-      for (int c = 0; c < channels; ++c) {
-        for (int h = 0; h < crop_size; ++h) {
-          for (int w = 0; w < crop_size; ++w) {
-            int data_index = (c * height + h + h_off) * width + w + w_off;
-            int top_index = ((batch_item_id * channels + c) * crop_size + h)
-                * crop_size + (crop_size - 1 - w);
-            Dtype datum_element =
-                static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
+
+    int top_index, data_index;
+    for (int c = 0; c < datum_channels; ++c) {
+      int top_index_c = c * crop_size;
+      int data_index_c = c * datum_height + h_off;
+      for (int h = 0; h < crop_size; ++h) {
+        int top_index_h = (top_index_c + h) * crop_size;
+        int data_index_h = (data_index_c + h) * datum_width + w_off;
+        for (int w = 0; w < crop_size; ++w) {
+          data_index = data_index_h + w;
+          if (do_mirror) {
+            top_index = top_index_h + (crop_size - 1 - w);
+          } else {
+            top_index = top_index_h + w;
+          }
+          if (has_unit8) {
+            datum_element =
+              static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
+          } else {
+            datum_element = datum.float_data(data_index);
+          }
+          if (has_mean_file) {
             transformed_data[top_index] =
-                (datum_element - mean[data_index]) * scale;
+              (datum_element - mean[data_index]) * scale;
+          } else {
+            transformed_data[top_index] = datum_element * scale;
           }
         }
       }
+    }
+  } else {
+    for (int j = 0; j < size; ++j) {
+      if (has_unit8) {
+          datum_element =
+            static_cast<Dtype>(static_cast<uint8_t>(data[j]));
+      } else {
+        datum_element = datum.float_data(j);
+      }
+      if (has_mean_file) {
+        transformed_data[j] =
+          (datum_element - mean[j]) * scale;
+      } else {
+        transformed_data[j] = datum_element * scale;
+      }
+    }
+  }
+}
+
+template<typename Dtype>
+void DataTransformer<Dtype>::Transform(const Datum& datum,
+                                       Blob<Dtype>* transformed_blob) {
+  const string& data = datum.data();
+  const int datum_channels = datum.channels();
+  const int datum_height = datum.height();
+  const int datum_width = datum.width();
+   
+  const int channels = transformed_blob->channels();
+  const int height = transformed_blob->height();
+  const int width = transformed_blob->width();
+
+  CHECK_EQ(datum_channels, channels);
+  CHECK_GE(datum_height, height);
+  CHECK_GE(datum_width, width);
+
+  CHECK_EQ(transformed_blob->num(), 1) <<
+    "transformed_blob should have num() = 1";
+
+  const int crop_size = param_.crop_size();
+  const Dtype scale = param_.scale();
+  const bool do_mirror = param_.mirror() && Rand() % 2;
+  const bool has_mean_file = param_.has_mean_file();
+  const bool has_unit8 = data.size()>0;
+
+  int h_off = 0;
+  int w_off = 0;
+  if (crop_size) {
+    CHECK_EQ(crop_size, height);
+    CHECK_EQ(crop_size, width);
+    // We only do random crop when we do training.
+    if (phase_ == Caffe::TRAIN) {
+      h_off = Rand() % (datum_height - crop_size);
+      w_off = Rand() % (datum_width - crop_size);
     } else {
-      // Normal copy
-      for (int c = 0; c < channels; ++c) {
-        for (int h = 0; h < crop_size; ++h) {
-          for (int w = 0; w < crop_size; ++w) {
-            int top_index = ((batch_item_id * channels + c) * crop_size + h)
-                * crop_size + w;
-            int data_index = (c * height + h + h_off) * width + w + w_off;
-            Dtype datum_element =
-                static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
-            transformed_data[top_index] =
-                (datum_element - mean[data_index]) * scale;
-          }
+      h_off = (datum_height - crop_size) / 2;
+      w_off = (datum_width - crop_size) / 2;
+    }
+  } else {
+    CHECK_EQ(datum_height, height);
+    CHECK_EQ(datum_width, width);
+  }
+
+  Dtype* transformed_data = transformed_blob->mutable_cpu_data();
+  
+  Dtype* mean = NULL;
+  if (has_mean_file) {
+    CHECK_EQ(datum_channels, data_mean_.channels());
+    CHECK_EQ(datum_height, data_mean_.height());
+    CHECK_EQ(datum_width, data_mean_.width());
+    mean = data_mean_.mutable_cpu_data();
+  }
+
+  Dtype datum_element;
+  int top_index, data_index;
+  for (int c = 0; c < channels; ++c) {
+    int top_index_c = c * height;
+    int data_index_c = c * datum_height + h_off;
+    for (int h = 0; h < height; ++h) {
+      int top_index_h = (top_index_c + h) * width;
+      int data_index_h = (data_index_c + h) * datum_width + w_off;
+      for (int w = 0; w < width; ++w) {
+        data_index = data_index_h + w;
+        if (do_mirror) {
+          top_index = top_index_h + (width - 1 - w);
+        } else {
+          top_index = top_index_h + w;
+        }
+        if (has_unit8) {
+          datum_element =
+            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
+        } else {
+          datum_element = datum.float_data(data_index);
+        }
+        if (has_mean_file) {
+          transformed_data[top_index] =
+            (datum_element - mean[data_index]) * scale;
+        } else {
+          transformed_data[top_index] = datum_element * scale;
         }
       }
     }
+  }
+}
+
+template<typename Dtype>
+void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
+                                       Blob<Dtype>* transformed_blob) {
+  const int datum_num = datum_vector.size();
+  const int num = transformed_blob->num();
+  const int channels = transformed_blob->channels();
+  const int height = transformed_blob->height();
+  const int width = transformed_blob->width();
+
+  CHECK_GT(datum_num, 0) << "There is no datum to add";
+  CHECK_LE(datum_num, num) <<
+    "The size of datum_vector must be smaller than transformed_blob->num()";
+  Blob<Dtype> uni_blob(1, channels, height, width);
+  for (int item_id = 0; item_id < datum_num; ++item_id) {
+    int offset = transformed_blob->offset(item_id);
+    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
+    Transform(datum_vector[item_id], &uni_blob);
+  }
+}
+
+template<typename Dtype>
+void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
+                                       Blob<Dtype>* transformed_blob) {
+
+  const int img_channels = cv_img.channels();
+  const int img_height = cv_img.rows;
+  const int img_width = cv_img.cols;
+
+  const int channels = transformed_blob->channels();
+  const int height = transformed_blob->height();
+  const int width = transformed_blob->width();
+
+  CHECK_EQ(img_channels, channels);
+  CHECK_GE(img_height, height);
+  CHECK_GE(img_width, width);
+
+  CHECK_EQ(transformed_blob->num(), 1) <<
+    "transformed_blob should have num() = 1";
+
+  const int crop_size = param_.crop_size();
+  const Dtype scale = param_.scale();
+  const bool do_mirror = param_.mirror() && Rand() % 2;
+  const bool has_mean_file = param_.has_mean_file();
+  
+  int h_off = 0;
+  int w_off = 0;
+  if (crop_size) {
+    CHECK_EQ(crop_size, height);
+    CHECK_EQ(crop_size, width);
+    // We only do random crop when we do training.
+    if (phase_ == Caffe::TRAIN) {
+      h_off = Rand() % (img_height - crop_size);
+      w_off = Rand() % (img_width - crop_size);
+    } else {
+      h_off = (img_height - crop_size) / 2;
+      w_off = (img_width - crop_size) / 2;
+    }
   } else {
-    // we will prefer to use data() first, and then try float_data()
-    if (data.size()) {
-      for (int j = 0; j < size; ++j) {
-        Dtype datum_element =
-            static_cast<Dtype>(static_cast<uint8_t>(data[j]));
-        transformed_data[j + batch_item_id * size] =
-            (datum_element - mean[j]) * scale;
+    CHECK_EQ(img_height, height);
+    CHECK_EQ(img_width, width);
+  }
+
+  Dtype* mean = NULL;
+  if (has_mean_file) {
+    CHECK_EQ(img_channels, data_mean_.channels());
+    CHECK_EQ(img_height, data_mean_.height());
+    CHECK_EQ(img_width, data_mean_.width());
+    mean = data_mean_.mutable_cpu_data();
+  }
+
+  Dtype* transformed_data = transformed_blob->mutable_cpu_data();
+  Dtype pixel;
+  int top_index;
+  for (int c = 0; c < channels; ++c) {
+    int top_index_c = c * height;
+    int mean_index_c = c * img_height + h_off;
+    for (int h = 0; h < height; ++h) {
+      int top_index_h = (top_index_c + h) * width;
+      int mean_index_h = (mean_index_c + h) * img_width + w_off;
+      for (int w = 0; w < width; ++w) {
+        if (do_mirror) {
+          top_index = top_index_h + (width - 1 - w);
+        } else {
+          top_index = top_index_h + w;
+        }
+        pixel = static_cast<Dtype>(
+              cv_img.at<cv::Vec3b>(h + h_off, w + w_off)[c]);
+        if (has_mean_file) {
+          int mean_index = mean_index_h + w;
+          transformed_data[top_index] = (pixel - mean[mean_index]) * scale;
+        } else {
+          transformed_data[top_index] = pixel * scale;
+        }
       }
+    }
+  }
+}
+
+
+template<typename Dtype>
+void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
+                                       Blob<Dtype>* transformed_blob) {
+
+  const int input_num = input_blob->num();
+  const int input_channels = input_blob->channels();
+  const int input_height = input_blob->height();
+  const int input_width = input_blob->width();
+  const int num = transformed_blob->num();
+  const int channels = transformed_blob->channels();
+  const int height = transformed_blob->height();
+  const int width = transformed_blob->width();
+  const int size = transformed_blob->count();
+
+  CHECK_LE(input_num, num);
+  CHECK_EQ(input_channels, channels);
+  CHECK_GE(input_height, height);
+  CHECK_GE(input_width, width);
+
+  const int crop_size = param_.crop_size();
+  const Dtype scale = param_.scale();
+  const bool do_mirror = param_.mirror() && Rand() % 2;
+  const bool has_mean_file = param_.has_mean_file();
+
+  int h_off = 0;
+  int w_off = 0;
+  if (crop_size) {
+    CHECK_EQ(crop_size, height);
+    CHECK_EQ(crop_size, width);
+    // We only do random crop when we do training.
+    if (phase_ == Caffe::TRAIN) {
+      h_off = Rand() % (input_height - crop_size);
+      w_off = Rand() % (input_width - crop_size);
     } else {
-      for (int j = 0; j < size; ++j) {
-        transformed_data[j + batch_item_id * size] =
-            (datum.float_data(j) - mean[j]) * scale;
+      h_off = (input_height - crop_size) / 2;
+      w_off = (input_width - crop_size) / 2;
+    }
+  } else {
+    CHECK_EQ(input_height, height);
+    CHECK_EQ(input_width, width);
+  }
+
+  Dtype* input_data = input_blob->mutable_cpu_data();
+  if (has_mean_file) {
+    CHECK_EQ(input_channels, data_mean_.channels());
+    CHECK_EQ(input_height, data_mean_.height());
+    CHECK_EQ(input_width, data_mean_.width());
+    for (int n = 0; n < input_num; ++n) {
+      int offset = input_blob->offset(n);
+      caffe_sub(data_mean_.count(), input_data + offset,
+            data_mean_.cpu_data(), input_data + offset);
+    } 
+  }
+
+  Dtype* transformed_data = transformed_blob->mutable_cpu_data();
+
+  for (int n = 0; n < input_num; ++n) {
+    int top_index_n = n * channels;
+    int data_index_n = n * channels;
+    for (int c = 0; c < channels; ++c) {
+      int top_index_c = (top_index_n + c) * height;
+      int data_index_c = (data_index_n + c) * input_height + h_off;
+      for (int h = 0; h < height; ++h) {
+        int top_index_h = (top_index_c + h) * width;
+        int data_index_h = (data_index_c + h) * input_width + w_off;
+        if (do_mirror) {
+          int top_index_w = top_index_h + width - 1;
+          for (int w = 0; w < width; ++w) {
+            transformed_data[top_index_w-w] = input_data[data_index_h + w];
+          }
+        } else {
+          for (int w = 0; w < width; ++w) {
+            transformed_data[top_index_h + w] = input_data[data_index_h + w];
+          }
+        }
       }
     }
   }
+  if (scale!=Dtype(1)) {
+    DLOG(INFO) << "Scale: " << scale;
+    caffe_scal(size, scale, transformed_data);
+  }
 }
 
 template <typename Dtype>
 void DataTransformer<Dtype>::InitRand() {
-  const bool needs_rand = (phase_ == Caffe::TRAIN) &&
-      (param_.mirror() || param_.crop_size());
+  const bool needs_rand = param_.mirror() ||
+      (phase_ == Caffe::TRAIN && param_.crop_size());
   if (needs_rand) {
     const unsigned int rng_seed = caffe_rng_rand();
     rng_.reset(new Caffe::RNG(rng_seed));
index 5038d53..d7d4752 100644 (file)
@@ -59,16 +59,23 @@ template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   // First, join the thread
+  if (this->timer_forward_.has_run_at_least_once()) {
+    DLOG(INFO) << "Proccessing: " << this->timer_forward_.MilliSeconds() << "ms.";
+  }
   JoinPrefetchThread();
+  DLOG(INFO) << "Thread joined";
   // Copy the data
   caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
              top[0]->mutable_cpu_data());
+  DLOG(INFO) << "Prefetch copied";
   if (this->output_labels_) {
     caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
                top[1]->mutable_cpu_data());
   }
   // Start a new prefetch thread
+  DLOG(INFO) << "CreatePrefetchThread";
   CreatePrefetchThread();
+  this->timer_forward_.Start();
 }
 
 #ifdef CPU_ONLY
index ff15103..690858f 100644 (file)
@@ -7,6 +7,7 @@ namespace caffe {
 template <typename Dtype>
 void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  DLOG(INFO) << "Processing: " << this->timer_forward_.MilliSeconds() << "ms.";
   // First, join the thread
   JoinPrefetchThread();
   // Copy the data
@@ -18,6 +19,7 @@ void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
   }
   // Start a new prefetch thread
   CreatePrefetchThread();
+  this->timer_forward_.Start();
 }
 
 INSTANTIATE_CLASS(BasePrefetchingDataLayer);
index 40c4873..b1d7ef9 100644 (file)
@@ -8,6 +8,7 @@
 #include "caffe/data_layers.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/benchmark.hpp"
 #include "caffe/util/io.hpp"
 #include "caffe/util/math_functions.hpp"
 #include "caffe/util/rng.hpp"
@@ -118,12 +119,15 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
                        datum.channels(), crop_size, crop_size);
     this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
         datum.channels(), crop_size, crop_size);
+    this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size);
   } else {
     top[0]->Reshape(
         this->layer_param_.data_param().batch_size(), datum.channels(),
         datum.height(), datum.width());
     this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
         datum.channels(), datum.height(), datum.width());
+    this->transformed_data_.Reshape(1, datum.channels(),
+      datum.height(), datum.width());
   }
   LOG(INFO) << "output data size: " << top[0]->num() << ","
       << top[0]->channels() << "," << top[0]->height() << ","
@@ -139,10 +143,13 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
 // This function is used to create a thread that prefetches the data.
 template <typename Dtype>
 void DataLayer<Dtype>::InternalThreadEntry() {
-  Datum datum;
+  Timer batch_timer;
+  batch_timer.Start();
   CHECK(this->prefetch_data_.count());
+  CHECK(this->transformed_data_.count());
   Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
+
   if (this->output_labels_) {
     top_label = this->prefetch_label_.mutable_cpu_data();
   }
@@ -150,6 +157,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
 
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     // get a blob
+    Datum datum;
     switch (this->layer_param_.data_param().backend()) {
     case DataParameter_DB_LEVELDB:
       CHECK(iter_);
@@ -165,10 +173,10 @@ void DataLayer<Dtype>::InternalThreadEntry() {
     default:
       LOG(FATAL) << "Unknown database backend";
     }
-
     // Apply data transformations (mirror, scale, crop...)
-    this->data_transformer_.Transform(item_id, datum, top_data);
-
+    int offset = this->prefetch_data_.offset(item_id);
+    this->transformed_data_.set_cpu_data(top_data + offset);
+    this->data_transformer_.Transform(datum, &(this->transformed_data_));    
     if (this->output_labels_) {
       top_label[item_id] = datum.label();
     }
@@ -196,6 +204,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
       LOG(FATAL) << "Unknown database backend";
     }
   }
+  DLOG(INFO) << "Prefetch: " << batch_timer.MilliSeconds() << " ms.";
 }
 
 INSTANTIATE_CLASS(DataLayer);
index a5d46fd..e167f13 100644 (file)
@@ -22,6 +22,8 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   const int new_height = this->layer_param_.image_data_param().new_height();
   const int new_width  = this->layer_param_.image_data_param().new_width();
+  const bool is_color  = this->layer_param_.image_data_param().is_color();
+
   CHECK((new_height == 0 && new_width == 0) ||
       (new_height > 0 && new_width > 0)) << "Current implementation requires "
       "new_height and new_width to be set at the same time.";
@@ -53,22 +55,23 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
     CHECK_GT(lines_.size(), skip) << "Not enough points to skip";
     lines_id_ = skip;
   }
-  // Read a data point, and use it to initialize the top blob.
-  Datum datum;
-  CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
-                         new_height, new_width, &datum));
+  // Read an image, and use it to initialize the top blob.
+  cv::Mat cv_img = ReadImageToCVMat(lines_[lines_id_].first,
+                                    new_height, new_width, is_color);
+  const int channels = cv_img.channels();
+  const int height = cv_img.rows;
+  const int width = cv_img.cols;
   // image
   const int crop_size = this->layer_param_.transform_param().crop_size();
   const int batch_size = this->layer_param_.image_data_param().batch_size();
   if (crop_size > 0) {
-    top[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
-    this->prefetch_data_.Reshape(batch_size, datum.channels(), crop_size,
-                                 crop_size);
+    top[0]->Reshape(batch_size, channels, crop_size, crop_size);
+    this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size);
+    this->transformed_data_.Reshape(1, channels, crop_size, crop_size);
   } else {
-    top[0]->Reshape(batch_size, datum.channels(), datum.height(),
-                       datum.width());
-    this->prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(),
-        datum.width());
+    top[0]->Reshape(batch_size, channels, height, width);
+    this->prefetch_data_.Reshape(batch_size, channels, height, width);
+    this->transformed_data_.Reshape(1, channels, height, width);
   }
   LOG(INFO) << "output data size: " << top[0]->num() << ","
       << top[0]->channels() << "," << top[0]->height() << ","
@@ -88,30 +91,35 @@ void ImageDataLayer<Dtype>::ShuffleImages() {
 // This function is used to create a thread that prefetches the data.
 template <typename Dtype>
 void ImageDataLayer<Dtype>::InternalThreadEntry() {
-  Datum datum;
+  Timer batch_timer;
+  batch_timer.Start();
   CHECK(this->prefetch_data_.count());
+  CHECK(this->transformed_data_.count());
   Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
   Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
   ImageDataParameter image_data_param = this->layer_param_.image_data_param();
   const int batch_size = image_data_param.batch_size();
   const int new_height = image_data_param.new_height();
   const int new_width = image_data_param.new_width();
+  const bool is_color = image_data_param.is_color();
 
   // datum scales
   const int lines_size = lines_.size();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     // get a blob
     CHECK_GT(lines_size, lines_id_);
-    if (!ReadImageToDatum(lines_[lines_id_].first,
-          lines_[lines_id_].second,
-          new_height, new_width, &datum)) {
+    cv::Mat cv_img = ReadImageToCVMat(lines_[lines_id_].first,
+                                    new_height, new_width, is_color);
+    if (!cv_img.data) {
       continue;
     }
-
-    // Apply transformations (mirror, crop...) to the data
-    this->data_transformer_.Transform(item_id, datum, top_data);
-
-    top_label[item_id] = datum.label();
+    // Apply transformations (mirror, crop...) to the image
+    // this->data_transformer_.Transform(item_id, cv_img, top_data);
+    int offset = this->prefetch_data_.offset(item_id);
+    this->transformed_data_.set_cpu_data(top_data + offset);
+    this->data_transformer_.Transform(cv_img, &(this->transformed_data_));    
+    
+    top_label[item_id] = lines_[lines_id_].second;
     // go to the next iter
     lines_id_++;
     if (lines_id_ >= lines_size) {
@@ -123,6 +131,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
       }
     }
   }
+  DLOG(INFO) << "Prefetch: " << batch_timer.MilliSeconds() << " ms.";
 }
 
 INSTANTIATE_CLASS(ImageDataLayer);
index 269a267..8f1c21f 100644 (file)
@@ -37,15 +37,15 @@ void MemoryDataLayer<Dtype>::AddDatumVector(const vector<Datum>& datum_vector) {
   CHECK_LE(num, batch_size_) <<
       "The number of added datum must be no greater than the batch size";
 
-  Dtype* top_data = added_data_.mutable_cpu_data();
+  // Apply data transformations (mirror, scale, crop...)
+  this->data_transformer_.Transform(datum_vector, &added_data_);
+  // Copy Labels
   Dtype* top_label = added_label_.mutable_cpu_data();
-  for (int batch_item_id = 0; batch_item_id < num; ++batch_item_id) {
-    // Apply data transformations (mirror, scale, crop...)
-    this->data_transformer_.Transform(
-        batch_item_id, datum_vector[batch_item_id], top_data);
-    top_label[batch_item_id] = datum_vector[batch_item_id].label();
+  for (int item_id = 0; item_id < num; ++item_id) {
+    top_label[item_id] = datum_vector[item_id].label();
   }
   // num_images == batch_size_
+  Dtype* top_data = added_data_.mutable_cpu_data();
   Reset(top_data, top_label, batch_size_);
   has_new_data_ = true;
 }
index 6944ae8..8f88177 100644 (file)
@@ -520,6 +520,8 @@ message ImageDataParameter {
   // It will also resize images if new_height or new_width are not zero.
   optional uint32 new_height = 9 [default = 0];
   optional uint32 new_width = 10 [default = 0];
+  // Specify if the images are color or gray
+  optional bool is_color = 11 [default = true];
   // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
   // simple scaling and subtracting the data mean, if provided. Note that the
   // mean subtraction is always carried out before scaling.
index 566d06a..76829f5 100644 (file)
@@ -55,6 +55,28 @@ void Timer::Stop() {
   }
 }
 
+
+float Timer::MicroSeconds() {
+  if (!has_run_at_least_once()) {
+    LOG(WARNING) << "Timer has never been run before reading time.";
+    return 0;
+  }
+  if (running()) {
+    Stop();
+  }
+  if (Caffe::mode() == Caffe::GPU) {
+#ifndef CPU_ONLY
+    CUDA_CHECK(cudaEventElapsedTime(&elapsed_microseconds_, start_gpu_,
+                                    stop_gpu_));
+#else
+      NO_GPU;
+#endif
+  } else {
+    elapsed_microseconds_ = (stop_cpu_ - start_cpu_).total_microseconds();
+  }
+  return elapsed_microseconds_;
+}
+
 float Timer::MilliSeconds() {
   if (!has_run_at_least_once()) {
     LOG(WARNING) << "Timer has never been run before reading time.";
index 43e5c01..4c32979 100644 (file)
@@ -66,6 +66,23 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
   CHECK(proto.SerializeToOstream(&output));
 }
 
+cv::Mat ReadImageToCVMat(const string& filename,
+    const int height, const int width, const bool is_color) {
+  cv::Mat cv_img;
+  int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
+    CV_LOAD_IMAGE_GRAYSCALE);
+  if (height > 0 && width > 0) {
+    cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag);
+    cv::resize(cv_img_origin, cv_img, cv::Size(width, height));
+  } else {
+    cv_img = cv::imread(filename, cv_read_flag);
+  }
+  if (!cv_img.data) {
+    LOG(ERROR) << "Could not open or find file " << filename;
+  }
+  return cv_img;
+}
+
 bool ReadImageToDatum(const string& filename, const int label,
     const int height, const int width, const bool is_color, Datum* datum) {
   cv::Mat cv_img;