Extract multiple features in a single Forward pass
authorKai Li <kaili_kloud@163.com>
Tue, 17 Jun 2014 06:53:00 +0000 (14:53 +0800)
committerKai Li <kaili_kloud@163.com>
Tue, 24 Jun 2014 14:12:44 +0000 (22:12 +0800)
tools/extract_features.cpp

index cdad667..99eb0db 100644 (file)
@@ -5,6 +5,7 @@
 #include <google/protobuf/text_format.h>
 #include <leveldb/db.h>
 #include <leveldb/write_batch.h>
+#include <boost/algorithm/string.hpp>
 #include <string>
 #include <vector>
 
@@ -32,9 +33,14 @@ int feature_extraction_pipeline(int argc, char** argv) {
     LOG(ERROR)<<
     "This program takes in a trained network and an input data layer, and then"
     " extract features of the input data produced by the net.\n"
-    "Usage: demo_extract_features  pretrained_net_param"
-    "  feature_extraction_proto_file  extract_feature_blob_name"
-    "  save_feature_leveldb_name  num_mini_batches  [CPU/GPU]  [DEVICE_ID=0]";
+    "Usage: extract_features  pretrained_net_param"
+    "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
+    "  save_feature_leveldb_name1[,name2,...]  num_mini_batches  [CPU/GPU]"
+    "  [DEVICE_ID=0]\n"
+    "Note: you can extract multiple features in one pass by specifying"
+    " multiple feature blob names and leveldb names seperated by ','."
+    " The names cannot contain white space characters and the number of blobs"
+    " and leveldbs must be equal.";
     return 1;
   }
   int arg_pos = num_required_args;
@@ -91,29 +97,46 @@ int feature_extraction_pipeline(int argc, char** argv) {
       new Net<Dtype>(feature_extraction_proto));
   feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
 
-  string extract_feature_blob_name(argv[++arg_pos]);
-  CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
-      << "Unknown feature blob name " << extract_feature_blob_name
-      << " in the network " << feature_extraction_proto;
+  string extract_feature_blob_names(argv[++arg_pos]);
+  vector<string> blob_names;
+  boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
+
+  string save_feature_leveldb_names(argv[++arg_pos]);
+  vector<string> leveldb_names;
+  boost::split(leveldb_names, save_feature_leveldb_names,
+               boost::is_any_of(","));
+  CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
+      " the number of blob names and leveldb names must be equal";
+  size_t num_features = blob_names.size();
+
+  for (size_t i = 0; i < num_features; i++) {
+    CHECK(feature_extraction_net->has_blob(blob_names[i]))
+        << "Unknown feature blob name " << blob_names[i]
+        << " in the network " << feature_extraction_proto;
+  }
 
-  string save_feature_leveldb_name(argv[++arg_pos]);
-  leveldb::DB* db;
   leveldb::Options options;
   options.error_if_exists = true;
   options.create_if_missing = true;
   options.write_buffer_size = 268435456;
-  LOG(INFO)<< "Opening leveldb " << save_feature_leveldb_name;
-  leveldb::Status status = leveldb::DB::Open(options,
-                                             save_feature_leveldb_name.c_str(),
-                                             &db);
-  CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
+  vector<leveldb::DB*> feature_dbs;
+  for (size_t i = 0; i < num_features; ++i) {
+    LOG(INFO)<< "Opening leveldb " << leveldb_names[i];
+    leveldb::DB* db;
+    leveldb::Status status = leveldb::DB::Open(options,
+                                               leveldb_names[i].c_str(),
+                                               &db);
+    CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i];
+    feature_dbs.push_back(db);
+  }
 
   int num_mini_batches = atoi(argv[++arg_pos]);
 
   LOG(ERROR)<< "Extacting Features";
 
   Datum datum;
-  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
+  vector<leveldb::WriteBatch*> feature_batches(
+      num_features, new leveldb::WriteBatch());
   const int kMaxKeyStrLength = 100;
   char key_str[kMaxKeyStrLength];
   int num_bytes_of_binary_code = sizeof(Dtype);
@@ -121,45 +144,51 @@ int feature_extraction_pipeline(int argc, char** argv) {
   int image_index = 0;
   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
     feature_extraction_net->Forward(input_vec);
-    const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
-        ->blob_by_name(extract_feature_blob_name);
-    int num_features = feature_blob->num();
-    int dim_features = feature_blob->count() / num_features;
-    Dtype* feature_blob_data;
-    for (int n = 0; n < num_features; ++n) {
-      datum.set_height(dim_features);
-      datum.set_width(1);
-      datum.set_channels(1);
-      datum.clear_data();
-      datum.clear_float_data();
-      feature_blob_data = feature_blob->mutable_cpu_data() +
-          feature_blob->offset(n);
-      for (int d = 0; d < dim_features; ++d) {
-        datum.add_float_data(feature_blob_data[d]);
-      }
-      string value;
-      datum.SerializeToString(&value);
-      snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
-      batch->Put(string(key_str), value);
-      ++image_index;
-      if (image_index % 1000 == 0) {
-        db->Write(leveldb::WriteOptions(), batch);
-        LOG(ERROR)<< "Extracted features of " << image_index <<
-            " query images.";
-        delete batch;
-        batch = new leveldb::WriteBatch();
-      }
-    }  // for (int n = 0; n < num_features; ++n)
+    for (int i = 0; i < num_features; ++i) {
+      const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
+          ->blob_by_name(blob_names[i]);
+      int batch_size = feature_blob->num();
+      int dim_features = feature_blob->count() / batch_size;
+      Dtype* feature_blob_data;
+      for (int n = 0; n < batch_size; ++n) {
+        datum.set_height(dim_features);
+        datum.set_width(1);
+        datum.set_channels(1);
+        datum.clear_data();
+        datum.clear_float_data();
+        feature_blob_data = feature_blob->mutable_cpu_data() +
+            feature_blob->offset(n);
+        for (int d = 0; d < dim_features; ++d) {
+          datum.add_float_data(feature_blob_data[d]);
+        }
+        string value;
+        datum.SerializeToString(&value);
+        snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
+        feature_batches[i]->Put(string(key_str), value);
+        ++image_index;
+        if (image_index % 1000 == 0) {
+          feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i]);
+          LOG(ERROR)<< "Extracted features of " << image_index <<
+              " query images.";
+          delete feature_batches[i];
+          feature_batches[i] = new leveldb::WriteBatch();
+        }
+      }  // for (int n = 0; n < batch_size; ++n)
+    }  // for (int i = 0; i < num_features; ++i)
   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
   // write the last batch
   if (image_index % 1000 != 0) {
-    db->Write(leveldb::WriteOptions(), batch);
+    for (int i = 0; i < num_features; ++i) {
+      feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i]);
+    }
     LOG(ERROR)<< "Extracted features of " << image_index <<
         " query images.";
   }
 
-  delete batch;
-  delete db;
+  for (int i = 0; i < num_features; ++i) {
+    delete feature_batches[i];
+    delete feature_dbs[i];
+  }
   LOG(ERROR)<< "Successfully extracted the features!";
   return 0;
 }