Add image retrieval example
authorKai Li <kaili_kloud@163.com>
Sun, 23 Feb 2014 11:56:57 +0000 (19:56 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:41 +0000 (23:04 +0800)
examples/demo_retrieve_images.cpp [new file with mode: 0644]

diff --git a/examples/demo_retrieve_images.cpp b/examples/demo_retrieve_images.cpp
new file mode 100644 (file)
index 0000000..5cfbdea
--- /dev/null
@@ -0,0 +1,331 @@
+// Copyright 2014 kloudkl@github
+//
+// This program takes in a trained network and an input blob, and then
+// extract features of the input blobs produced by the net to retrieve similar images.
+// Usage:
+//    retrieve_image pretrained_net_param input_blob output_filename top_k_results [CPU/GPU] [DEVICE_ID=0]
+
+#include <stdio.h> // for snprintf
+#include <fstream> // for std::ofstream
+#include <queue> // for std::priority_queue
+#include <cuda_runtime.h>
+#include <google/protobuf/text_format.h>
+#include <leveldb/db.h>
+#include <leveldb/write_batch.h>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/net.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using namespace caffe;
+
+template<typename Dtype>
+inline int sign(const Dtype val) {
+  return (Dtype(0) < val) - (val < Dtype(0));
+}
+
+template<typename Dtype>
+void binarize(const int n, const Dtype* real_valued_feature,
+              Dtype* binary_code);
+
+template<typename Dtype>
+void binarize(const shared_ptr<Blob<Dtype> > real_valued_features,
+              shared_ptr<Blob<Dtype> > binary_codes);
+
+template<typename Dtype>
+void similarity_search(const shared_ptr<Blob<Dtype> > sample_images_feature,
+                       const shared_ptr<Blob<Dtype> > query_image_feature,
+                       const int top_k_results,
+                       shared_ptr<Blob<Dtype> > retrieval_results);
+
+template<typename Dtype>
+int image_retrieval_pipeline(int argc, char** argv);
+
+int main(int argc, char** argv) {
+  return image_retrieval_pipeline<float>(argc, argv);
+//  return image_retrieval_pipeline<double>(argc, argv);
+}
+
+template<typename Dtype>
+int image_retrieval_pipeline(int argc, char** argv) {
+  const int num_required_args = 7;
+  if (argc < num_required_args) {
+    LOG(ERROR)<<
+    "retrieve_image pretrained_net_param extract__feature_blob_name"
+    " sample_images_feature_blob_binaryproto data_prototxt data_layer_name"
+    " save_feature_leveldb_name save_retrieval_result_filename"
+    " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]";
+    return 1;
+  }
+  int arg_pos = num_required_args;
+
+  int top_k_results;
+  if (argc <= num_required_args) {
+    top_k_results = 1;
+  } else {
+    top_k_results = atoi(argv[arg_pos]);
+    CHECK_GE(top_k_results, 0);
+  }
+
+  arg_pos = num_required_args + 1;
+  if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
+    LOG(ERROR)<< "Using GPU";
+    uint device_id = 0;
+    if (argc > arg_pos) {
+      device_id = atoi(argv[arg_pos]);
+    }
+    LOG(ERROR) << "Using Device_id=" << device_id;
+    Caffe::SetDevice(device_id);
+    Caffe::set_mode(Caffe::GPU);
+  } else {
+    LOG(ERROR) << "Using CPU";
+    Caffe::set_mode(Caffe::CPU);
+  }
+  Caffe::set_phase(Caffe::TEST);
+
+  NetParameter pretrained_net_param;
+
+  arg_pos = 0;  // the name of the executable
+  // We directly load the net param from trained file
+  string pretrained_binary_proto(argv[++arg_pos]);
+  ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
+                          &pretrained_net_param);
+  shared_ptr<Net<Dtype> > feature_extraction_net(
+      new Net<Dtype>(pretrained_net_param));
+
+  string extract_feature_blob_name(argv[++arg_pos]);
+  if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
+    LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
+    " in trained network " << pretrained_binary_proto;
+    return 1;
+  }
+
+  string sample_images_feature_blob_binaryproto(argv[++arg_pos]);
+  BlobProto sample_images_feature_blob_proto;
+  ReadProtoFromBinaryFile(argv[++arg_pos], &sample_images_feature_blob_proto);
+  shared_ptr<Blob<Dtype> > sample_images_feature_blob(new Blob<Dtype>());
+  sample_images_feature_blob->FromProto(sample_images_feature_blob_proto);
+
+  // Expected prototxt contains at least one data layer as the query images.
+  /*
+   layers {
+   layer {
+   name: "query_images"
+   type: "data"
+   source: "/path/to/your/images/to/extract/feature/and/retrieve/similar/images_leveldb"
+   meanfile: "/path/to/your/image_mean.binaryproto"
+   batchsize: 128
+   cropsize: 115
+   mirror: false
+   }
+   top: "query_images"
+   top: "ground_truth_labels" // TODO: Add MultiLabelDataLayer support for image retrieval, annotations etc.
+   }
+   */
+  string data_prototxt(argv[++arg_pos]);
+  string data_layer_name(argv[++arg_pos]);
+  NetParameter data_net_param;
+  ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param);
+  LayerParameter data_layer_param;
+  int num_layer;
+  for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) {
+    if (data_layer_name == data_net_param.layers(num_layer).layer().name()) {
+      data_layer_param = data_net_param.layers(num_layer).layer();
+      break;
+    }
+  }
+  if (num_layer = data_net_param.layers_size()) {
+    LOG(ERROR) << "Unknow data layer name " << data_layer_name <<
+        " in prototxt " << data_prototxt;
+  }
+
+  string save_feature_leveldb_name(argv[++arg_pos]);
+  leveldb::DB* db;
+  leveldb::Options options;
+  options.error_if_exists = true;
+  options.create_if_missing = true;
+  options.write_buffer_size = 268435456;
+  LOG(INFO) << "Opening leveldb " << argv[3];
+  leveldb::Status status = leveldb::DB::Open(
+      options, save_feature_leveldb_name.c_str(), &db);
+  CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
+
+  string save_retrieval_result_filename(argv[++arg_pos]);
+  std::ofstream retrieval_result_ofs(save_retrieval_result_filename.c_str(),
+                                     std::ofstream::out);
+
+  LOG(ERROR)<< "Extacting Features and retrieving images";
+  DataLayer<Dtype> data_layer(data_layer_param);
+  vector<Blob<Dtype>*> bottom_vec_that_data_layer_does_not_need_;
+  vector<Blob<Dtype>*> top_vec;
+  data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
+  int batch_index = 0;
+  shared_ptr<Blob<Dtype> > feature_binary_codes;
+  shared_ptr<Blob<Dtype> > retrieval_results;
+  int query_image_index = 0;
+
+  Datum datum;
+  leveldb::WriteBatch* batch = new leveldb::WriteBatch();
+  const int max_key_str_length = 100;
+  char key_str[max_key_str_length];
+  int num_bytes_of_binary_code = sizeof(Dtype);
+  int count_query_images = 0;
+  while (top_vec.size()) { // data_layer still outputs data
+    LOG(ERROR)<< "Batch " << batch_index << " feature extraction";
+    feature_extraction_net->Forward(top_vec);
+    const shared_ptr<Blob<Dtype> > feature_blob =
+    feature_extraction_net->GetBlob(extract_feature_blob_name);
+    feature_binary_codes.reset(new Blob<Dtype>());
+    binarize<Dtype>(feature_blob, feature_binary_codes);
+
+    LOG(ERROR) << "Batch " << batch_index << " save extracted features";
+    const Dtype* retrieval_results_data = retrieval_results->cpu_data();
+    int num_features = feature_binary_codes->num();
+    int dim_features = feature_binary_codes->count() / num_features;
+    for (int n = 0; n < num_features; ++n) {
+       datum.set_height(dim_features);
+       datum.set_width(1);
+       datum.set_channels(1);
+       datum.clear_data();
+       datum.clear_float_data();
+       string* datum_string = datum.mutable_data();
+       for (int d = 0; d < dim_features; ++d) {
+         const Dtype data = feature_binary_codes->data_at(n, d, 0, 0);
+         const char* data_byte = reinterpret_cast<const char*>(&data);
+         for(int i = 0; i < num_bytes_of_binary_code; ++i) {
+           datum_string->push_back(data_byte[i]);
+         }
+       }
+       string value;
+       datum.SerializeToString(&value);
+       snprintf(key_str, max_key_str_length, "%d", query_image_index);
+       batch->Put(string(key_str), value);
+       if (++count_query_images % 1000 == 0) {
+         db->Write(leveldb::WriteOptions(), batch);
+         LOG(ERROR) << "Extracted features of " << count_query_images << " query images.";
+         delete batch;
+         batch = new leveldb::WriteBatch();
+       }
+    }
+    // write the last batch
+    if (count_query_images % 1000 != 0) {
+      db->Write(leveldb::WriteOptions(), batch);
+      LOG(ERROR) << "Extracted features of " << count_query_images << " query images.";
+      delete batch;
+      batch = new leveldb::WriteBatch();
+    }
+
+    LOG(ERROR) << "Batch " << batch_index << " image retrieval";
+    similarity_search<Dtype>(sample_images_feature_blob, feature_binary_codes,
+        top_k_results, retrieval_results);
+
+    LOG(ERROR) << "Batch " << batch_index << " save image retrieval results";
+    int num_results = retrieval_results->num();
+    int dim_results = retrieval_results->count() / num_results;
+    for (int i = 0; i < num_results; ++i) {
+      retrieval_result_ofs << query_image_index;
+      for (int k = 0; k < dim_results; ++k) {
+        retrieval_result_ofs << " " << retrieval_results->data_at(i, k, 0, 0);
+      }
+      retrieval_result_ofs << "\n";
+    }
+    ++query_image_index;
+
+    data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
+    ++batch_index;
+  } //  while (top_vec.size()) {
+
+  delete batch;
+  delete db;
+  retrieval_result_ofs.close();
+  LOG(ERROR)<< "Successfully ended!";
+  return 0;
+}
+
+template<typename Dtype>
+void binarize(const int n, const Dtype* real_valued_feature,
+              Dtype* binary_codes) {
+  // TODO: more advanced binarization algorithm such as bilinear projection
+  // Yunchao Gong, Sanjiv Kumar, Henry A. Rowley, and Svetlana Lazebnik.
+  // Learning Binary Codes for High-Dimensional Data Using Bilinear Projections.
+  // In IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), 2013.
+  // http://www.unc.edu/~yunchao/bpbc.htm
+  int size_of_code = sizeof(Dtype) * 8;
+  CHECK_EQ(n % size_of_code, 0);
+  int num_binary_codes = n / size_of_code;
+  uint64_t code;
+  int offset;
+  for (int i = 0; i < num_binary_codes; ++i) {
+    code = 0;
+    offset = i * size_of_code;
+    for (int j = 0; j < size_of_code; ++j) {
+      code |= sign(real_valued_feature[offset + j]);
+      code << 1;
+    }
+    binary_codes[i] = static_cast<Dtype>(code);
+  }
+}
+
+template<typename Dtype>
+void binarize(const shared_ptr<Blob<Dtype> > real_valued_features,
+              shared_ptr<Blob<Dtype> > binary_codes) {
+  int num = real_valued_features->num();
+  int dim = real_valued_features->count() / num;
+  int size_of_code = sizeof(Dtype) * 8;
+  CHECK_EQ(dim % size_of_code, 0);
+  binary_codes->Reshape(num, dim / size_of_code, 1, 1);
+  const Dtype* real_valued_features_data = real_valued_features->cpu_data();
+  Dtype* binary_codes_data = binary_codes->mutable_cpu_data();
+  for (int n = 0; n < num; ++n) {
+    binarize<Dtype>(dim,
+                    real_valued_features_data + real_valued_features->offset(n),
+                    binary_codes_data + binary_codes->offset(n));
+  }
+}
+
+class MinHeapComparison {
+ public:
+  bool operator()(const std::pair<int, int>& lhs,
+                  const std::pair<int, int>&rhs) const {
+    return (lhs.first > rhs.first);
+  }
+};
+
+template<typename Dtype>
+void similarity_search(const shared_ptr<Blob<Dtype> > sample_images_feature,
+                       const shared_ptr<Blob<Dtype> > query_image_feature,
+                       const int top_k_results,
+                       shared_ptr<Blob<Dtype> > retrieval_results) {
+  int num_samples = sample_images_feature->num();
+  int num_queries = query_image_feature->num();
+  int dim = query_image_feature->count() / num_queries;
+  retrieval_results->Reshape(num_queries, std::min(num_samples, top_k_results), 1, 1);
+  Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data();
+  int hamming_dist;
+  for (int i = 0; i < num_queries; ++i) {
+    std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int> >,
+        MinHeapComparison> results;
+    for (int j = 0; j < num_samples; ++j) {
+      hamming_dist = caffe_hamming_distance(
+          dim, query_image_feature->cpu_data() + query_image_feature->offset(i),
+          sample_images_feature->cpu_data() + sample_images_feature->offset(j));
+      if (results.empty()) {
+        results.push(std::make_pair(-hamming_dist, j));
+      } else if (-hamming_dist > results.top().first) { // smaller hamming dist
+        results.push(std::make_pair(-hamming_dist, j));
+        if (results.size() > top_k_results) {
+          results.pop();
+        }
+      }
+    }  // for (int j = 0; j < num_samples; ++j) {
+    retrieval_results_data += retrieval_results->offset(i);
+    for (int k = 0; k < results.size(); ++k) {
+      retrieval_results_data[k] = results.top().second;
+      results.pop();
+    }
+  }  // for (int i = 0; i < num_queries; ++i) {
+}