From a1ea5ba4b2a6e1af9ef360aefd0c8ab777039a5b Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Sun, 12 Oct 2014 14:15:17 -0400 Subject: [PATCH] Updated extract_features to take a leveldb/lmdb config option. --- examples/feature_extraction/readme.md | 2 +- tools/extract_features.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/feature_extraction/readme.md b/examples/feature_extraction/readme.md index c325ed4..6c8917e 100644 --- a/examples/feature_extraction/readme.md +++ b/examples/feature_extraction/readme.md @@ -51,7 +51,7 @@ Extract Features Now everything necessary is in place. - ./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 + ./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 lmdb The name of feature blob that you extract is `fc7`, which represents the highest level feature of the reference model. We can use any other layer, as well, such as `conv5` or `pool3`. diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 1065d44..c4d1a39 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -32,15 +32,15 @@ int main(int argc, char** argv) { template int feature_extraction_pipeline(int argc, char** argv) { ::google::InitGoogleLogging(argv[0]); - const int num_required_args = 6; + const int num_required_args = 7; if (argc < num_required_args) { LOG(ERROR)<< "This program takes in a trained network and an input data layer, and then" " extract features of the input data produced by the net.\n" "Usage: extract_features pretrained_net_param" " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" - " save_feature_database_name1[,name2,...] num_mini_batches [CPU/GPU]" - " [DEVICE_ID=0]\n" + " save_feature_database_name1[,name2,...] num_mini_batches db_type" + " [CPU/GPU] [DEVICE_ID=0]\n" "Note: you can extract multiple features in one pass by specifying" " multiple feature blob names and database names seperated by ','." " The names cannot contain white space characters and the number of blobs" @@ -119,16 +119,16 @@ int feature_extraction_pipeline(int argc, char** argv) { << " in the network " << feature_extraction_proto; } + int num_mini_batches = atoi(argv[++arg_pos]); + std::vector > feature_dbs; for (size_t i = 0; i < num_features; ++i) { LOG(INFO)<< "Opening database " << database_names[i]; - shared_ptr database = DatabaseFactory("leveldb"); + shared_ptr database = DatabaseFactory(argv[++arg_pos]); database->open(database_names.at(i), Database::New); feature_dbs.push_back(database); } - int num_mini_batches = atoi(argv[++arg_pos]); - LOG(ERROR)<< "Extacting Features"; Datum datum; -- 2.7.4