Refactor the accuracy layer with std::partial_sort
authorKai Li <kaili_kloud@163.com>
Thu, 10 Jul 2014 01:37:34 +0000 (09:37 +0800)
committerKai Li <kaili_kloud@163.com>
Sat, 19 Jul 2014 16:26:41 +0000 (00:26 +0800)
src/caffe/layers/accuracy_layer.cpp

index ddfe38a..86ded06 100644 (file)
@@ -30,6 +30,12 @@ void AccuracyLayer<Dtype>::SetUp(
   (*top)[0]->Reshape(1, 1, 1, 1);
 }
 
+template<typename Dtype>
+bool int_Dtype_pair_greater(std::pair<int, Dtype> a,
+                            std::pair<int, Dtype> b) {
+  return a.second > b.second;
+}
+
 template <typename Dtype>
 Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
@@ -42,21 +48,17 @@ Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   vector<int> max_id(top_k_+1);
   for (int i = 0; i < num; ++i) {
     // Top-k accuracy
-    std::fill_n(maxval.begin(), top_k_, -FLT_MAX);
-    std::fill_n(max_id.begin(), top_k_, 0);
-    for (int j = 0, k; j < dim; ++j) {
-      // insert into (reverse-)sorted top-k array
-      Dtype val = bottom_data[i * dim + j];
-      for (k = top_k_; k > 0 && maxval[k-1] < val; k--) {
-        maxval[k] = maxval[k-1];
-        max_id[k] = max_id[k-1];
-      }
-      maxval[k] = val;
-      max_id[k] = j;
+    std::vector<std::pair<int, Dtype> > bottom_data_vector;
+    for (int j = 0; j < dim; ++j) {
+      bottom_data_vector.push_back(
+          std::make_pair(j, bottom_data[i * dim + j]));
     }
+    std::partial_sort(
+        bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
+        bottom_data_vector.end(), int_Dtype_pair_greater<Dtype>);
     // check if true label is in top k predictions
     for (int k = 0; k < top_k_; k++)
-      if (max_id[k] == static_cast<int>(bottom_label[i])) {
+      if (bottom_data_vector[k].first == static_cast<int>(bottom_label[i])) {
         ++accuracy;
         break;
       }