From da0f89bbd495c604f63744374165cdc9f3cbba9f Mon Sep 17 00:00:00 2001 From: Kai Li Date: Tue, 17 Jun 2014 14:53:00 +0800 Subject: [PATCH] Extract multiple features in a single Forward pass --- tools/extract_features.cpp | 123 ++++++++++++++++++++++++++++----------------- 1 file changed, 76 insertions(+), 47 deletions(-) diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index cdad667..99eb0db 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -32,9 +33,14 @@ int feature_extraction_pipeline(int argc, char** argv) { LOG(ERROR)<< "This program takes in a trained network and an input data layer, and then" " extract features of the input data produced by the net.\n" - "Usage: demo_extract_features pretrained_net_param" - " feature_extraction_proto_file extract_feature_blob_name" - " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]"; + "Usage: extract_features pretrained_net_param" + " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" + " save_feature_leveldb_name1[,name2,...] num_mini_batches [CPU/GPU]" + " [DEVICE_ID=0]\n" + "Note: you can extract multiple features in one pass by specifying" + " multiple feature blob names and leveldb names seperated by ','." + " The names cannot contain white space characters and the number of blobs" + " and leveldbs must be equal."; return 1; } int arg_pos = num_required_args; @@ -91,29 +97,46 @@ int feature_extraction_pipeline(int argc, char** argv) { new Net(feature_extraction_proto)); feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); - string extract_feature_blob_name(argv[++arg_pos]); - CHECK(feature_extraction_net->has_blob(extract_feature_blob_name)) - << "Unknown feature blob name " << extract_feature_blob_name - << " in the network " << feature_extraction_proto; + string extract_feature_blob_names(argv[++arg_pos]); + vector blob_names; + boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); + + string save_feature_leveldb_names(argv[++arg_pos]); + vector leveldb_names; + boost::split(leveldb_names, save_feature_leveldb_names, + boost::is_any_of(",")); + CHECK_EQ(blob_names.size(), leveldb_names.size()) << + " the number of blob names and leveldb names must be equal"; + size_t num_features = blob_names.size(); + + for (size_t i = 0; i < num_features; i++) { + CHECK(feature_extraction_net->has_blob(blob_names[i])) + << "Unknown feature blob name " << blob_names[i] + << " in the network " << feature_extraction_proto; + } - string save_feature_leveldb_name(argv[++arg_pos]); - leveldb::DB* db; leveldb::Options options; options.error_if_exists = true; options.create_if_missing = true; options.write_buffer_size = 268435456; - LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name; - leveldb::Status status = leveldb::DB::Open(options, - save_feature_leveldb_name.c_str(), - &db); - CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name; + vector feature_dbs; + for (size_t i = 0; i < num_features; ++i) { + LOG(INFO)<< "Opening leveldb " << leveldb_names[i]; + leveldb::DB* db; + leveldb::Status status = leveldb::DB::Open(options, + leveldb_names[i].c_str(), + &db); + CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i]; + feature_dbs.push_back(db); + } int num_mini_batches = atoi(argv[++arg_pos]); LOG(ERROR)<< "Extacting Features"; Datum datum; - leveldb::WriteBatch* batch = new leveldb::WriteBatch(); + vector feature_batches( + num_features, new leveldb::WriteBatch()); const int kMaxKeyStrLength = 100; char key_str[kMaxKeyStrLength]; int num_bytes_of_binary_code = sizeof(Dtype); @@ -121,45 +144,51 @@ int feature_extraction_pipeline(int argc, char** argv) { int image_index = 0; for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) { feature_extraction_net->Forward(input_vec); - const shared_ptr > feature_blob = feature_extraction_net - ->blob_by_name(extract_feature_blob_name); - int num_features = feature_blob->num(); - int dim_features = feature_blob->count() / num_features; - Dtype* feature_blob_data; - for (int n = 0; n < num_features; ++n) { - datum.set_height(dim_features); - datum.set_width(1); - datum.set_channels(1); - datum.clear_data(); - datum.clear_float_data(); - feature_blob_data = feature_blob->mutable_cpu_data() + - feature_blob->offset(n); - for (int d = 0; d < dim_features; ++d) { - datum.add_float_data(feature_blob_data[d]); - } - string value; - datum.SerializeToString(&value); - snprintf(key_str, kMaxKeyStrLength, "%d", image_index); - batch->Put(string(key_str), value); - ++image_index; - if (image_index % 1000 == 0) { - db->Write(leveldb::WriteOptions(), batch); - LOG(ERROR)<< "Extracted features of " << image_index << - " query images."; - delete batch; - batch = new leveldb::WriteBatch(); - } - } // for (int n = 0; n < num_features; ++n) + for (int i = 0; i < num_features; ++i) { + const shared_ptr > feature_blob = feature_extraction_net + ->blob_by_name(blob_names[i]); + int batch_size = feature_blob->num(); + int dim_features = feature_blob->count() / batch_size; + Dtype* feature_blob_data; + for (int n = 0; n < batch_size; ++n) { + datum.set_height(dim_features); + datum.set_width(1); + datum.set_channels(1); + datum.clear_data(); + datum.clear_float_data(); + feature_blob_data = feature_blob->mutable_cpu_data() + + feature_blob->offset(n); + for (int d = 0; d < dim_features; ++d) { + datum.add_float_data(feature_blob_data[d]); + } + string value; + datum.SerializeToString(&value); + snprintf(key_str, kMaxKeyStrLength, "%d", image_index); + feature_batches[i]->Put(string(key_str), value); + ++image_index; + if (image_index % 1000 == 0) { + feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i]); + LOG(ERROR)<< "Extracted features of " << image_index << + " query images."; + delete feature_batches[i]; + feature_batches[i] = new leveldb::WriteBatch(); + } + } // for (int n = 0; n < batch_size; ++n) + } // for (int i = 0; i < num_features; ++i) } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) // write the last batch if (image_index % 1000 != 0) { - db->Write(leveldb::WriteOptions(), batch); + for (int i = 0; i < num_features; ++i) { + feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i]); + } LOG(ERROR)<< "Extracted features of " << image_index << " query images."; } - delete batch; - delete db; + for (int i = 0; i < num_features; ++i) { + delete feature_batches[i]; + delete feature_dbs[i]; + } LOG(ERROR)<< "Successfully extracted the features!"; return 0; } -- 2.7.4