From: Kai Li Date: Sun, 23 Feb 2014 11:56:57 +0000 (+0800) Subject: Add image retrieval example X-Git-Tag: submit/tizen/20180823.020014~692^2~88^2~21 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e76f7dc00cc7ad119b6ae8558ec80dc6f36edb95;p=platform%2Fupstream%2Fcaffeonacl.git Add image retrieval example --- diff --git a/examples/demo_retrieve_images.cpp b/examples/demo_retrieve_images.cpp new file mode 100644 index 0000000..5cfbdea --- /dev/null +++ b/examples/demo_retrieve_images.cpp @@ -0,0 +1,331 @@ +// Copyright 2014 kloudkl@github +// +// This program takes in a trained network and an input blob, and then +// extract features of the input blobs produced by the net to retrieve similar images. +// Usage: +// retrieve_image pretrained_net_param input_blob output_filename top_k_results [CPU/GPU] [DEVICE_ID=0] + +#include // for snprintf +#include // for std::ofstream +#include // for std::priority_queue +#include +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" + +using namespace caffe; + +template +inline int sign(const Dtype val) { + return (Dtype(0) < val) - (val < Dtype(0)); +} + +template +void binarize(const int n, const Dtype* real_valued_feature, + Dtype* binary_code); + +template +void binarize(const shared_ptr > real_valued_features, + shared_ptr > binary_codes); + +template +void similarity_search(const shared_ptr > sample_images_feature, + const shared_ptr > query_image_feature, + const int top_k_results, + shared_ptr > retrieval_results); + +template +int image_retrieval_pipeline(int argc, char** argv); + +int main(int argc, char** argv) { + return image_retrieval_pipeline(argc, argv); +// return image_retrieval_pipeline(argc, argv); +} + +template +int image_retrieval_pipeline(int argc, char** argv) { + const int num_required_args = 7; + if (argc < num_required_args) { + LOG(ERROR)<< + "retrieve_image pretrained_net_param extract__feature_blob_name" + " sample_images_feature_blob_binaryproto data_prototxt data_layer_name" + " save_feature_leveldb_name save_retrieval_result_filename" + " [top_k_results=1] [CPU/GPU] [DEVICE_ID=0]"; + return 1; + } + int arg_pos = num_required_args; + + int top_k_results; + if (argc <= num_required_args) { + top_k_results = 1; + } else { + top_k_results = atoi(argv[arg_pos]); + CHECK_GE(top_k_results, 0); + } + + arg_pos = num_required_args + 1; + if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) { + LOG(ERROR)<< "Using GPU"; + uint device_id = 0; + if (argc > arg_pos) { + device_id = atoi(argv[arg_pos]); + } + LOG(ERROR) << "Using Device_id=" << device_id; + Caffe::SetDevice(device_id); + Caffe::set_mode(Caffe::GPU); + } else { + LOG(ERROR) << "Using CPU"; + Caffe::set_mode(Caffe::CPU); + } + Caffe::set_phase(Caffe::TEST); + + NetParameter pretrained_net_param; + + arg_pos = 0; // the name of the executable + // We directly load the net param from trained file + string pretrained_binary_proto(argv[++arg_pos]); + ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(), + &pretrained_net_param); + shared_ptr > feature_extraction_net( + new Net(pretrained_net_param)); + + string extract_feature_blob_name(argv[++arg_pos]); + if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) { + LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name << + " in trained network " << pretrained_binary_proto; + return 1; + } + + string sample_images_feature_blob_binaryproto(argv[++arg_pos]); + BlobProto sample_images_feature_blob_proto; + ReadProtoFromBinaryFile(argv[++arg_pos], &sample_images_feature_blob_proto); + shared_ptr > sample_images_feature_blob(new Blob()); + sample_images_feature_blob->FromProto(sample_images_feature_blob_proto); + + // Expected prototxt contains at least one data layer as the query images. + /* + layers { + layer { + name: "query_images" + type: "data" + source: "/path/to/your/images/to/extract/feature/and/retrieve/similar/images_leveldb" + meanfile: "/path/to/your/image_mean.binaryproto" + batchsize: 128 + cropsize: 115 + mirror: false + } + top: "query_images" + top: "ground_truth_labels" // TODO: Add MultiLabelDataLayer support for image retrieval, annotations etc. + } + */ + string data_prototxt(argv[++arg_pos]); + string data_layer_name(argv[++arg_pos]); + NetParameter data_net_param; + ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param); + LayerParameter data_layer_param; + int num_layer; + for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) { + if (data_layer_name == data_net_param.layers(num_layer).layer().name()) { + data_layer_param = data_net_param.layers(num_layer).layer(); + break; + } + } + if (num_layer = data_net_param.layers_size()) { + LOG(ERROR) << "Unknow data layer name " << data_layer_name << + " in prototxt " << data_prototxt; + } + + string save_feature_leveldb_name(argv[++arg_pos]); + leveldb::DB* db; + leveldb::Options options; + options.error_if_exists = true; + options.create_if_missing = true; + options.write_buffer_size = 268435456; + LOG(INFO) << "Opening leveldb " << argv[3]; + leveldb::Status status = leveldb::DB::Open( + options, save_feature_leveldb_name.c_str(), &db); + CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name; + + string save_retrieval_result_filename(argv[++arg_pos]); + std::ofstream retrieval_result_ofs(save_retrieval_result_filename.c_str(), + std::ofstream::out); + + LOG(ERROR)<< "Extacting Features and retrieving images"; + DataLayer data_layer(data_layer_param); + vector*> bottom_vec_that_data_layer_does_not_need_; + vector*> top_vec; + data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec); + int batch_index = 0; + shared_ptr > feature_binary_codes; + shared_ptr > retrieval_results; + int query_image_index = 0; + + Datum datum; + leveldb::WriteBatch* batch = new leveldb::WriteBatch(); + const int max_key_str_length = 100; + char key_str[max_key_str_length]; + int num_bytes_of_binary_code = sizeof(Dtype); + int count_query_images = 0; + while (top_vec.size()) { // data_layer still outputs data + LOG(ERROR)<< "Batch " << batch_index << " feature extraction"; + feature_extraction_net->Forward(top_vec); + const shared_ptr > feature_blob = + feature_extraction_net->GetBlob(extract_feature_blob_name); + feature_binary_codes.reset(new Blob()); + binarize(feature_blob, feature_binary_codes); + + LOG(ERROR) << "Batch " << batch_index << " save extracted features"; + const Dtype* retrieval_results_data = retrieval_results->cpu_data(); + int num_features = feature_binary_codes->num(); + int dim_features = feature_binary_codes->count() / num_features; + for (int n = 0; n < num_features; ++n) { + datum.set_height(dim_features); + datum.set_width(1); + datum.set_channels(1); + datum.clear_data(); + datum.clear_float_data(); + string* datum_string = datum.mutable_data(); + for (int d = 0; d < dim_features; ++d) { + const Dtype data = feature_binary_codes->data_at(n, d, 0, 0); + const char* data_byte = reinterpret_cast(&data); + for(int i = 0; i < num_bytes_of_binary_code; ++i) { + datum_string->push_back(data_byte[i]); + } + } + string value; + datum.SerializeToString(&value); + snprintf(key_str, max_key_str_length, "%d", query_image_index); + batch->Put(string(key_str), value); + if (++count_query_images % 1000 == 0) { + db->Write(leveldb::WriteOptions(), batch); + LOG(ERROR) << "Extracted features of " << count_query_images << " query images."; + delete batch; + batch = new leveldb::WriteBatch(); + } + } + // write the last batch + if (count_query_images % 1000 != 0) { + db->Write(leveldb::WriteOptions(), batch); + LOG(ERROR) << "Extracted features of " << count_query_images << " query images."; + delete batch; + batch = new leveldb::WriteBatch(); + } + + LOG(ERROR) << "Batch " << batch_index << " image retrieval"; + similarity_search(sample_images_feature_blob, feature_binary_codes, + top_k_results, retrieval_results); + + LOG(ERROR) << "Batch " << batch_index << " save image retrieval results"; + int num_results = retrieval_results->num(); + int dim_results = retrieval_results->count() / num_results; + for (int i = 0; i < num_results; ++i) { + retrieval_result_ofs << query_image_index; + for (int k = 0; k < dim_results; ++k) { + retrieval_result_ofs << " " << retrieval_results->data_at(i, k, 0, 0); + } + retrieval_result_ofs << "\n"; + } + ++query_image_index; + + data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec); + ++batch_index; + } // while (top_vec.size()) { + + delete batch; + delete db; + retrieval_result_ofs.close(); + LOG(ERROR)<< "Successfully ended!"; + return 0; +} + +template +void binarize(const int n, const Dtype* real_valued_feature, + Dtype* binary_codes) { + // TODO: more advanced binarization algorithm such as bilinear projection + // Yunchao Gong, Sanjiv Kumar, Henry A. Rowley, and Svetlana Lazebnik. + // Learning Binary Codes for High-Dimensional Data Using Bilinear Projections. + // In IEEE International Conference on Computer Vision and Pattern Recognition (CVPR), 2013. + // http://www.unc.edu/~yunchao/bpbc.htm + int size_of_code = sizeof(Dtype) * 8; + CHECK_EQ(n % size_of_code, 0); + int num_binary_codes = n / size_of_code; + uint64_t code; + int offset; + for (int i = 0; i < num_binary_codes; ++i) { + code = 0; + offset = i * size_of_code; + for (int j = 0; j < size_of_code; ++j) { + code |= sign(real_valued_feature[offset + j]); + code << 1; + } + binary_codes[i] = static_cast(code); + } +} + +template +void binarize(const shared_ptr > real_valued_features, + shared_ptr > binary_codes) { + int num = real_valued_features->num(); + int dim = real_valued_features->count() / num; + int size_of_code = sizeof(Dtype) * 8; + CHECK_EQ(dim % size_of_code, 0); + binary_codes->Reshape(num, dim / size_of_code, 1, 1); + const Dtype* real_valued_features_data = real_valued_features->cpu_data(); + Dtype* binary_codes_data = binary_codes->mutable_cpu_data(); + for (int n = 0; n < num; ++n) { + binarize(dim, + real_valued_features_data + real_valued_features->offset(n), + binary_codes_data + binary_codes->offset(n)); + } +} + +class MinHeapComparison { + public: + bool operator()(const std::pair& lhs, + const std::pair&rhs) const { + return (lhs.first > rhs.first); + } +}; + +template +void similarity_search(const shared_ptr > sample_images_feature, + const shared_ptr > query_image_feature, + const int top_k_results, + shared_ptr > retrieval_results) { + int num_samples = sample_images_feature->num(); + int num_queries = query_image_feature->num(); + int dim = query_image_feature->count() / num_queries; + retrieval_results->Reshape(num_queries, std::min(num_samples, top_k_results), 1, 1); + Dtype* retrieval_results_data = retrieval_results->mutable_cpu_data(); + int hamming_dist; + for (int i = 0; i < num_queries; ++i) { + std::priority_queue, std::vector >, + MinHeapComparison> results; + for (int j = 0; j < num_samples; ++j) { + hamming_dist = caffe_hamming_distance( + dim, query_image_feature->cpu_data() + query_image_feature->offset(i), + sample_images_feature->cpu_data() + sample_images_feature->offset(j)); + if (results.empty()) { + results.push(std::make_pair(-hamming_dist, j)); + } else if (-hamming_dist > results.top().first) { // smaller hamming dist + results.push(std::make_pair(-hamming_dist, j)); + if (results.size() > top_k_results) { + results.pop(); + } + } + } // for (int j = 0; j < num_samples; ++j) { + retrieval_results_data += retrieval_results->offset(i); + for (int k = 0; k < results.size(); ++k) { + retrieval_results_data[k] = results.top().second; + results.pop(); + } + } // for (int i = 0; i < num_queries; ++i) { +}