Merge pull request #511 from kloudkl/extract_multiple_features
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 27 Jun 2014 02:13:28 +0000 (19:13 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 27 Jun 2014 02:13:28 +0000 (19:13 -0700)
Extract multiple features in a single Forward pass

1  2 
tools/extract_features.cpp

@@@ -113,52 -136,57 +136,56 @@@ int feature_extraction_pipeline(int arg
    LOG(ERROR)<< "Extacting Features";
  
    Datum datum;
-   leveldb::WriteBatch* batch = new leveldb::WriteBatch();
+   vector<shared_ptr<leveldb::WriteBatch> > feature_batches(
+       num_features,
+       shared_ptr<leveldb::WriteBatch>(new leveldb::WriteBatch()));
    const int kMaxKeyStrLength = 100;
    char key_str[kMaxKeyStrLength];
 -  int num_bytes_of_binary_code = sizeof(Dtype);
    vector<Blob<float>*> input_vec;
-   int image_index = 0;
+   vector<int> image_indices(num_features, 0);
    for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
      feature_extraction_net->Forward(input_vec);
-     const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
-         ->blob_by_name(extract_feature_blob_name);
-     int num_features = feature_blob->num();
-     int dim_features = feature_blob->count() / num_features;
-     Dtype* feature_blob_data;
-     for (int n = 0; n < num_features; ++n) {
-       datum.set_height(dim_features);
-       datum.set_width(1);
-       datum.set_channels(1);
-       datum.clear_data();
-       datum.clear_float_data();
-       feature_blob_data = feature_blob->mutable_cpu_data() +
-           feature_blob->offset(n);
-       for (int d = 0; d < dim_features; ++d) {
-         datum.add_float_data(feature_blob_data[d]);
-       }
-       string value;
-       datum.SerializeToString(&value);
-       snprintf(key_str, kMaxKeyStrLength, "%d", image_index);
-       batch->Put(string(key_str), value);
-       ++image_index;
-       if (image_index % 1000 == 0) {
-         db->Write(leveldb::WriteOptions(), batch);
-         LOG(ERROR)<< "Extracted features of " << image_index <<
-             " query images.";
-         delete batch;
-         batch = new leveldb::WriteBatch();
-       }
-     }  // for (int n = 0; n < num_features; ++n)
+     for (int i = 0; i < num_features; ++i) {
+       const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
+           ->blob_by_name(blob_names[i]);
+       int batch_size = feature_blob->num();
+       int dim_features = feature_blob->count() / batch_size;
+       Dtype* feature_blob_data;
+       for (int n = 0; n < batch_size; ++n) {
+         datum.set_height(dim_features);
+         datum.set_width(1);
+         datum.set_channels(1);
+         datum.clear_data();
+         datum.clear_float_data();
+         feature_blob_data = feature_blob->mutable_cpu_data() +
+             feature_blob->offset(n);
+         for (int d = 0; d < dim_features; ++d) {
+           datum.add_float_data(feature_blob_data[d]);
+         }
+         string value;
+         datum.SerializeToString(&value);
+         snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
+         feature_batches[i]->Put(string(key_str), value);
+         ++image_indices[i];
+         if (image_indices[i] % 1000 == 0) {
+           feature_dbs[i]->Write(leveldb::WriteOptions(),
+                                 feature_batches[i].get());
+           LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
+               " query images for feature blob " << blob_names[i];
+           feature_batches[i].reset(new leveldb::WriteBatch());
+         }
+       }  // for (int n = 0; n < batch_size; ++n)
+     }  // for (int i = 0; i < num_features; ++i)
    }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
    // write the last batch
-   if (image_index % 1000 != 0) {
-     db->Write(leveldb::WriteOptions(), batch);
-     LOG(ERROR)<< "Extracted features of " << image_index <<
-         " query images.";
+   for (int i = 0; i < num_features; ++i) {
+     if (image_indices[i] % 1000 != 0) {
+       feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get());
+     }
+     LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
+         " query images for feature blob " << blob_names[i];
    }
  
-   delete batch;
-   delete db;
    LOG(ERROR)<< "Successfully extracted the features!";
    return 0;
  }