Updated extract_features to take a leveldb/lmdb config option.
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Sun, 12 Oct 2014 18:15:17 +0000 (14:15 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:31:30 +0000 (19:31 -0400)
examples/feature_extraction/readme.md
tools/extract_features.cpp

index c325ed4..6c8917e 100644 (file)
@@ -51,7 +51,7 @@ Extract Features
 
 Now everything necessary is in place.
 
-    ./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10
+    ./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 lmdb
 
 The name of feature blob that you extract is `fc7`, which represents the highest level feature of the reference model.
 We can use any other layer, as well, such as `conv5` or `pool3`.
index 1065d44..c4d1a39 100644 (file)
@@ -32,15 +32,15 @@ int main(int argc, char** argv) {
 template<typename Dtype>
 int feature_extraction_pipeline(int argc, char** argv) {
   ::google::InitGoogleLogging(argv[0]);
-  const int num_required_args = 6;
+  const int num_required_args = 7;
   if (argc < num_required_args) {
     LOG(ERROR)<<
     "This program takes in a trained network and an input data layer, and then"
     " extract features of the input data produced by the net.\n"
     "Usage: extract_features  pretrained_net_param"
     "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
-    "  save_feature_database_name1[,name2,...]  num_mini_batches  [CPU/GPU]"
-    "  [DEVICE_ID=0]\n"
+    "  save_feature_database_name1[,name2,...]  num_mini_batches  db_type"
+    "  [CPU/GPU] [DEVICE_ID=0]\n"
     "Note: you can extract multiple features in one pass by specifying"
     " multiple feature blob names and database names seperated by ','."
     " The names cannot contain white space characters and the number of blobs"
@@ -119,16 +119,16 @@ int feature_extraction_pipeline(int argc, char** argv) {
         << " in the network " << feature_extraction_proto;
   }
 
+  int num_mini_batches = atoi(argv[++arg_pos]);
+
   std::vector<shared_ptr<Database> > feature_dbs;
   for (size_t i = 0; i < num_features; ++i) {
     LOG(INFO)<< "Opening database " << database_names[i];
-    shared_ptr<Database> database = DatabaseFactory("leveldb");
+    shared_ptr<Database> database = DatabaseFactory(argv[++arg_pos]);
     database->open(database_names.at(i), Database::New);
     feature_dbs.push_back(database);
   }
 
-  int num_mini_batches = atoi(argv[++arg_pos]);
-
   LOG(ERROR)<< "Extacting Features";
 
   Datum datum;