+#include <gflags/gflags.h>
#include <glog/logging.h>
#include <stdint.h>
using caffe::Datum;
using caffe::BlobProto;
using std::max;
+using std::pair;
+
+
+DEFINE_string(backend, "lmdb", "The backend for containing the images");
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
- if (argc < 3 || argc > 4) {
- LOG(ERROR) << "Usage: compute_image_mean input_db output_file"
- << " db_backend[leveldb or lmdb]";
+
+#ifndef GFLAGS_GFLAGS_H_
+ namespace gflags = google;
+#endif
+
+ gflags::SetUsageMessage("Compute the mean_image of a set of images given by"
+ " a leveldb/lmdb or a list of images\n"
+ "Usage:\n"
+ " compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n");
+
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ if (argc < 2 || argc > 3) {
+ gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean");
return 1;
}
- std::string db_backend = "lmdb";
- if (argc == 4) {
- db_backend = std::string(argv[3]);
- }
+ std::string db_backend = FLAGS_backend;
caffe::shared_ptr<Dataset<std::string, Datum> > dataset =
caffe::DatasetFactory<std::string, Datum>(db_backend);
int count = 0;
// load first datum
Dataset<std::string, Datum>::const_iterator iter = dataset->begin();
- const Datum& datum = iter->value;
+ Datum datum = iter->value;
+
+ if (DecodeDatum(&datum)) {
+ LOG(INFO) << "Decoding Datum";
+ }
sum_blob.set_num(1);
sum_blob.set_channels(datum.channels());
LOG(INFO) << "Starting Iteration";
for (Dataset<std::string, Datum>::const_iterator iter = dataset->begin();
iter != dataset->end(); ++iter) {
- // just a dummy operation
- const Datum& datum = iter->value;
+ Datum datum = iter->value;
+ DecodeDatum(&datum);
+
const std::string& data = datum.data();
size_in_datum = std::max<int>(datum.data().size(),
datum.float_data_size());
}
++count;
if (count % 10000 == 0) {
- LOG(ERROR) << "Processed " << count << " files.";
+ LOG(INFO) << "Processed " << count << " files.";
}
}
if (count % 10000 != 0) {
- LOG(ERROR) << "Processed " << count << " files.";
+ LOG(INFO) << "Processed " << count << " files.";
}
for (int i = 0; i < sum_blob.data_size(); ++i) {
sum_blob.set_data(i, sum_blob.data(i) / count);
}
// Write to disk
- LOG(INFO) << "Write to " << argv[2];
- WriteProtoToBinaryFile(sum_blob, argv[2]);
-
+ if (argc == 3) {
+ LOG(INFO) << "Write to " << argv[2];
+ WriteProtoToBinaryFile(sum_blob, argv[2]);
+ }
+ const int channels = sum_blob.channels();
+ const int dim = sum_blob.height() * sum_blob.width();
+ std::vector<float> mean_values(channels, 0.0);
+ LOG(INFO) << "Number of channels: " << channels;
+ for (int c = 0; c < channels; ++c) {
+ for (int i = 0; i < dim; ++i) {
+ mean_values[c] += sum_blob.data(dim * c + i);
+ }
+ LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c] / dim;
+ }
// Clean up
dataset->close();
return 0;