Incorporate top_k param into AccuracyLayer and check it's value.
authorRob Hess <hess@yahoo-inc.com>
Sat, 14 Jun 2014 02:07:38 +0000 (19:07 -0700)
committerRob Hess <hess@yahoo-inc.com>
Fri, 27 Jun 2014 18:17:46 +0000 (11:17 -0700)
include/caffe/loss_layers.hpp
src/caffe/layers/accuracy_layer.cpp
src/caffe/proto/caffe.proto

index bb03f63..3a4d416 100644 (file)
@@ -251,6 +251,8 @@ class AccuracyLayer : public Layer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
     NOT_IMPLEMENTED;
   }
+
+  int top_k_;
 };
 
 }  // namespace caffe
index 899750f..55f620d 100644 (file)
@@ -18,8 +18,11 @@ template <typename Dtype>
 void AccuracyLayer<Dtype>::SetUp(
   const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   Layer<Dtype>::SetUp(bottom, top);
+  top_k_ = this->layer_param_.accuracy_param().top_k();
   CHECK_EQ(bottom[0]->num(), bottom[1]->num())
       << "The data and label should have the same number.";
+  CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
+      << "top_k must be less than the number of classes.";
   CHECK_EQ(bottom[1]->channels(), 1);
   CHECK_EQ(bottom[1]->height(), 1);
   CHECK_EQ(bottom[1]->width(), 1);
index 48d6fe2..936a215 100644 (file)
@@ -220,7 +220,7 @@ message AccuracyParameter {
   // When computing accuracy, count as correct by comparing the true label to
   // the top k scoring classes.  By default, only compare to the top scoring
   // class (i.e. argmax).
-  optional uint32 compare_to_top_k = 1 [default = 1];
+  optional uint32 top_k = 1 [default = 1];
 }
 
 // Message that stores parameters used by ArgMaxLayer