Simplify image retrieval example to use binary features directly
authorKai Li <kaili_kloud@163.com>
Sun, 23 Feb 2014 14:26:35 +0000 (22:26 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:41 +0000 (23:04 +0800)
examples/demo_retrieve_images.cpp

index 5cfbdea..e12ad36 100644 (file)
@@ -1,17 +1,9 @@
 // Copyright 2014 kloudkl@github
-//
-// This program takes in a trained network and an input blob, and then
-// extract features of the input blobs produced by the net to retrieve similar images.
-// Usage:
-//    retrieve_image pretrained_net_param input_blob output_filename top_k_results [CPU/GPU] [DEVICE_ID=0]
 
-#include <stdio.h> // for snprintf
 #include <fstream> // for std::ofstream
 #include <queue> // for std::priority_queue
 #include <cuda_runtime.h>
 #include <google/protobuf/text_format.h>
-#include <leveldb/db.h>
-#include <leveldb/write_batch.h>
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 using namespace caffe;
 
 template<typename Dtype>
-inline int sign(const Dtype val) {
-  return (Dtype(0) < val) - (val < Dtype(0));
-}
-
-template<typename Dtype>
-void binarize(const int n, const Dtype* real_valued_feature,
-              Dtype* binary_code);
-
-template<typename Dtype>
-void binarize(const shared_ptr<Blob<Dtype> > real_valued_features,
-              shared_ptr<Blob<Dtype> > binary_codes);
-
-template<typename Dtype>
-void similarity_search(const shared_ptr<Blob<Dtype> > sample_images_feature,
-                       const shared_ptr<Blob<Dtype> > query_image_feature,
-                       const int top_k_results,
-                       shared_ptr<Blob<Dtype> > retrieval_results);
+void similarity_search(
+    const vector<shared_ptr<Blob<Dtype> > >& sample_binary_feature_blobs,
+    const shared_ptr<Blob<Dtype> > query_binary_feature,
+    const int top_k_results, shared_ptr<Blob<Dtype> > retrieval_results);
 
 template<typename Dtype>
 int image_retrieval_pipeline(int argc, char** argv);
@@ -52,13 +31,14 @@ int main(int argc, char** argv) {
 
 template<typename Dtype>
 int image_retrieval_pipeline(int argc, char** argv) {
-  const int num_required_args = 7;
+  const int num_required_args = 4;
   if (argc < num_required_args) {
     LOG(ERROR)<<
-    "retrieve_image pretrained_net_param extract__feature_blob_name"
-    " sample_images_feature_blob_binaryproto data_prototxt data_layer_name"
-    " save_feature_leveldb_name save_retrieval_result_filename"
-    " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]";
+    "This program takes in binarized features of query images and sample images"
+    "  extracted by Caffe to retrieve similar images."
+    "Usage: demo_retrieve_images  sample_binary_features_binaryproto_file"
+    "  query_binary_features_binaryproto_file  save_retrieval_result_filename"
+    "  [top_k_results=1]  [CPU/GPU]  [DEVICE_ID=0]";
     return 1;
   }
   int arg_pos = num_required_args;
@@ -75,8 +55,8 @@ int image_retrieval_pipeline(int argc, char** argv) {
   if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
     LOG(ERROR)<< "Using GPU";
     uint device_id = 0;
-    if (argc > arg_pos) {
-      device_id = atoi(argv[arg_pos]);
+    if (argc > arg_pos + 1) {
+      device_id = atoi(argv[arg_pos + 1]);
     }
     LOG(ERROR) << "Using Device_id=" << device_id;
     Caffe::SetDevice(device_id);
@@ -90,157 +70,63 @@ int image_retrieval_pipeline(int argc, char** argv) {
   NetParameter pretrained_net_param;
 
   arg_pos = 0;  // the name of the executable
-  // We directly load the net param from trained file
-  string pretrained_binary_proto(argv[++arg_pos]);
-  ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
-                          &pretrained_net_param);
-  shared_ptr<Net<Dtype> > feature_extraction_net(
-      new Net<Dtype>(pretrained_net_param));
 
-  string extract_feature_blob_name(argv[++arg_pos]);
-  if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
-    LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
-    " in trained network " << pretrained_binary_proto;
-    return 1;
+  string sample_binary_features_binaryproto_file(argv[++arg_pos]);
+  BlobProtoVector sample_binary_features;
+  ReadProtoFromBinaryFile(sample_binary_features_binaryproto_file,
+                          &sample_binary_features);
+  vector<shared_ptr<Blob<Dtype> > > sample_binary_feature_blobs;
+  int num_samples;
+  for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
+    shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
+    blob->FromProto(sample_binary_features.blobs(i));
+    sample_binary_feature_blobs.push_back(blob);
+    num_samples += blob->num();
   }
-
-  string sample_images_feature_blob_binaryproto(argv[++arg_pos]);
-  BlobProto sample_images_feature_blob_proto;
-  ReadProtoFromBinaryFile(argv[++arg_pos], &sample_images_feature_blob_proto);
-  shared_ptr<Blob<Dtype> > sample_images_feature_blob(new Blob<Dtype>());
-  sample_images_feature_blob->FromProto(sample_images_feature_blob_proto);
-
-  // Expected prototxt contains at least one data layer as the query images.
-  /*
-   layers {
-   layer {
-   name: "query_images"
-   type: "data"
-   source: "/path/to/your/images/to/extract/feature/and/retrieve/similar/images_leveldb"
-   meanfile: "/path/to/your/image_mean.binaryproto"
-   batchsize: 128
-   cropsize: 115
-   mirror: false
-   }
-   top: "query_images"
-   top: "ground_truth_labels" // TODO: Add MultiLabelDataLayer support for image retrieval, annotations etc.
-   }
-   */
-  string data_prototxt(argv[++arg_pos]);
-  string data_layer_name(argv[++arg_pos]);
-  NetParameter data_net_param;
-  ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param);
-  LayerParameter data_layer_param;
-  int num_layer;
-  for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) {
-    if (data_layer_name == data_net_param.layers(num_layer).layer().name()) {
-      data_layer_param = data_net_param.layers(num_layer).layer();
-      break;
-    }
-  }
-  if (num_layer = data_net_param.layers_size()) {
-    LOG(ERROR) << "Unknow data layer name " << data_layer_name <<
-        " in prototxt " << data_prototxt;
+  if (top_k_results > num_samples) {
+    top_k_results = num_samples;
   }
 
-  string save_feature_leveldb_name(argv[++arg_pos]);
-  leveldb::DB* db;
-  leveldb::Options options;
-  options.error_if_exists = true;
-  options.create_if_missing = true;
-  options.write_buffer_size = 268435456;
-  LOG(INFO) << "Opening leveldb " << argv[3];
-  leveldb::Status status = leveldb::DB::Open(
-      options, save_feature_leveldb_name.c_str(), &db);
-  CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
+  string query_images_feature_blob_binaryproto(argv[++arg_pos]);
+  BlobProtoVector query_images_features;
+  ReadProtoFromBinaryFile(query_images_feature_blob_binaryproto,
+                          &query_images_features);
+  vector<shared_ptr<Blob<Dtype> > > query_binary_feature_blobs;
+  for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
+    shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
+    blob->FromProto(query_images_features.blobs(i));
+    query_binary_feature_blobs.push_back(blob);
+  }
 
   string save_retrieval_result_filename(argv[++arg_pos]);
   std::ofstream retrieval_result_ofs(save_retrieval_result_filename.c_str(),
                                      std::ofstream::out);
 
-  LOG(ERROR)<< "Extacting Features and retrieving images";
-  DataLayer<Dtype> data_layer(data_layer_param);
-  vector<Blob<Dtype>*> bottom_vec_that_data_layer_does_not_need_;
-  vector<Blob<Dtype>*> top_vec;
-  data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
-  int batch_index = 0;
-  shared_ptr<Blob<Dtype> > feature_binary_codes;
+  LOG(ERROR)<< "Retrieving images";
   shared_ptr<Blob<Dtype> > retrieval_results;
   int query_image_index = 0;
 
-  Datum datum;
-  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
-  const int max_key_str_length = 100;
-  char key_str[max_key_str_length];
   int num_bytes_of_binary_code = sizeof(Dtype);
-  int count_query_images = 0;
-  while (top_vec.size()) { // data_layer still outputs data
-    LOG(ERROR)<< "Batch " << batch_index << " feature extraction";
-    feature_extraction_net->Forward(top_vec);
-    const shared_ptr<Blob<Dtype> > feature_blob =
-    feature_extraction_net->GetBlob(extract_feature_blob_name);
-    feature_binary_codes.reset(new Blob<Dtype>());
-    binarize<Dtype>(feature_blob, feature_binary_codes);
-
-    LOG(ERROR) << "Batch " << batch_index << " save extracted features";
-    const Dtype* retrieval_results_data = retrieval_results->cpu_data();
-    int num_features = feature_binary_codes->num();
-    int dim_features = feature_binary_codes->count() / num_features;
-    for (int n = 0; n < num_features; ++n) {
-       datum.set_height(dim_features);
-       datum.set_width(1);
-       datum.set_channels(1);
-       datum.clear_data();
-       datum.clear_float_data();
-       string* datum_string = datum.mutable_data();
-       for (int d = 0; d < dim_features; ++d) {
-         const Dtype data = feature_binary_codes->data_at(n, d, 0, 0);
-         const char* data_byte = reinterpret_cast<const char*>(&data);
-         for(int i = 0; i < num_bytes_of_binary_code; ++i) {
-           datum_string->push_back(data_byte[i]);
-         }
-       }
-       string value;
-       datum.SerializeToString(&value);
-       snprintf(key_str, max_key_str_length, "%d", query_image_index);
-       batch->Put(string(key_str), value);
-       if (++count_query_images % 1000 == 0) {
-         db->Write(leveldb::WriteOptions(), batch);
-         LOG(ERROR) << "Extracted features of " << count_query_images << " query images.";
-         delete batch;
-         batch = new leveldb::WriteBatch();
-       }
-    }
-    // write the last batch
-    if (count_query_images % 1000 != 0) {
-      db->Write(leveldb::WriteOptions(), batch);
-      LOG(ERROR) << "Extracted features of " << count_query_images << " query images.";
-      delete batch;
-      batch = new leveldb::WriteBatch();
-    }
-
-    LOG(ERROR) << "Batch " << batch_index << " image retrieval";
-    similarity_search<Dtype>(sample_images_feature_blob, feature_binary_codes,
+  int num_query_batches = query_binary_feature_blobs.size();
+  for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
+    LOG(ERROR)<< "Batch " << batch_index << " image retrieval";
+    similarity_search<Dtype>(sample_binary_feature_blobs,
+        query_binary_feature_blobs[batch_index],
         top_k_results, retrieval_results);
 
     LOG(ERROR) << "Batch " << batch_index << " save image retrieval results";
     int num_results = retrieval_results->num();
-    int dim_results = retrieval_results->count() / num_results;
+    const Dtype* retrieval_results_data = retrieval_results->cpu_data();
     for (int i = 0; i < num_results; ++i) {
-      retrieval_result_ofs << query_image_index;
-      for (int k = 0; k < dim_results; ++k) {
-        retrieval_result_ofs << " " << retrieval_results->data_at(i, k, 0, 0);
+      retrieval_result_ofs << ++query_image_index;
+      retrieval_results_data += retrieval_results->offset(i);
+      for (int j = 0; j < top_k_results; ++j) {
+        retrieval_result_ofs << " " << retrieval_results_data[j];
       }
       retrieval_result_ofs << "\n";
     }
-    ++query_image_index;
-
-    data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
-    ++batch_index;
-  } //  while (top_vec.size()) {
+  }  //  for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
 
-  delete batch;
-  delete db;
   retrieval_result_ofs.close();
   LOG(ERROR)<< "Successfully ended!";
   return 0;
@@ -296,36 +182,42 @@ class MinHeapComparison {
 };
 
 template<typename Dtype>
-void similarity_search(const shared_ptr<Blob<Dtype> > sample_images_feature,
-                       const shared_ptr<Blob<Dtype> > query_image_feature,
-                       const int top_k_results,
-                       shared_ptr<Blob<Dtype> > retrieval_results) {
-  int num_samples = sample_images_feature->num();
+void similarity_search(
+    const vector<shared_ptr<Blob<Dtype> > >& sample_images_feature_blobs,
+    const shared_ptr<Blob<Dtype> > query_image_feature, const int top_k_results,
+    shared_ptr<Blob<Dtype> > retrieval_results) {
   int num_queries = query_image_feature->num();
   int dim = query_image_feature->count() / num_queries;
-  retrieval_results->Reshape(num_queries, std::min(num_samples, top_k_results), 1, 1);
-  Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data();
   int hamming_dist;
+  retrieval_results->Reshape(num_queries, top_k_results, 1, 1);
+  Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data();
   for (int i = 0; i < num_queries; ++i) {
-    std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int> >,
-        MinHeapComparison> results;
-    for (int j = 0; j < num_samples; ++j) {
-      hamming_dist = caffe_hamming_distance(
-          dim, query_image_feature->cpu_data() + query_image_feature->offset(i),
-          sample_images_feature->cpu_data() + sample_images_feature->offset(j));
-      if (results.empty()) {
-        results.push(std::make_pair(-hamming_dist, j));
-      } else if (-hamming_dist > results.top().first) { // smaller hamming dist
-        results.push(std::make_pair(-hamming_dist, j));
-        if (results.size() > top_k_results) {
+    std::priority_queue<std::pair<int, int>,
+        std::vector<std::pair<int, int> >, MinHeapComparison> results;
+    for (int num_sample_blob;
+        num_sample_blob < sample_images_feature_blobs.size();
+        ++num_sample_blob) {
+      shared_ptr<Blob<Dtype> > sample_images_feature =
+          sample_images_feature_blobs[num_sample_blob];
+      int num_samples = sample_images_feature->num();
+      for (int j = 0; j < num_samples; ++j) {
+        hamming_dist = caffe_hamming_distance(
+            dim,
+            query_image_feature->cpu_data() + query_image_feature->offset(i),
+            sample_images_feature->cpu_data()
+                + sample_images_feature->offset(j));
+        if (results.size() < top_k_results) {
+          results.push(std::make_pair(-hamming_dist, j));
+        } else if (-hamming_dist > results.top().first) {  // smaller hamming dist
           results.pop();
+          results.push(std::make_pair(-hamming_dist, j));
         }
+      }  // for (int j = 0; j < num_samples; ++j) {
+      retrieval_results_data += retrieval_results->offset(i);
+      for (int k = 0; k < results.size(); ++k) {
+        retrieval_results_data[k] = results.top().second;
+        results.pop();
       }
-    }  // for (int j = 0; j < num_samples; ++j) {
-    retrieval_results_data += retrieval_results->offset(i);
-    for (int k = 0; k < results.size(); ++k) {
-      retrieval_results_data[k] = results.top().second;
-      results.pop();
-    }
-  }  // for (int i = 0; i < num_queries; ++i) {
+    } // for(...; sample_images_feature_blobs.size(); ...)
+  } // for (int i = 0; i < num_queries; ++i) {
 }