void AccuracyLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
Layer<Dtype>::SetUp(bottom, top);
+ top_k_ = this->layer_param_.accuracy_param().top_k();
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
+ CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
+ << "top_k must be less than the number of classes.";
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
// When computing accuracy, count as correct by comparing the true label to
// the top k scoring classes. By default, only compare to the top scoring
// class (i.e. argmax).
- optional uint32 compare_to_top_k = 1 [default = 1];
+ optional uint32 top_k = 1 [default = 1];
}
// Message that stores parameters used by ArgMaxLayer