Added an iscolor flag to io.cpp method ReadImageToDatum to handle grayscale images...
authorJerod Weinman <jerod@acm.org>
Mon, 2 Jun 2014 20:54:25 +0000 (15:54 -0500)
committerJerod Weinman <jerod@acm.org>
Mon, 2 Jun 2014 20:54:25 +0000 (15:54 -0500)
include/caffe/util/io.hpp
src/caffe/util/io.cpp
tools/convert_imageset.cpp

index 056b573..829293e 100644 (file)
@@ -61,13 +61,19 @@ inline void WriteProtoToBinaryFile(
 }
 
 bool ReadImageToDatum(const string& filename, const int label,
-    const int height, const int width, Datum* datum);
+    const int height, const int width, const bool iscolor, Datum* datum);
+
+inline bool ReadImageToDatum(const string& filename, const int label,
+    const int height, const int width, Datum* datum) {
+  return ReadImageToDatum(filename, label, height, width, true, datum);
+}
 
 inline bool ReadImageToDatum(const string& filename, const int label,
     Datum* datum) {
   return ReadImageToDatum(filename, label, 0, 0, datum);
 }
 
+
 template <typename Dtype>
 void hdf5_load_nd_dataset_helper(
   hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
index 44858f4..f2650e9 100644 (file)
@@ -72,32 +72,43 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
 }
 
 bool ReadImageToDatum(const string& filename, const int label,
-    const int height, const int width, Datum* datum) {
+    const int height, const int width, const bool iscolor, Datum* datum) {
   cv::Mat cv_img;
+  int cv_read_flag = (iscolor ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE);
   if (height > 0 && width > 0) {
-    cv::Mat cv_img_origin = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+    cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag);
     cv::resize(cv_img_origin, cv_img, cv::Size(height, width));
   } else {
-    cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
+    cv_img = cv::imread(filename, cv_read_flag);
   }
   if (!cv_img.data) {
     LOG(ERROR) << "Could not open or find file " << filename;
     return false;
   }
-  datum->set_channels(3);
+  int num_channels = (iscolor ? 3 : 1);
+  datum->set_channels(num_channels);
   datum->set_height(cv_img.rows);
   datum->set_width(cv_img.cols);
   datum->set_label(label);
   datum->clear_data();
   datum->clear_float_data();
   string* datum_string = datum->mutable_data();
-  for (int c = 0; c < 3; ++c) {
+  if (iscolor) {
+    for (int c = 0; c < num_channels; ++c) {
+      for (int h = 0; h < cv_img.rows; ++h) {
+        for (int w = 0; w < cv_img.cols; ++w) {
+          datum_string->push_back(
+            static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
+        }
+      }
+    }
+  } else {  // Faster than repeatedly testing iscolor for each pixel w/i loop
     for (int h = 0; h < cv_img.rows; ++h) {
       for (int w = 0; w < cv_img.cols; ++w) {
         datum_string->push_back(
-            static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
+          static_cast<char>(cv_img.at<uchar>(h, w)));
+        }
       }
-    }
   }
   return true;
 }
index 2420f95..bbf848a 100644 (file)
@@ -2,7 +2,7 @@
 // This program converts a set of images to a leveldb by storing them as Datum
 // proto buffers.
 // Usage:
-//    convert_imageset ROOTFOLDER/ LISTFILE DB_NAME RANDOM_SHUFFLE[0 or 1] \
+//   convert_imageset [-g] ROOTFOLDER/ LISTFILE DB_NAME RANDOM_SHUFFLE[0 or 1] \
 //                     [resize_height] [resize_width]
 // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
 // should be a list of files as well as their labels, in the format as
@@ -10,6 +10,9 @@
 //   ....
 // if RANDOM_SHUFFLE is 1, a random shuffle will be carried out before we
 // process the file lines.
+// Optional flag -g indicates the images should be read as
+// single-channel grayscale. If omitted, grayscale images will be
+// converted to color.
 
 #include <glog/logging.h>
 #include <leveldb/db.h>
@@ -30,24 +33,28 @@ using std::string;
 
 int main(int argc, char** argv) {
   ::google::InitGoogleLogging(argv[0]);
-  if (argc < 4 || argc > 7) {
+  if (argc < 4 || argc > 8) {
     printf("Convert a set of images to the leveldb format used\n"
         "as input for Caffe.\n"
         "Usage:\n"
-        "    convert_imageset ROOTFOLDER/ LISTFILE DB_NAME"
+        "    convert_imageset [-g] ROOTFOLDER/ LISTFILE DB_NAME"
         " RANDOM_SHUFFLE_DATA[0 or 1] [resize_height] [resize_width]\n"
         "The ImageNet dataset for the training demo is at\n"
         "    http://www.image-net.org/download-images\n");
     return 1;
   }
-  std::ifstream infile(argv[2]);
+
+  // Test whether argv[1] == "-g"
+  bool iscolor= !(string("-g") == string(argv[1]));
+  int  arg_offset = (iscolor ? 0 : 1);
+  std::ifstream infile(argv[arg_offset+2]);
   std::vector<std::pair<string, int> > lines;
   string filename;
   int label;
   while (infile >> filename >> label) {
     lines.push_back(std::make_pair(filename, label));
   }
-  if (argc >= 5 && argv[4][0] == '1') {
+  if (argc >= (arg_offset+5) && argv[arg_offset+4][0] == '1') {
     // randomly shuffle data
     LOG(INFO) << "Shuffling data";
     std::random_shuffle(lines.begin(), lines.end());
@@ -55,11 +62,11 @@ int main(int argc, char** argv) {
   LOG(INFO) << "A total of " << lines.size() << " images.";
   int resize_height = 0;
   int resize_width = 0;
-  if (argc >= 6) {
-    resize_height = atoi(argv[5]);
+  if (argc >= (arg_offset+6)) {
+    resize_height = atoi(argv[arg_offset+5]);
   }
-  if (argc >= 7) {
-    resize_width = atoi(argv[6]);
+  if (argc >= (arg_offset+7)) {
+    resize_width = atoi(argv[arg_offset+6]);
   }
 
   leveldb::DB* db;
@@ -67,12 +74,12 @@ int main(int argc, char** argv) {
   options.error_if_exists = true;
   options.create_if_missing = true;
   options.write_buffer_size = 268435456;
-  LOG(INFO) << "Opening leveldb " << argv[3];
+  LOG(INFO) << "Opening leveldb " << argv[arg_offset+3];
   leveldb::Status status = leveldb::DB::Open(
-      options, argv[3], &db);
-  CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
+      options, argv[arg_offset+3], &db);
+  CHECK(status.ok()) << "Failed to open leveldb " << argv[arg_offset+3];
 
-  string root_folder(argv[1]);
+  string root_folder(argv[arg_offset+1]);
   Datum datum;
   int count = 0;
   const int kMaxKeyLength = 256;
@@ -82,7 +89,7 @@ int main(int argc, char** argv) {
   bool data_size_initialized = false;
   for (int line_id = 0; line_id < lines.size(); ++line_id) {
     if (!ReadImageToDatum(root_folder + lines[line_id].first,
-        lines[line_id].second, resize_height, resize_width, &datum)) {
+         lines[line_id].second, resize_height, resize_width, iscolor, &datum)) {
       continue;
     }
     if (!data_size_initialized) {