From fc740a3e7ec90829c142f5a7a9f409b5c849cd00 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Sun, 23 Feb 2014 22:26:35 +0800 Subject: [PATCH] Simplify image retrieval example to use binary features directly --- examples/demo_retrieve_images.cpp | 264 +++++++++++--------------------------- 1 file changed, 78 insertions(+), 186 deletions(-) diff --git a/examples/demo_retrieve_images.cpp b/examples/demo_retrieve_images.cpp index 5cfbdea..e12ad36 100644 --- a/examples/demo_retrieve_images.cpp +++ b/examples/demo_retrieve_images.cpp @@ -1,17 +1,9 @@ // Copyright 2014 kloudkl@github -// -// This program takes in a trained network and an input blob, and then -// extract features of the input blobs produced by the net to retrieve similar images. -// Usage: -// retrieve_image pretrained_net_param input_blob output_filename top_k_results [CPU/GPU] [DEVICE_ID=0] -#include // for snprintf #include // for std::ofstream #include // for std::priority_queue #include #include -#include -#include #include "caffe/blob.hpp" #include "caffe/common.hpp" @@ -24,23 +16,10 @@ using namespace caffe; template -inline int sign(const Dtype val) { - return (Dtype(0) < val) - (val < Dtype(0)); -} - -template -void binarize(const int n, const Dtype* real_valued_feature, - Dtype* binary_code); - -template -void binarize(const shared_ptr > real_valued_features, - shared_ptr > binary_codes); - -template -void similarity_search(const shared_ptr > sample_images_feature, - const shared_ptr > query_image_feature, - const int top_k_results, - shared_ptr > retrieval_results); +void similarity_search( + const vector > >& sample_binary_feature_blobs, + const shared_ptr > query_binary_feature, + const int top_k_results, shared_ptr > retrieval_results); template int image_retrieval_pipeline(int argc, char** argv); @@ -52,13 +31,14 @@ int main(int argc, char** argv) { template int image_retrieval_pipeline(int argc, char** argv) { - const int num_required_args = 7; + const int num_required_args = 4; if (argc < num_required_args) { LOG(ERROR)<< - "retrieve_image pretrained_net_param extract__feature_blob_name" - " sample_images_feature_blob_binaryproto data_prototxt data_layer_name" - " save_feature_leveldb_name save_retrieval_result_filename" - " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]"; + "This program takes in binarized features of query images and sample images" + " extracted by Caffe to retrieve similar images." + "Usage: demo_retrieve_images sample_binary_features_binaryproto_file" + " query_binary_features_binaryproto_file save_retrieval_result_filename" + " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]"; return 1; } int arg_pos = num_required_args; @@ -75,8 +55,8 @@ int image_retrieval_pipeline(int argc, char** argv) { if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) { LOG(ERROR)<< "Using GPU"; uint device_id = 0; - if (argc > arg_pos) { - device_id = atoi(argv[arg_pos]); + if (argc > arg_pos + 1) { + device_id = atoi(argv[arg_pos + 1]); } LOG(ERROR) << "Using Device_id=" << device_id; Caffe::SetDevice(device_id); @@ -90,157 +70,63 @@ int image_retrieval_pipeline(int argc, char** argv) { NetParameter pretrained_net_param; arg_pos = 0; // the name of the executable - // We directly load the net param from trained file - string pretrained_binary_proto(argv[++arg_pos]); - ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(), - &pretrained_net_param); - shared_ptr > feature_extraction_net( - new Net(pretrained_net_param)); - string extract_feature_blob_name(argv[++arg_pos]); - if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) { - LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name << - " in trained network " << pretrained_binary_proto; - return 1; + string sample_binary_features_binaryproto_file(argv[++arg_pos]); + BlobProtoVector sample_binary_features; + ReadProtoFromBinaryFile(sample_binary_features_binaryproto_file, + &sample_binary_features); + vector > > sample_binary_feature_blobs; + int num_samples; + for (int i = 0; i < sample_binary_features.blobs_size(); ++i) { + shared_ptr > blob(new Blob()); + blob->FromProto(sample_binary_features.blobs(i)); + sample_binary_feature_blobs.push_back(blob); + num_samples += blob->num(); } - - string sample_images_feature_blob_binaryproto(argv[++arg_pos]); - BlobProto sample_images_feature_blob_proto; - ReadProtoFromBinaryFile(argv[++arg_pos], &sample_images_feature_blob_proto); - shared_ptr > sample_images_feature_blob(new Blob()); - sample_images_feature_blob->FromProto(sample_images_feature_blob_proto); - - // Expected prototxt contains at least one data layer as the query images. - /* - layers { - layer { - name: "query_images" - type: "data" - source: "/path/to/your/images/to/extract/feature/and/retrieve/similar/images_leveldb" - meanfile: "/path/to/your/image_mean.binaryproto" - batchsize: 128 - cropsize: 115 - mirror: false - } - top: "query_images" - top: "ground_truth_labels" // TODO: Add MultiLabelDataLayer support for image retrieval, annotations etc. - } - */ - string data_prototxt(argv[++arg_pos]); - string data_layer_name(argv[++arg_pos]); - NetParameter data_net_param; - ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param); - LayerParameter data_layer_param; - int num_layer; - for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) { - if (data_layer_name == data_net_param.layers(num_layer).layer().name()) { - data_layer_param = data_net_param.layers(num_layer).layer(); - break; - } - } - if (num_layer = data_net_param.layers_size()) { - LOG(ERROR) << "Unknow data layer name " << data_layer_name << - " in prototxt " << data_prototxt; + if (top_k_results > num_samples) { + top_k_results = num_samples; } - string save_feature_leveldb_name(argv[++arg_pos]); - leveldb::DB* db; - leveldb::Options options; - options.error_if_exists = true; - options.create_if_missing = true; - options.write_buffer_size = 268435456; - LOG(INFO) << "Opening leveldb " << argv[3]; - leveldb::Status status = leveldb::DB::Open( - options, save_feature_leveldb_name.c_str(), &db); - CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name; + string query_images_feature_blob_binaryproto(argv[++arg_pos]); + BlobProtoVector query_images_features; + ReadProtoFromBinaryFile(query_images_feature_blob_binaryproto, + &query_images_features); + vector > > query_binary_feature_blobs; + for (int i = 0; i < sample_binary_features.blobs_size(); ++i) { + shared_ptr > blob(new Blob()); + blob->FromProto(query_images_features.blobs(i)); + query_binary_feature_blobs.push_back(blob); + } string save_retrieval_result_filename(argv[++arg_pos]); std::ofstream retrieval_result_ofs(save_retrieval_result_filename.c_str(), std::ofstream::out); - LOG(ERROR)<< "Extacting Features and retrieving images"; - DataLayer data_layer(data_layer_param); - vector*> bottom_vec_that_data_layer_does_not_need_; - vector*> top_vec; - data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec); - int batch_index = 0; - shared_ptr > feature_binary_codes; + LOG(ERROR)<< "Retrieving images"; shared_ptr > retrieval_results; int query_image_index = 0; - Datum datum; - leveldb::WriteBatch* batch = new leveldb::WriteBatch(); - const int max_key_str_length = 100; - char key_str[max_key_str_length]; int num_bytes_of_binary_code = sizeof(Dtype); - int count_query_images = 0; - while (top_vec.size()) { // data_layer still outputs data - LOG(ERROR)<< "Batch " << batch_index << " feature extraction"; - feature_extraction_net->Forward(top_vec); - const shared_ptr > feature_blob = - feature_extraction_net->GetBlob(extract_feature_blob_name); - feature_binary_codes.reset(new Blob()); - binarize(feature_blob, feature_binary_codes); - - LOG(ERROR) << "Batch " << batch_index << " save extracted features"; - const Dtype* retrieval_results_data = retrieval_results->cpu_data(); - int num_features = feature_binary_codes->num(); - int dim_features = feature_binary_codes->count() / num_features; - for (int n = 0; n < num_features; ++n) { - datum.set_height(dim_features); - datum.set_width(1); - datum.set_channels(1); - datum.clear_data(); - datum.clear_float_data(); - string* datum_string = datum.mutable_data(); - for (int d = 0; d < dim_features; ++d) { - const Dtype data = feature_binary_codes->data_at(n, d, 0, 0); - const char* data_byte = reinterpret_cast(&data); - for(int i = 0; i < num_bytes_of_binary_code; ++i) { - datum_string->push_back(data_byte[i]); - } - } - string value; - datum.SerializeToString(&value); - snprintf(key_str, max_key_str_length, "%d", query_image_index); - batch->Put(string(key_str), value); - if (++count_query_images % 1000 == 0) { - db->Write(leveldb::WriteOptions(), batch); - LOG(ERROR) << "Extracted features of " << count_query_images << " query images."; - delete batch; - batch = new leveldb::WriteBatch(); - } - } - // write the last batch - if (count_query_images % 1000 != 0) { - db->Write(leveldb::WriteOptions(), batch); - LOG(ERROR) << "Extracted features of " << count_query_images << " query images."; - delete batch; - batch = new leveldb::WriteBatch(); - } - - LOG(ERROR) << "Batch " << batch_index << " image retrieval"; - similarity_search(sample_images_feature_blob, feature_binary_codes, + int num_query_batches = query_binary_feature_blobs.size(); + for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) { + LOG(ERROR)<< "Batch " << batch_index << " image retrieval"; + similarity_search(sample_binary_feature_blobs, + query_binary_feature_blobs[batch_index], top_k_results, retrieval_results); LOG(ERROR) << "Batch " << batch_index << " save image retrieval results"; int num_results = retrieval_results->num(); - int dim_results = retrieval_results->count() / num_results; + const Dtype* retrieval_results_data = retrieval_results->cpu_data(); for (int i = 0; i < num_results; ++i) { - retrieval_result_ofs << query_image_index; - for (int k = 0; k < dim_results; ++k) { - retrieval_result_ofs << " " << retrieval_results->data_at(i, k, 0, 0); + retrieval_result_ofs << ++query_image_index; + retrieval_results_data += retrieval_results->offset(i); + for (int j = 0; j < top_k_results; ++j) { + retrieval_result_ofs << " " << retrieval_results_data[j]; } retrieval_result_ofs << "\n"; } - ++query_image_index; - - data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec); - ++batch_index; - } // while (top_vec.size()) { + } // for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) { - delete batch; - delete db; retrieval_result_ofs.close(); LOG(ERROR)<< "Successfully ended!"; return 0; @@ -296,36 +182,42 @@ class MinHeapComparison { }; template -void similarity_search(const shared_ptr > sample_images_feature, - const shared_ptr > query_image_feature, - const int top_k_results, - shared_ptr > retrieval_results) { - int num_samples = sample_images_feature->num(); +void similarity_search( + const vector > >& sample_images_feature_blobs, + const shared_ptr > query_image_feature, const int top_k_results, + shared_ptr > retrieval_results) { int num_queries = query_image_feature->num(); int dim = query_image_feature->count() / num_queries; - retrieval_results->Reshape(num_queries, std::min(num_samples, top_k_results), 1, 1); - Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data(); int hamming_dist; + retrieval_results->Reshape(num_queries, top_k_results, 1, 1); + Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data(); for (int i = 0; i < num_queries; ++i) { - std::priority_queue, std::vector >, - MinHeapComparison> results; - for (int j = 0; j < num_samples; ++j) { - hamming_dist = caffe_hamming_distance( - dim, query_image_feature->cpu_data() + query_image_feature->offset(i), - sample_images_feature->cpu_data() + sample_images_feature->offset(j)); - if (results.empty()) { - results.push(std::make_pair(-hamming_dist, j)); - } else if (-hamming_dist > results.top().first) { // smaller hamming dist - results.push(std::make_pair(-hamming_dist, j)); - if (results.size() > top_k_results) { + std::priority_queue, + std::vector >, MinHeapComparison> results; + for (int num_sample_blob; + num_sample_blob < sample_images_feature_blobs.size(); + ++num_sample_blob) { + shared_ptr > sample_images_feature = + sample_images_feature_blobs[num_sample_blob]; + int num_samples = sample_images_feature->num(); + for (int j = 0; j < num_samples; ++j) { + hamming_dist = caffe_hamming_distance( + dim, + query_image_feature->cpu_data() + query_image_feature->offset(i), + sample_images_feature->cpu_data() + + sample_images_feature->offset(j)); + if (results.size() < top_k_results) { + results.push(std::make_pair(-hamming_dist, j)); + } else if (-hamming_dist > results.top().first) { // smaller hamming dist results.pop(); + results.push(std::make_pair(-hamming_dist, j)); } + } // for (int j = 0; j < num_samples; ++j) { + retrieval_results_data += retrieval_results->offset(i); + for (int k = 0; k < results.size(); ++k) { + retrieval_results_data[k] = results.top().second; + results.pop(); } - } // for (int j = 0; j < num_samples; ++j) { - retrieval_results_data += retrieval_results->offset(i); - for (int k = 0; k < results.size(); ++k) { - retrieval_results_data[k] = results.top().second; - results.pop(); - } - } // for (int i = 0; i < num_queries; ++i) { + } // for(...; sample_images_feature_blobs.size(); ...) + } // for (int i = 0; i < num_queries; ++i) { } -- 2.7.4