#include <google/protobuf/text_format.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
+#include <boost/algorithm/string.hpp>
#include <string>
#include <vector>
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net.\n"
- "Usage: demo_extract_features pretrained_net_param"
- " feature_extraction_proto_file extract_feature_blob_name"
- " save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
+ "Usage: extract_features pretrained_net_param"
+ " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]"
+ " save_feature_leveldb_name1[,name2,...] num_mini_batches [CPU/GPU]"
+ " [DEVICE_ID=0]\n"
+ "Note: you can extract multiple features in one pass by specifying"
+ " multiple feature blob names and leveldb names seperated by ','."
+ " The names cannot contain white space characters and the number of blobs"
+ " and leveldbs must be equal.";
return 1;
}
int arg_pos = num_required_args;
new Net<Dtype>(feature_extraction_proto));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
- string extract_feature_blob_name(argv[++arg_pos]);
- CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
- << "Unknown feature blob name " << extract_feature_blob_name
- << " in the network " << feature_extraction_proto;
+ string extract_feature_blob_names(argv[++arg_pos]);
+ vector<string> blob_names;
+ boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
+
+ string save_feature_leveldb_names(argv[++arg_pos]);
+ vector<string> leveldb_names;
+ boost::split(leveldb_names, save_feature_leveldb_names,
+ boost::is_any_of(","));
+ CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
+ " the number of blob names and leveldb names must be equal";
+ size_t num_features = blob_names.size();
+
+ for (size_t i = 0; i < num_features; i++) {
+ CHECK(feature_extraction_net->has_blob(blob_names[i]))
+ << "Unknown feature blob name " << blob_names[i]
+ << " in the network " << feature_extraction_proto;
+ }
- string save_feature_leveldb_name(argv[++arg_pos]);
- leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
- LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
- leveldb::Status status = leveldb::DB::Open(options,
- save_feature_leveldb_name.c_str(),
- &db);
- CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
+ vector<leveldb::DB*> feature_dbs;
+ for (size_t i = 0; i < num_features; ++i) {
+ LOG(INFO)<< "Opening leveldb " << leveldb_names[i];
+ leveldb::DB* db;
+ leveldb::Status status = leveldb::DB::Open(options,
+ leveldb_names[i].c_str(),
+ &db);
+ CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i];
+ feature_dbs.push_back(db);
+ }
int num_mini_batches = atoi(argv[++arg_pos]);
LOG(ERROR)<< "Extacting Features";
Datum datum;
- leveldb::WriteBatch* batch = new leveldb::WriteBatch();
+ vector<leveldb::WriteBatch*> feature_batches(
+ num_features, new leveldb::WriteBatch());
const int kMaxKeyStrLength = 100;
char key_str[kMaxKeyStrLength];
int num_bytes_of_binary_code = sizeof(Dtype);
int image_index = 0;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
- const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
- ->blob_by_name(extract_feature_blob_name);
- int num_features = feature_blob->num();
- int dim_features = feature_blob->count() / num_features;
- Dtype* feature_blob_data;
- for (int n = 0; n < num_features; ++n) {
- datum.set_height(dim_features);
- datum.set_width(1);
- datum.set_channels(1);
- datum.clear_data();
- datum.clear_float_data();
- feature_blob_data = feature_blob->mutable_cpu_data() +
- feature_blob->offset(n);
- for (int d = 0; d < dim_features; ++d) {
- datum.add_float_data(feature_blob_data[d]);
- }
- string value;
- datum.SerializeToString(&value);
- snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
- batch->Put(string(key_str), value);
- ++image_index;
- if (image_index % 1000 == 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR)<< "Extracted features of " << image_index <<
- " query images.";
- delete batch;
- batch = new leveldb::WriteBatch();
- }
- } // for (int n = 0; n < num_features; ++n)
+ for (int i = 0; i < num_features; ++i) {
+ const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
+ ->blob_by_name(blob_names[i]);
+ int batch_size = feature_blob->num();
+ int dim_features = feature_blob->count() / batch_size;
+ Dtype* feature_blob_data;
+ for (int n = 0; n < batch_size; ++n) {
+ datum.set_height(dim_features);
+ datum.set_width(1);
+ datum.set_channels(1);
+ datum.clear_data();
+ datum.clear_float_data();
+ feature_blob_data = feature_blob->mutable_cpu_data() +
+ feature_blob->offset(n);
+ for (int d = 0; d < dim_features; ++d) {
+ datum.add_float_data(feature_blob_data[d]);
+ }
+ string value;
+ datum.SerializeToString(&value);
+ snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
+ feature_batches[i]->Put(string(key_str), value);
+ ++image_index;
+ if (image_index % 1000 == 0) {
+ feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i]);
+ LOG(ERROR)<< "Extracted features of " << image_index <<
+ " query images.";
+ delete feature_batches[i];
+ feature_batches[i] = new leveldb::WriteBatch();
+ }
+ } // for (int n = 0; n < batch_size; ++n)
+ } // for (int i = 0; i < num_features; ++i)
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
if (image_index % 1000 != 0) {
- db->Write(leveldb::WriteOptions(), batch);
+ for (int i = 0; i < num_features; ++i) {
+ feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i]);
+ }
LOG(ERROR)<< "Extracted features of " << image_index <<
" query images.";
}
- delete batch;
- delete db;
+ for (int i = 0; i < num_features; ++i) {
+ delete feature_batches[i];
+ delete feature_dbs[i];
+ }
LOG(ERROR)<< "Successfully extracted the features!";
return 0;
}