// Copyright 2014 kloudkl@github
-//
-// This program takes in a trained network and an input blob, and then
-// extract features of the input blobs produced by the net to retrieve similar images.
-// Usage:
-// retrieve_image pretrained_net_param input_blob output_filename top_k_results [CPU/GPU] [DEVICE_ID=0]
-#include <stdio.h> // for snprintf
#include <fstream> // for std::ofstream
#include <queue> // for std::priority_queue
#include <cuda_runtime.h>
#include <google/protobuf/text_format.h>
-#include <leveldb/db.h>
-#include <leveldb/write_batch.h>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
using namespace caffe;
template<typename Dtype>
-inline int sign(const Dtype val) {
- return (Dtype(0) < val) - (val < Dtype(0));
-}
-
-template<typename Dtype>
-void binarize(const int n, const Dtype* real_valued_feature,
- Dtype* binary_code);
-
-template<typename Dtype>
-void binarize(const shared_ptr<Blob<Dtype> > real_valued_features,
- shared_ptr<Blob<Dtype> > binary_codes);
-
-template<typename Dtype>
-void similarity_search(const shared_ptr<Blob<Dtype> > sample_images_feature,
- const shared_ptr<Blob<Dtype> > query_image_feature,
- const int top_k_results,
- shared_ptr<Blob<Dtype> > retrieval_results);
+void similarity_search(
+ const vector<shared_ptr<Blob<Dtype> > >& sample_binary_feature_blobs,
+ const shared_ptr<Blob<Dtype> > query_binary_feature,
+ const int top_k_results, shared_ptr<Blob<Dtype> > retrieval_results);
template<typename Dtype>
int image_retrieval_pipeline(int argc, char** argv);
template<typename Dtype>
int image_retrieval_pipeline(int argc, char** argv) {
- const int num_required_args = 7;
+ const int num_required_args = 4;
if (argc < num_required_args) {
LOG(ERROR)<<
- "retrieve_image pretrained_net_param extract__feature_blob_name"
- " sample_images_feature_blob_binaryproto data_prototxt data_layer_name"
- " save_feature_leveldb_name save_retrieval_result_filename"
- " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]";
+ "This program takes in binarized features of query images and sample images"
+ " extracted by Caffe to retrieve similar images."
+ "Usage: demo_retrieve_images sample_binary_features_binaryproto_file"
+ " query_binary_features_binaryproto_file save_retrieval_result_filename"
+ " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]";
return 1;
}
int arg_pos = num_required_args;
if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
LOG(ERROR)<< "Using GPU";
uint device_id = 0;
- if (argc > arg_pos) {
- device_id = atoi(argv[arg_pos]);
+ if (argc > arg_pos + 1) {
+ device_id = atoi(argv[arg_pos + 1]);
}
LOG(ERROR) << "Using Device_id=" << device_id;
Caffe::SetDevice(device_id);
NetParameter pretrained_net_param;
arg_pos = 0; // the name of the executable
- // We directly load the net param from trained file
- string pretrained_binary_proto(argv[++arg_pos]);
- ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
- &pretrained_net_param);
- shared_ptr<Net<Dtype> > feature_extraction_net(
- new Net<Dtype>(pretrained_net_param));
- string extract_feature_blob_name(argv[++arg_pos]);
- if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
- LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
- " in trained network " << pretrained_binary_proto;
- return 1;
+ string sample_binary_features_binaryproto_file(argv[++arg_pos]);
+ BlobProtoVector sample_binary_features;
+ ReadProtoFromBinaryFile(sample_binary_features_binaryproto_file,
+ &sample_binary_features);
+ vector<shared_ptr<Blob<Dtype> > > sample_binary_feature_blobs;
+ int num_samples;
+ for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
+ shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
+ blob->FromProto(sample_binary_features.blobs(i));
+ sample_binary_feature_blobs.push_back(blob);
+ num_samples += blob->num();
}
-
- string sample_images_feature_blob_binaryproto(argv[++arg_pos]);
- BlobProto sample_images_feature_blob_proto;
- ReadProtoFromBinaryFile(argv[++arg_pos], &sample_images_feature_blob_proto);
- shared_ptr<Blob<Dtype> > sample_images_feature_blob(new Blob<Dtype>());
- sample_images_feature_blob->FromProto(sample_images_feature_blob_proto);
-
- // Expected prototxt contains at least one data layer as the query images.
- /*
- layers {
- layer {
- name: "query_images"
- type: "data"
- source: "/path/to/your/images/to/extract/feature/and/retrieve/similar/images_leveldb"
- meanfile: "/path/to/your/image_mean.binaryproto"
- batchsize: 128
- cropsize: 115
- mirror: false
- }
- top: "query_images"
- top: "ground_truth_labels" // TODO: Add MultiLabelDataLayer support for image retrieval, annotations etc.
- }
- */
- string data_prototxt(argv[++arg_pos]);
- string data_layer_name(argv[++arg_pos]);
- NetParameter data_net_param;
- ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param);
- LayerParameter data_layer_param;
- int num_layer;
- for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) {
- if (data_layer_name == data_net_param.layers(num_layer).layer().name()) {
- data_layer_param = data_net_param.layers(num_layer).layer();
- break;
- }
- }
- if (num_layer = data_net_param.layers_size()) {
- LOG(ERROR) << "Unknow data layer name " << data_layer_name <<
- " in prototxt " << data_prototxt;
+ if (top_k_results > num_samples) {
+ top_k_results = num_samples;
}
- string save_feature_leveldb_name(argv[++arg_pos]);
- leveldb::DB* db;
- leveldb::Options options;
- options.error_if_exists = true;
- options.create_if_missing = true;
- options.write_buffer_size = 268435456;
- LOG(INFO) << "Opening leveldb " << argv[3];
- leveldb::Status status = leveldb::DB::Open(
- options, save_feature_leveldb_name.c_str(), &db);
- CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;
+ string query_images_feature_blob_binaryproto(argv[++arg_pos]);
+ BlobProtoVector query_images_features;
+ ReadProtoFromBinaryFile(query_images_feature_blob_binaryproto,
+ &query_images_features);
+ vector<shared_ptr<Blob<Dtype> > > query_binary_feature_blobs;
+ for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
+ shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
+ blob->FromProto(query_images_features.blobs(i));
+ query_binary_feature_blobs.push_back(blob);
+ }
string save_retrieval_result_filename(argv[++arg_pos]);
std::ofstream retrieval_result_ofs(save_retrieval_result_filename.c_str(),
std::ofstream::out);
- LOG(ERROR)<< "Extacting Features and retrieving images";
- DataLayer<Dtype> data_layer(data_layer_param);
- vector<Blob<Dtype>*> bottom_vec_that_data_layer_does_not_need_;
- vector<Blob<Dtype>*> top_vec;
- data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
- int batch_index = 0;
- shared_ptr<Blob<Dtype> > feature_binary_codes;
+ LOG(ERROR)<< "Retrieving images";
shared_ptr<Blob<Dtype> > retrieval_results;
int query_image_index = 0;
- Datum datum;
- leveldb::WriteBatch* batch = new leveldb::WriteBatch();
- const int max_key_str_length = 100;
- char key_str[max_key_str_length];
int num_bytes_of_binary_code = sizeof(Dtype);
- int count_query_images = 0;
- while (top_vec.size()) { // data_layer still outputs data
- LOG(ERROR)<< "Batch " << batch_index << " feature extraction";
- feature_extraction_net->Forward(top_vec);
- const shared_ptr<Blob<Dtype> > feature_blob =
- feature_extraction_net->GetBlob(extract_feature_blob_name);
- feature_binary_codes.reset(new Blob<Dtype>());
- binarize<Dtype>(feature_blob, feature_binary_codes);
-
- LOG(ERROR) << "Batch " << batch_index << " save extracted features";
- const Dtype* retrieval_results_data = retrieval_results->cpu_data();
- int num_features = feature_binary_codes->num();
- int dim_features = feature_binary_codes->count() / num_features;
- for (int n = 0; n < num_features; ++n) {
- datum.set_height(dim_features);
- datum.set_width(1);
- datum.set_channels(1);
- datum.clear_data();
- datum.clear_float_data();
- string* datum_string = datum.mutable_data();
- for (int d = 0; d < dim_features; ++d) {
- const Dtype data = feature_binary_codes->data_at(n, d, 0, 0);
- const char* data_byte = reinterpret_cast<const char*>(&data);
- for(int i = 0; i < num_bytes_of_binary_code; ++i) {
- datum_string->push_back(data_byte[i]);
- }
- }
- string value;
- datum.SerializeToString(&value);
- snprintf(key_str, max_key_str_length, "%d", query_image_index);
- batch->Put(string(key_str), value);
- if (++count_query_images % 1000 == 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR) << "Extracted features of " << count_query_images << " query images.";
- delete batch;
- batch = new leveldb::WriteBatch();
- }
- }
- // write the last batch
- if (count_query_images % 1000 != 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR) << "Extracted features of " << count_query_images << " query images.";
- delete batch;
- batch = new leveldb::WriteBatch();
- }
-
- LOG(ERROR) << "Batch " << batch_index << " image retrieval";
- similarity_search<Dtype>(sample_images_feature_blob, feature_binary_codes,
+ int num_query_batches = query_binary_feature_blobs.size();
+ for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
+ LOG(ERROR)<< "Batch " << batch_index << " image retrieval";
+ similarity_search<Dtype>(sample_binary_feature_blobs,
+ query_binary_feature_blobs[batch_index],
top_k_results, retrieval_results);
LOG(ERROR) << "Batch " << batch_index << " save image retrieval results";
int num_results = retrieval_results->num();
- int dim_results = retrieval_results->count() / num_results;
+ const Dtype* retrieval_results_data = retrieval_results->cpu_data();
for (int i = 0; i < num_results; ++i) {
- retrieval_result_ofs << query_image_index;
- for (int k = 0; k < dim_results; ++k) {
- retrieval_result_ofs << " " << retrieval_results->data_at(i, k, 0, 0);
+ retrieval_result_ofs << ++query_image_index;
+ retrieval_results_data += retrieval_results->offset(i);
+ for (int j = 0; j < top_k_results; ++j) {
+ retrieval_result_ofs << " " << retrieval_results_data[j];
}
retrieval_result_ofs << "\n";
}
- ++query_image_index;
-
- data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
- ++batch_index;
- } // while (top_vec.size()) {
+ } // for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
- delete batch;
- delete db;
retrieval_result_ofs.close();
LOG(ERROR)<< "Successfully ended!";
return 0;
};
template<typename Dtype>
-void similarity_search(const shared_ptr<Blob<Dtype> > sample_images_feature,
- const shared_ptr<Blob<Dtype> > query_image_feature,
- const int top_k_results,
- shared_ptr<Blob<Dtype> > retrieval_results) {
- int num_samples = sample_images_feature->num();
+void similarity_search(
+ const vector<shared_ptr<Blob<Dtype> > >& sample_images_feature_blobs,
+ const shared_ptr<Blob<Dtype> > query_image_feature, const int top_k_results,
+ shared_ptr<Blob<Dtype> > retrieval_results) {
int num_queries = query_image_feature->num();
int dim = query_image_feature->count() / num_queries;
- retrieval_results->Reshape(num_queries, std::min(num_samples, top_k_results), 1, 1);
- Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data();
int hamming_dist;
+ retrieval_results->Reshape(num_queries, top_k_results, 1, 1);
+ Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data();
for (int i = 0; i < num_queries; ++i) {
- std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int> >,
- MinHeapComparison> results;
- for (int j = 0; j < num_samples; ++j) {
- hamming_dist = caffe_hamming_distance(
- dim, query_image_feature->cpu_data() + query_image_feature->offset(i),
- sample_images_feature->cpu_data() + sample_images_feature->offset(j));
- if (results.empty()) {
- results.push(std::make_pair(-hamming_dist, j));
- } else if (-hamming_dist > results.top().first) { // smaller hamming dist
- results.push(std::make_pair(-hamming_dist, j));
- if (results.size() > top_k_results) {
+ std::priority_queue<std::pair<int, int>,
+ std::vector<std::pair<int, int> >, MinHeapComparison> results;
+ for (int num_sample_blob;
+ num_sample_blob < sample_images_feature_blobs.size();
+ ++num_sample_blob) {
+ shared_ptr<Blob<Dtype> > sample_images_feature =
+ sample_images_feature_blobs[num_sample_blob];
+ int num_samples = sample_images_feature->num();
+ for (int j = 0; j < num_samples; ++j) {
+ hamming_dist = caffe_hamming_distance(
+ dim,
+ query_image_feature->cpu_data() + query_image_feature->offset(i),
+ sample_images_feature->cpu_data()
+ + sample_images_feature->offset(j));
+ if (results.size() < top_k_results) {
+ results.push(std::make_pair(-hamming_dist, j));
+ } else if (-hamming_dist > results.top().first) { // smaller hamming dist
results.pop();
+ results.push(std::make_pair(-hamming_dist, j));
}
+ } // for (int j = 0; j < num_samples; ++j) {
+ retrieval_results_data += retrieval_results->offset(i);
+ for (int k = 0; k < results.size(); ++k) {
+ retrieval_results_data[k] = results.top().second;
+ results.pop();
}
- } // for (int j = 0; j < num_samples; ++j) {
- retrieval_results_data += retrieval_results->offset(i);
- for (int k = 0; k < results.size(); ++k) {
- retrieval_results_data[k] = results.top().second;
- results.pop();
- }
- } // for (int i = 0; i < num_queries; ++i) {
+ } // for(...; sample_images_feature_blobs.size(); ...)
+ } // for (int i = 0; i < num_queries; ++i) {
}