CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
- << "top_k must be less than the number of classes.";
+ << "top_k must be less than or equal to the number of classes.";
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
const Dtype* bottom_label = bottom[1]->cpu_data();
int num = bottom[0]->num();
int dim = bottom[0]->count() / bottom[0]->num();
+ Dtype* maxval = new Dtype[top_k_+1];
+ int* max_id = new int[top_k_+1];
for (int i = 0; i < num; ++i) {
- // Accuracy
- Dtype maxval = -FLT_MAX;
- int max_id = 0;
- for (int j = 0; j < dim; ++j) {
- if (bottom_data[i * dim + j] > maxval) {
- maxval = bottom_data[i * dim + j];
- max_id = j;
+ // Top-k accuracy
+ std::fill_n(maxval, top_k_, -FLT_MAX);
+ std::fill_n(max_id, top_k_, 0);
+ for (int j = 0, k; j < dim; ++j) {
+ // insert into (reverse-)sorted top-k array
+ Dtype val = bottom_data[i * dim + j];
+ for (k = top_k_; k > 0 && maxval[k-1] < val; k--) {
+ maxval[k] = maxval[k-1];
+ max_id[k] = max_id[k-1];
}
+ maxval[k] = val;
+ max_id[k] = j;
}
- if (max_id == static_cast<int>(bottom_label[i])) {
- ++accuracy;
- }
+ // check if true label is in top k predictions
+ for (int k = 0; k < top_k_; k++)
+ if (max_id[k] == static_cast<int>(bottom_label[i])) {
+ ++accuracy;
+ break;
+ }
}
+ delete[] maxval;
+ delete[] max_id;
+
// LOG(INFO) << "Accuracy: " << accuracy;
(*top)[0]->mutable_cpu_data()[0] = accuracy / num;