Upgrade compute_image_mean to use gflags, accept list_of_images, and print mean_values
authorSergio <sguada@gmail.com>
Wed, 15 Oct 2014 22:36:42 +0000 (15:36 -0700)
committerSergio <sguada@gmail.com>
Thu, 16 Oct 2014 00:03:18 +0000 (17:03 -0700)
examples/cifar10/create_cifar10.sh
src/caffe/layers/data_layer.cpp
tools/compute_image_mean.cpp

index ebfca46..a42725c 100755 (executable)
@@ -13,7 +13,7 @@ rm -rf $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/cifar10_test_$DBTYPE
 
 echo "Computing image mean..."
 
-./build/tools/compute_image_mean $EXAMPLE/cifar10_train_$DBTYPE \
-  $EXAMPLE/mean.binaryproto $DBTYPE
+./build/tools/compute_image_mean -backend=$DBTYPE \
+  $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/mean.binaryproto
 
 echo "Done."
index 7ddd466..05f2e29 100644 (file)
@@ -46,9 +46,9 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   }
   // Read a data point, and use it to initialize the top blob.
   CHECK(iter_ != dataset_->end());
-  const Datum& datum = iter_->value;
+  Datum datum = iter_->value;
 
-  if (DecodeDatum(datum)) {
+  if (DecodeDatum(&datum)) {
     LOG(INFO) << "Decoding Datum";
   }
   // image
index a720f16..4c2c658 100644 (file)
@@ -1,3 +1,4 @@
+#include <gflags/gflags.h>
 #include <glog/logging.h>
 #include <stdint.h>
 
@@ -12,19 +13,31 @@ using caffe::Dataset;
 using caffe::Datum;
 using caffe::BlobProto;
 using std::max;
+using std::pair;
+
+
+DEFINE_string(backend, "lmdb", "The backend for containing the images");
 
 int main(int argc, char** argv) {
   ::google::InitGoogleLogging(argv[0]);
-  if (argc < 3 || argc > 4) {
-    LOG(ERROR) << "Usage: compute_image_mean input_db output_file"
-               << " db_backend[leveldb or lmdb]";
+
+#ifndef GFLAGS_GFLAGS_H_
+  namespace gflags = google;
+#endif
+
+  gflags::SetUsageMessage("Compute the mean_image of a set of images given by"
+        " a leveldb/lmdb or a list of images\n"
+        "Usage:\n"
+        "    compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n");
+
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+  if (argc < 2 || argc > 3) {
+    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean");
     return 1;
   }
 
-  std::string db_backend = "lmdb";
-  if (argc == 4) {
-    db_backend = std::string(argv[3]);
-  }
+  std::string db_backend = FLAGS_backend;
 
   caffe::shared_ptr<Dataset<std::string, Datum> > dataset =
       caffe::DatasetFactory<std::string, Datum>(db_backend);
@@ -36,7 +49,11 @@ int main(int argc, char** argv) {
   int count = 0;
   // load first datum
   Dataset<std::string, Datum>::const_iterator iter = dataset->begin();
-  const Datum& datum = iter->value;
+  Datum datum = iter->value;
+
+  if (DecodeDatum(&datum)) {
+    LOG(INFO) << "Decoding Datum";
+  }
 
   sum_blob.set_num(1);
   sum_blob.set_channels(datum.channels());
@@ -51,8 +68,9 @@ int main(int argc, char** argv) {
   LOG(INFO) << "Starting Iteration";
   for (Dataset<std::string, Datum>::const_iterator iter = dataset->begin();
       iter != dataset->end(); ++iter) {
-    // just a dummy operation
-    const Datum& datum = iter->value;
+    Datum datum = iter->value;
+    DecodeDatum(&datum);
+
     const std::string& data = datum.data();
     size_in_datum = std::max<int>(datum.data().size(),
         datum.float_data_size());
@@ -72,20 +90,31 @@ int main(int argc, char** argv) {
     }
     ++count;
     if (count % 10000 == 0) {
-      LOG(ERROR) << "Processed " << count << " files.";
+      LOG(INFO) << "Processed " << count << " files.";
     }
   }
 
   if (count % 10000 != 0) {
-    LOG(ERROR) << "Processed " << count << " files.";
+    LOG(INFO) << "Processed " << count << " files.";
   }
   for (int i = 0; i < sum_blob.data_size(); ++i) {
     sum_blob.set_data(i, sum_blob.data(i) / count);
   }
   // Write to disk
-  LOG(INFO) << "Write to " << argv[2];
-  WriteProtoToBinaryFile(sum_blob, argv[2]);
-
+  if (argc == 3) {
+    LOG(INFO) << "Write to " << argv[2];
+    WriteProtoToBinaryFile(sum_blob, argv[2]);
+  }
+  const int channels = sum_blob.channels();
+  const int dim = sum_blob.height() * sum_blob.width();
+  std::vector<float> mean_values(channels, 0.0);
+  LOG(INFO) << "Number of channels: " << channels;
+  for (int c = 0; c < channels; ++c) {
+    for (int i = 0; i < dim; ++i) {
+      mean_values[c] += sum_blob.data(dim * c + i);
+    }
+    LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c] / dim;
+  }
   // Clean up
   dataset->close();
   return 0;