template<typename Dtype>
void similarity_search(
- const vector<shared_ptr<Blob<Dtype> > >& sample_binary_feature_blobs,
+ const shared_ptr<Blob<Dtype> > sample_binary_feature_blobs,
const shared_ptr<Blob<Dtype> > query_binary_feature,
- const int top_k_results, vector<vector<Dtype> >* retrieval_results);
+ const int top_k_results,
+ vector<vector<std::pair<int, int> > >* retrieval_results);
template<typename Dtype>
int image_retrieval_pipeline(int argc, char** argv);
LOG(ERROR)<< "Loading sample binary features";
string sample_binary_features_binaryproto_file(argv[++arg_pos]);
- BlobProtoVector sample_binary_features;
+ BlobProto sample_binary_features;
ReadProtoFromBinaryFile(sample_binary_features_binaryproto_file,
&sample_binary_features);
- vector<shared_ptr<Blob<Dtype> > > sample_binary_feature_blobs;
- int num_samples;
- for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
- shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
- blob->FromProto(sample_binary_features.blobs(i));
- sample_binary_feature_blobs.push_back(blob);
- num_samples += blob->num();
- }
+ shared_ptr<Blob<Dtype> > sample_binary_feature_blob(new Blob<Dtype>());
+ sample_binary_feature_blob->FromProto(sample_binary_features);
+ int num_samples = sample_binary_feature_blob->num();
if (top_k_results > num_samples) {
top_k_results = num_samples;
}
LOG(ERROR)<< "Loading query binary features";
string query_images_feature_blob_binaryproto(argv[++arg_pos]);
- BlobProtoVector query_images_features;
+ BlobProto query_images_features;
ReadProtoFromBinaryFile(query_images_feature_blob_binaryproto,
&query_images_features);
- vector<shared_ptr<Blob<Dtype> > > query_binary_feature_blobs;
- for (int i = 0; i < query_images_features.blobs_size(); ++i) {
- shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
- blob->FromProto(query_images_features.blobs(i));
- query_binary_feature_blobs.push_back(blob);
- }
+ shared_ptr<Blob<Dtype> > query_binary_feature_blob(new Blob<Dtype>());
+ query_binary_feature_blob->FromProto(query_images_features);
string save_retrieval_result_filename(argv[++arg_pos]);
LOG(ERROR)<< "Opening result file " << save_retrieval_result_filename;
std::ofstream::out);
LOG(ERROR)<< "Retrieving images";
- vector<vector<Dtype> > retrieval_results;
+ vector<vector<std::pair<int, int> > > retrieval_results;
int query_image_index = 0;
- int num_query_batches = query_binary_feature_blobs.size();
- for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
- similarity_search<Dtype>(sample_binary_feature_blobs,
- query_binary_feature_blobs[batch_index],
- top_k_results, &retrieval_results);
- int num_results = retrieval_results.size();
- for (int i = 0; i < num_results; ++i) {
- retrieval_result_ofs << query_image_index++;
- for (int j = 0; j < retrieval_results[i].size(); ++j) {
- retrieval_result_ofs << " " << retrieval_results[i][j];
- }
- retrieval_result_ofs << "\n";
+ similarity_search<Dtype>(sample_binary_feature_blob,
+ query_binary_feature_blob, top_k_results,
+ &retrieval_results);
+ int num_results = retrieval_results.size();
+ for (int i = 0; i < num_results; ++i) {
+ retrieval_result_ofs << query_image_index++;
+ for (int j = 0; j < retrieval_results[i].size(); ++j) {
+ retrieval_result_ofs << " " << retrieval_results[i][j].first << ":"
+ << retrieval_results[i][j].second;
}
- } // for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
+ retrieval_result_ofs << "\n";
+ }
retrieval_result_ofs.close();
- LOG(ERROR)<< "Successfully retrieved similar images for " << query_image_index << " queries!";
+ LOG(ERROR)<< "Successfully retrieved similar images for " << num_results << " queries!";
return 0;
}
template<typename Dtype>
void similarity_search(
- const vector<shared_ptr<Blob<Dtype> > >& sample_images_feature_blobs,
- const shared_ptr<Blob<Dtype> > query_image_feature, const int top_k_results,
- vector<vector<Dtype> >* retrieval_results) {
- int num_queries = query_image_feature->num();
- int dim = query_image_feature->count() / num_queries;
+ const shared_ptr<Blob<Dtype> > sample_images_feature_blob,
+ const shared_ptr<Blob<Dtype> > query_binary_feature_blob,
+ const int top_k_results,
+ vector<vector<std::pair<int, int> > >* retrieval_results) {
+ int num_samples = sample_images_feature_blob->num();
+ int num_queries = query_binary_feature_blob->num();
+ int dim = query_binary_feature_blob->count() / num_queries;
+ LOG(ERROR)<< "num_samples " << num_samples << ", num_queries " << num_queries << ", dim " << dim;
int hamming_dist;
+ int neighbor_index;
retrieval_results->resize(num_queries);
std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int> >,
MinHeapComparison> results;
while (!results.empty()) {
results.pop();
}
- for (int j = 0; j < sample_images_feature_blobs.size(); ++j) {
- int num_samples = sample_images_feature_blobs[j]->num();
- for (int k = 0; k < num_samples; ++k) {
- hamming_dist = caffe_hamming_distance(
- dim,
- query_image_feature->cpu_data() + query_image_feature->offset(i),
- sample_images_feature_blobs[j]->cpu_data()
- + sample_images_feature_blobs[j]->offset(k));
- if (results.size() < top_k_results) {
- results.push(std::make_pair(-hamming_dist, k));
- } else if (-hamming_dist > results.top().first) { // smaller hamming dist
- results.pop();
- results.push(std::make_pair(-hamming_dist, k));
- }
- } // for (int k = 0; k < num_samples; ++k) {
- } // for (int j = 0; j < sample_images_feature_blobs.size(); ++j)
+ const Dtype* query_data = query_binary_feature_blob->cpu_data()
+ + query_binary_feature_blob->offset(i);
+ for (int k = 0; k < num_samples; ++k) {
+ const Dtype* sample_data = sample_images_feature_blob->cpu_data()
+ + sample_images_feature_blob->offset(k);
+ hamming_dist = caffe_hamming_distance(dim, query_data, sample_data);
+ if (results.size() < top_k_results) {
+ results.push(std::make_pair(-hamming_dist, k));
+ } else if (-hamming_dist > results.top().first) { // smaller hamming dist, nearer neighbor
+ results.pop();
+ results.push(std::make_pair(-hamming_dist, k));
+ }
+ } // for (int k = 0; k < num_samples; ++k) {
retrieval_results->at(i).resize(results.size());
for (int k = results.size() - 1; k >= 0; --k) {
- retrieval_results->at(i)[k] = results.top().second;
+ hamming_dist = -results.top().first;
+ neighbor_index = results.top().second;
+ retrieval_results->at(i)[k] = std::make_pair<int, int>(neighbor_index, hamming_dist);
results.pop();
}
} // for (int i = 0; i < num_queries; ++i) {