Save and load data correctly in feat extracion, binarization and IR demo
authorKai Li <kaili_kloud@163.com>
Wed, 26 Feb 2014 01:30:35 +0000 (09:30 +0800)
committerKai Li <kaili_kloud@163.com>
Wed, 19 Mar 2014 15:04:42 +0000 (23:04 +0800)
examples/demo_binarize_features.cpp
examples/demo_extract_features.cpp
examples/demo_retrieve_images.cpp

index 5345a26..74a389c 100644 (file)
@@ -97,6 +97,7 @@ int features_binarization_pipeline(int argc, char** argv) {
   }
   shared_ptr<Blob<Dtype> > feature_binary_codes(new Blob<Dtype>());
   binarize<Dtype>(feature_blob_vector, feature_binary_codes);
+
   BlobProto blob_proto;
   feature_binary_codes->ToProto(&blob_proto);
   WriteProtoToBinaryFile(blob_proto, save_binarized_feature_binaryproto_file);
@@ -125,13 +126,14 @@ void binarize(const vector<shared_ptr<Blob<Dtype> > >& feature_blob_vector,
   int size_of_code = sizeof(Dtype) * 8;
   binary_codes->Reshape(num_features, (dim + size_of_code - 1) / size_of_code,
                         1, 1);
-  Dtype* binary_data = binary_codes->mutable_cpu_data();
-  int offset;
   uint64_t code;
+  count = 0;
   for (int i = 0; i < feature_blob_vector.size(); ++i) {
-    const Dtype* data = feature_blob_vector[i]->cpu_data();
     for (int j = 0; j < feature_blob_vector[i]->num(); ++j) {
-      offset = j * dim;
+      const Dtype* data = feature_blob_vector[i]->cpu_data()
+          + feature_blob_vector[i]->offset(j);
+      Dtype* binary_data = binary_codes->mutable_cpu_data()
+          + binary_codes->offset(count++);
       code = 0;
       int k;
       for (k = 0; k < dim;) {
index 088cc28..32bb728 100644 (file)
@@ -131,13 +131,14 @@ int feature_extraction_pipeline(int argc, char** argv) {
         ->GetBlob(extract_feature_blob_name);
     int num_features = feature_blob->num();
     int dim_features = feature_blob->count() / num_features;
+    Dtype* feature_blob_data;
     for (int n = 0; n < num_features; ++n) {
       datum.set_height(dim_features);
       datum.set_width(1);
       datum.set_channels(1);
       datum.clear_data();
       datum.clear_float_data();
-      const Dtype* feature_blob_data = feature_blob->cpu_data();
+      feature_blob_data = feature_blob->mutable_cpu_data() + feature_blob->offset(n);
       for (int d = 0; d < dim_features; ++d) {
         datum.add_float_data(feature_blob_data[d]);
       }
@@ -152,7 +153,7 @@ int feature_extraction_pipeline(int argc, char** argv) {
         delete batch;
         batch = new leveldb::WriteBatch();
       }
-    }
+    }  // for (int n = 0; n < num_features; ++n)
   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
   // write the last batch
   if (image_index % 1000 != 0) {
index 2c16824..f339981 100644 (file)
@@ -17,9 +17,10 @@ using namespace caffe;
 
 template<typename Dtype>
 void similarity_search(
-    const vector<shared_ptr<Blob<Dtype> > >& sample_binary_feature_blobs,
+    const shared_ptr<Blob<Dtype> > sample_binary_feature_blobs,
     const shared_ptr<Blob<Dtype> > query_binary_feature,
-    const int top_k_results, vector<vector<Dtype> >* retrieval_results);
+    const int top_k_results,
+    vector<vector<std::pair<int, int> > >* retrieval_results);
 
 template<typename Dtype>
 int image_retrieval_pipeline(int argc, char** argv);
@@ -71,32 +72,23 @@ int image_retrieval_pipeline(int argc, char** argv) {
 
   LOG(ERROR)<< "Loading sample binary features";
   string sample_binary_features_binaryproto_file(argv[++arg_pos]);
-  BlobProtoVector sample_binary_features;
+  BlobProto sample_binary_features;
   ReadProtoFromBinaryFile(sample_binary_features_binaryproto_file,
                           &sample_binary_features);
-  vector<shared_ptr<Blob<Dtype> > > sample_binary_feature_blobs;
-  int num_samples;
-  for (int i = 0; i < sample_binary_features.blobs_size(); ++i) {
-    shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
-    blob->FromProto(sample_binary_features.blobs(i));
-    sample_binary_feature_blobs.push_back(blob);
-    num_samples += blob->num();
-  }
+  shared_ptr<Blob<Dtype> > sample_binary_feature_blob(new Blob<Dtype>());
+  sample_binary_feature_blob->FromProto(sample_binary_features);
+  int num_samples = sample_binary_feature_blob->num();
   if (top_k_results > num_samples) {
     top_k_results = num_samples;
   }
 
   LOG(ERROR)<< "Loading query binary features";
   string query_images_feature_blob_binaryproto(argv[++arg_pos]);
-  BlobProtoVector query_images_features;
+  BlobProto query_images_features;
   ReadProtoFromBinaryFile(query_images_feature_blob_binaryproto,
                           &query_images_features);
-  vector<shared_ptr<Blob<Dtype> > > query_binary_feature_blobs;
-  for (int i = 0; i < query_images_features.blobs_size(); ++i) {
-    shared_ptr<Blob<Dtype> > blob(new Blob<Dtype>());
-    blob->FromProto(query_images_features.blobs(i));
-    query_binary_feature_blobs.push_back(blob);
-  }
+  shared_ptr<Blob<Dtype> > query_binary_feature_blob(new Blob<Dtype>());
+  query_binary_feature_blob->FromProto(query_images_features);
 
   string save_retrieval_result_filename(argv[++arg_pos]);
   LOG(ERROR)<< "Opening result file " << save_retrieval_result_filename;
@@ -104,26 +96,24 @@ int image_retrieval_pipeline(int argc, char** argv) {
                                      std::ofstream::out);
 
   LOG(ERROR)<< "Retrieving images";
-  vector<vector<Dtype> > retrieval_results;
+  vector<vector<std::pair<int, int> > > retrieval_results;
   int query_image_index = 0;
 
-  int num_query_batches = query_binary_feature_blobs.size();
-  for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
-    similarity_search<Dtype>(sample_binary_feature_blobs,
-                             query_binary_feature_blobs[batch_index],
-                             top_k_results, &retrieval_results);
-    int num_results = retrieval_results.size();
-    for (int i = 0; i < num_results; ++i) {
-      retrieval_result_ofs << query_image_index++;
-      for (int j = 0; j < retrieval_results[i].size(); ++j) {
-        retrieval_result_ofs << " " << retrieval_results[i][j];
-      }
-      retrieval_result_ofs << "\n";
+  similarity_search<Dtype>(sample_binary_feature_blob,
+                           query_binary_feature_blob, top_k_results,
+                           &retrieval_results);
+  int num_results = retrieval_results.size();
+  for (int i = 0; i < num_results; ++i) {
+    retrieval_result_ofs << query_image_index++;
+    for (int j = 0; j < retrieval_results[i].size(); ++j) {
+      retrieval_result_ofs << " " << retrieval_results[i][j].first << ":"
+                           << retrieval_results[i][j].second;
     }
-  }  //  for (int batch_index = 0; batch_index < num_query_batches; ++batch_index) {
+    retrieval_result_ofs << "\n";
+  }
 
   retrieval_result_ofs.close();
-  LOG(ERROR)<< "Successfully retrieved similar images for " << query_image_index << " queries!";
+  LOG(ERROR)<< "Successfully retrieved similar images for " << num_results << " queries!";
   return 0;
 }
 
@@ -137,12 +127,16 @@ class MinHeapComparison {
 
 template<typename Dtype>
 void similarity_search(
-    const vector<shared_ptr<Blob<Dtype> > >& sample_images_feature_blobs,
-    const shared_ptr<Blob<Dtype> > query_image_feature, const int top_k_results,
-    vector<vector<Dtype> >* retrieval_results) {
-  int num_queries = query_image_feature->num();
-  int dim = query_image_feature->count() / num_queries;
+    const shared_ptr<Blob<Dtype> > sample_images_feature_blob,
+    const shared_ptr<Blob<Dtype> > query_binary_feature_blob,
+    const int top_k_results,
+    vector<vector<std::pair<int, int> > >* retrieval_results) {
+  int num_samples = sample_images_feature_blob->num();
+  int num_queries = query_binary_feature_blob->num();
+  int dim = query_binary_feature_blob->count() / num_queries;
+  LOG(ERROR)<< "num_samples " << num_samples << ", num_queries " << num_queries << ", dim " << dim;
   int hamming_dist;
+  int neighbor_index;
   retrieval_results->resize(num_queries);
   std::priority_queue<std::pair<int, int>, std::vector<std::pair<int, int> >,
       MinHeapComparison> results;
@@ -150,25 +144,24 @@ void similarity_search(
     while (!results.empty()) {
       results.pop();
     }
-    for (int j = 0; j < sample_images_feature_blobs.size(); ++j) {
-      int num_samples = sample_images_feature_blobs[j]->num();
-      for (int k = 0; k < num_samples; ++k) {
-        hamming_dist = caffe_hamming_distance(
-            dim,
-            query_image_feature->cpu_data() + query_image_feature->offset(i),
-            sample_images_feature_blobs[j]->cpu_data()
-                + sample_images_feature_blobs[j]->offset(k));
-        if (results.size() < top_k_results) {
-          results.push(std::make_pair(-hamming_dist, k));
-        } else if (-hamming_dist > results.top().first) {  // smaller hamming dist
-          results.pop();
-          results.push(std::make_pair(-hamming_dist, k));
-        }
-      }  // for (int k = 0; k < num_samples; ++k) {
-    }  // for (int j = 0; j < sample_images_feature_blobs.size(); ++j)
+    const Dtype* query_data = query_binary_feature_blob->cpu_data()
+        + query_binary_feature_blob->offset(i);
+    for (int k = 0; k < num_samples; ++k) {
+      const Dtype* sample_data = sample_images_feature_blob->cpu_data()
+          + sample_images_feature_blob->offset(k);
+      hamming_dist = caffe_hamming_distance(dim, query_data, sample_data);
+      if (results.size() < top_k_results) {
+        results.push(std::make_pair(-hamming_dist, k));
+      } else if (-hamming_dist > results.top().first) {  // smaller hamming dist, nearer neighbor
+        results.pop();
+        results.push(std::make_pair(-hamming_dist, k));
+      }
+    }  // for (int k = 0; k < num_samples; ++k) {
     retrieval_results->at(i).resize(results.size());
     for (int k = results.size() - 1; k >= 0; --k) {
-      retrieval_results->at(i)[k] = results.top().second;
+      hamming_dist = -results.top().first;
+      neighbor_index = results.top().second;
+      retrieval_results->at(i)[k] = std::make_pair<int, int>(neighbor_index, hamming_dist);
       results.pop();
     }
   }  // for (int i = 0; i < num_queries; ++i) {