Compute top-k accuracy in AccuracyLayer.
authorRob Hess <hess@yahoo-inc.com>
Mon, 16 Jun 2014 20:35:46 +0000 (13:35 -0700)
committerRob Hess <hess@yahoo-inc.com>
Fri, 27 Jun 2014 18:19:49 +0000 (11:19 -0700)
src/caffe/layers/accuracy_layer.cpp

index 55f620d..a284f6b 100644 (file)
@@ -22,7 +22,7 @@ void AccuracyLayer<Dtype>::SetUp(
   CHECK_EQ(bottom[0]->num(), bottom[1]->num())
       << "The data and label should have the same number.";
   CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
-      << "top_k must be less than the number of classes.";
+      << "top_k must be less than or equal to the number of classes.";
   CHECK_EQ(bottom[1]->channels(), 1);
   CHECK_EQ(bottom[1]->height(), 1);
   CHECK_EQ(bottom[1]->width(), 1);
@@ -37,20 +37,32 @@ Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_label = bottom[1]->cpu_data();
   int num = bottom[0]->num();
   int dim = bottom[0]->count() / bottom[0]->num();
+  Dtype* maxval = new Dtype[top_k_+1];
+  int* max_id = new int[top_k_+1];
   for (int i = 0; i < num; ++i) {
-    // Accuracy
-    Dtype maxval = -FLT_MAX;
-    int max_id = 0;
-    for (int j = 0; j < dim; ++j) {
-      if (bottom_data[i * dim + j] > maxval) {
-        maxval = bottom_data[i * dim + j];
-        max_id = j;
+    // Top-k accuracy
+    std::fill_n(maxval, top_k_, -FLT_MAX);
+    std::fill_n(max_id, top_k_, 0);
+    for (int j = 0, k; j < dim; ++j) {
+      // insert into (reverse-)sorted top-k array
+      Dtype val = bottom_data[i * dim + j];
+      for (k = top_k_; k > 0 && maxval[k-1] < val; k--) {
+        maxval[k] = maxval[k-1];
+        max_id[k] = max_id[k-1];
       }
+      maxval[k] = val;
+      max_id[k] = j;
     }
-    if (max_id == static_cast<int>(bottom_label[i])) {
-      ++accuracy;
-    }
+    // check if true label is in top k predictions
+    for (int k = 0; k < top_k_; k++)
+      if (max_id[k] == static_cast<int>(bottom_label[i])) {
+        ++accuracy;
+        break;
+      }
   }
+  delete[] maxval;
+  delete[] max_id;
+
   // LOG(INFO) << "Accuracy: " << accuracy;
   (*top)[0]->mutable_cpu_data()[0] = accuracy / num;