From 5a887fc30fba742501ad395165d3f0968738db04 Mon Sep 17 00:00:00 2001 From: Rob Hess Date: Mon, 16 Jun 2014 13:35:46 -0700 Subject: [PATCH] Compute top-k accuracy in AccuracyLayer. --- src/caffe/layers/accuracy_layer.cpp | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 55f620d..a284f6b 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -22,7 +22,7 @@ void AccuracyLayer::SetUp( CHECK_EQ(bottom[0]->num(), bottom[1]->num()) << "The data and label should have the same number."; CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) - << "top_k must be less than the number of classes."; + << "top_k must be less than or equal to the number of classes."; CHECK_EQ(bottom[1]->channels(), 1); CHECK_EQ(bottom[1]->height(), 1); CHECK_EQ(bottom[1]->width(), 1); @@ -37,20 +37,32 @@ Dtype AccuracyLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_label = bottom[1]->cpu_data(); int num = bottom[0]->num(); int dim = bottom[0]->count() / bottom[0]->num(); + Dtype* maxval = new Dtype[top_k_+1]; + int* max_id = new int[top_k_+1]; for (int i = 0; i < num; ++i) { - // Accuracy - Dtype maxval = -FLT_MAX; - int max_id = 0; - for (int j = 0; j < dim; ++j) { - if (bottom_data[i * dim + j] > maxval) { - maxval = bottom_data[i * dim + j]; - max_id = j; + // Top-k accuracy + std::fill_n(maxval, top_k_, -FLT_MAX); + std::fill_n(max_id, top_k_, 0); + for (int j = 0, k; j < dim; ++j) { + // insert into (reverse-)sorted top-k array + Dtype val = bottom_data[i * dim + j]; + for (k = top_k_; k > 0 && maxval[k-1] < val; k--) { + maxval[k] = maxval[k-1]; + max_id[k] = max_id[k-1]; } + maxval[k] = val; + max_id[k] = j; } - if (max_id == static_cast(bottom_label[i])) { - ++accuracy; - } + // check if true label is in top k predictions + for (int k = 0; k < top_k_; k++) + if (max_id[k] == static_cast(bottom_label[i])) { + ++accuracy; + break; + } } + delete[] maxval; + delete[] max_id; + // LOG(INFO) << "Accuracy: " << accuracy; (*top)[0]->mutable_cpu_data()[0] = accuracy / num; -- 2.7.4