From 1c640c976b113839514ca5b70a28f6a921cf9eb2 Mon Sep 17 00:00:00 2001 From: Rob Hess Date: Fri, 13 Jun 2014 19:07:38 -0700 Subject: [PATCH] Incorporate top_k param into AccuracyLayer and check it's value. --- include/caffe/loss_layers.hpp | 2 ++ src/caffe/layers/accuracy_layer.cpp | 3 +++ src/caffe/proto/caffe.proto | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index bb03f63..3a4d416 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -251,6 +251,8 @@ class AccuracyLayer : public Layer { const vector& propagate_down, vector*>* bottom) { NOT_IMPLEMENTED; } + + int top_k_; }; } // namespace caffe diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 899750f..55f620d 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -18,8 +18,11 @@ template void AccuracyLayer::SetUp( const vector*>& bottom, vector*>* top) { Layer::SetUp(bottom, top); + top_k_ = this->layer_param_.accuracy_param().top_k(); CHECK_EQ(bottom[0]->num(), bottom[1]->num()) << "The data and label should have the same number."; + CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) + << "top_k must be less than the number of classes."; CHECK_EQ(bottom[1]->channels(), 1); CHECK_EQ(bottom[1]->height(), 1); CHECK_EQ(bottom[1]->width(), 1); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 48d6fe2..936a215 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -220,7 +220,7 @@ message AccuracyParameter { // When computing accuracy, count as correct by comparing the true label to // the top k scoring classes. By default, only compare to the top scoring // class (i.e. argmax). - optional uint32 compare_to_top_k = 1 [default = 1]; + optional uint32 top_k = 1 [default = 1]; } // Message that stores parameters used by ArgMaxLayer -- 2.7.4