added a force_encoded_color flag to the data layer. Printing a warning if images...
authorphilkr <philkr@users.noreply.github.com>
Thu, 19 Feb 2015 21:46:21 +0000 (13:46 -0800)
committerphilkr <philkr@users.noreply.github.com>
Fri, 20 Feb 2015 01:27:25 +0000 (17:27 -0800)
examples/images/cat_gray.jpg [new file with mode: 0644]
include/caffe/util/io.hpp
src/caffe/layers/data_layer.cpp
src/caffe/layers/window_data_layer.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_io.cpp
src/caffe/util/io.cpp

diff --git a/examples/images/cat_gray.jpg b/examples/images/cat_gray.jpg
new file mode 100644 (file)
index 0000000..43c5ce3
Binary files /dev/null and b/examples/images/cat_gray.jpg differ
index 9d7540d..3a62c3c 100644 (file)
@@ -122,6 +122,7 @@ inline bool ReadImageToDatum(const string& filename, const int label,
 }
 
 bool DecodeDatumNative(Datum* datum);
+bool DecodeDatum(Datum* datum, bool is_color);
 
 cv::Mat ReadImageToCVMat(const string& filename,
     const int height, const int width, const bool is_color);
@@ -135,6 +136,7 @@ cv::Mat ReadImageToCVMat(const string& filename,
 cv::Mat ReadImageToCVMat(const string& filename);
 
 cv::Mat DecodeDatumToCVMatNative(const Datum& datum);
+cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
 
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
 
index 891d039..7716406 100644 (file)
@@ -42,7 +42,9 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   Datum datum;
   datum.ParseFromString(cursor_->value());
 
-  if (DecodeDatumNative(&datum)) {
+  bool force_color = this->layer_param_.data_param().force_encoded_color();
+  if ((force_color && DecodeDatum(&datum, true)) ||
+      DecodeDatumNative(&datum)) {
     LOG(INFO) << "Decoding Datum";
   }
   // image
@@ -90,6 +92,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
     top_label = this->prefetch_label_.mutable_cpu_data();
   }
   const int batch_size = this->layer_param_.data_param().batch_size();
+  bool force_color = this->layer_param_.data_param().force_encoded_color();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     timer.Start();
     // get a blob
@@ -98,7 +101,15 @@ void DataLayer<Dtype>::InternalThreadEntry() {
 
     cv::Mat cv_img;
     if (datum.encoded()) {
-       cv_img = DecodeDatumToCVMatNative(datum);
+      if (force_color)
+        cv_img = DecodeDatumToCVMat(datum, true);
+      else
+        cv_img = DecodeDatumToCVMatNative(datum);
+      if (cv_img.channels() != this->transformed_data_.channels())
+        LOG(WARNING) << "Your dataset contains encoded images with mixed "
+        << "channel sizes. Consider adding a 'force_color' flag to the "
+        << "model definition, or rebuild your dataset using "
+        << "convert_imageset.";
     }
     read_time += timer.MicroSeconds();
     timer.Start();
index cceb4ff..73408c6 100644 (file)
@@ -281,7 +281,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
       if (this->cache_images_) {
         pair<std::string, Datum> image_cached =
           image_database_cache_[window[WindowDataLayer<Dtype>::IMAGE_INDEX]];
-        cv_img = DecodeDatumToCVMatNative(image_cached.second);
+        cv_img = DecodeDatumToCVMat(image_cached.second, true);
       } else {
         cv_img = cv::imread(image.first, CV_LOAD_IMAGE_COLOR);
         if (!cv_img.data) {
index 8d93742..8ba6075 100644 (file)
@@ -426,6 +426,8 @@ message DataParameter {
   // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
   // data.
   optional bool mirror = 6 [default = false];
+  // Force the encoded image to have 3 color channels
+  optional bool force_encoded_color = 9 [default = false];
 }
 
 // Message that stores parameters used by DropoutLayer
index 6b135ef..4ab9631 100644 (file)
@@ -289,6 +289,60 @@ TEST_F(IOTest, TestDecodeDatum) {
   string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
   Datum datum;
   EXPECT_TRUE(ReadFileToDatum(filename, &datum));
+  EXPECT_TRUE(DecodeDatum(&datum, true));
+  EXPECT_FALSE(DecodeDatum(&datum, true));
+  Datum datum_ref;
+  ReadImageToDatumReference(filename, 0, 0, 0, true, &datum_ref);
+  EXPECT_EQ(datum.channels(), datum_ref.channels());
+  EXPECT_EQ(datum.height(), datum_ref.height());
+  EXPECT_EQ(datum.width(), datum_ref.width());
+  EXPECT_EQ(datum.data().size(), datum_ref.data().size());
+
+  const string& data = datum.data();
+  const string& data_ref = datum_ref.data();
+  for (int i = 0; i < datum.data().size(); ++i) {
+    EXPECT_TRUE(data[i] == data_ref[i]);
+  }
+}
+
+TEST_F(IOTest, TestDecodeDatumToCVMat) {
+  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
+  Datum datum;
+  EXPECT_TRUE(ReadFileToDatum(filename, &datum));
+  cv::Mat cv_img = DecodeDatumToCVMat(datum, true);
+  EXPECT_EQ(cv_img.channels(), 3);
+  EXPECT_EQ(cv_img.rows, 360);
+  EXPECT_EQ(cv_img.cols, 480);
+  cv_img = DecodeDatumToCVMat(datum, false);
+  EXPECT_EQ(cv_img.channels(), 1);
+  EXPECT_EQ(cv_img.rows, 360);
+  EXPECT_EQ(cv_img.cols, 480);
+}
+
+TEST_F(IOTest, TestDecodeDatumToCVMatContent) {
+  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
+  Datum datum;
+  EXPECT_TRUE(ReadImageToDatum(filename, 0, std::string("jpg"), &datum));
+  cv::Mat cv_img = DecodeDatumToCVMat(datum, true);
+  cv::Mat cv_img_ref = ReadImageToCVMat(filename);
+  EXPECT_EQ(cv_img_ref.channels(), cv_img.channels());
+  EXPECT_EQ(cv_img_ref.rows, cv_img.rows);
+  EXPECT_EQ(cv_img_ref.cols, cv_img.cols);
+
+  for (int c = 0; c < datum.channels(); ++c) {
+    for (int h = 0; h < datum.height(); ++h) {
+      for (int w = 0; w < datum.width(); ++w) {
+        EXPECT_TRUE(cv_img.at<cv::Vec3b>(h, w)[c]==
+          cv_img_ref.at<cv::Vec3b>(h, w)[c]);
+      }
+    }
+  }
+}
+
+TEST_F(IOTest, TestDecodeDatumNative) {
+  string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
+  Datum datum;
+  EXPECT_TRUE(ReadFileToDatum(filename, &datum));
   EXPECT_TRUE(DecodeDatumNative(&datum));
   EXPECT_FALSE(DecodeDatumNative(&datum));
   Datum datum_ref;
@@ -305,7 +359,7 @@ TEST_F(IOTest, TestDecodeDatum) {
   }
 }
 
-TEST_F(IOTest, TestDecodeDatumToCVMat) {
+TEST_F(IOTest, TestDecodeDatumToCVMatNative) {
   string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
   Datum datum;
   EXPECT_TRUE(ReadFileToDatum(filename, &datum));
@@ -315,7 +369,37 @@ TEST_F(IOTest, TestDecodeDatumToCVMat) {
   EXPECT_EQ(cv_img.cols, 480);
 }
 
-TEST_F(IOTest, TestDecodeDatumToCVMatContent) {
+TEST_F(IOTest, TestDecodeDatumNativeGray) {
+  string filename = EXAMPLES_SOURCE_DIR "images/cat_gray.jpg";
+  Datum datum;
+  EXPECT_TRUE(ReadFileToDatum(filename, &datum));
+  EXPECT_TRUE(DecodeDatumNative(&datum));
+  EXPECT_FALSE(DecodeDatumNative(&datum));
+  Datum datum_ref;
+  ReadImageToDatumReference(filename, 0, 0, 0, false, &datum_ref);
+  EXPECT_EQ(datum.channels(), datum_ref.channels());
+  EXPECT_EQ(datum.height(), datum_ref.height());
+  EXPECT_EQ(datum.width(), datum_ref.width());
+  EXPECT_EQ(datum.data().size(), datum_ref.data().size());
+
+  const string& data = datum.data();
+  const string& data_ref = datum_ref.data();
+  for (int i = 0; i < datum.data().size(); ++i) {
+    EXPECT_TRUE(data[i] == data_ref[i]);
+  }
+}
+
+TEST_F(IOTest, TestDecodeDatumToCVMatNativeGray) {
+  string filename = EXAMPLES_SOURCE_DIR "images/cat_gray.jpg";
+  Datum datum;
+  EXPECT_TRUE(ReadFileToDatum(filename, &datum));
+  cv::Mat cv_img = DecodeDatumToCVMatNative(datum);
+  EXPECT_EQ(cv_img.channels(), 1);
+  EXPECT_EQ(cv_img.rows, 360);
+  EXPECT_EQ(cv_img.cols, 480);
+}
+
+TEST_F(IOTest, TestDecodeDatumToCVMatContentNative) {
   string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
   Datum datum;
   EXPECT_TRUE(ReadImageToDatum(filename, 0, std::string("jpg"), &datum));
index 6553168..b243a98 100644 (file)
@@ -167,9 +167,21 @@ cv::Mat DecodeDatumToCVMatNative(const Datum& datum) {
   }
   return cv_img;
 }
+cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) {
+  cv::Mat cv_img;
+  CHECK(datum.encoded()) << "Datum not encoded";
+  const string& data = datum.data();
+  std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
+  int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
+    CV_LOAD_IMAGE_GRAYSCALE);
+  cv_img = cv::imdecode(vec_data, cv_read_flag);
+  if (!cv_img.data) {
+    LOG(ERROR) << "Could not decode datum ";
+  }
+  return cv_img;
+}
 
 // If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum
-// if height and width are set it will resize it
 // If Datum is not encoded will do nothing
 bool DecodeDatumNative(Datum* datum) {
   if (datum->encoded()) {
@@ -180,6 +192,15 @@ bool DecodeDatumNative(Datum* datum) {
     return false;
   }
 }
+bool DecodeDatum(Datum* datum, bool is_color) {
+  if (datum->encoded()) {
+    cv::Mat cv_img = DecodeDatumToCVMat((*datum), is_color);
+    CVMatToDatum(cv_img, datum);
+    return true;
+  } else {
+    return false;
+  }
+}
 
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
   CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";