Fix AccuracyLayerTest for per-class accuracy.
authorRonghang Hu <huronghang@hotmail.com>
Fri, 4 Sep 2015 04:44:45 +0000 (21:44 -0700)
committerRonghang Hu <huronghang@hotmail.com>
Fri, 4 Sep 2015 04:47:53 +0000 (21:47 -0700)
Fix AccuracyLayerTest for per-class accuracy. Previously in #2935, it crashes since the test accuracy is nan (0/0) when a class never appear.

src/caffe/test/test_accuracy_layer.cpp

index 94e529b..ef0e57a 100644 (file)
@@ -250,7 +250,6 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) {
 
 TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) {
   LayerParameter layer_param;
-  Caffe::set_mode(Caffe::CPU);
   AccuracyLayer<TypeParam> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_per_class_vec_);
   layer.Forward(this->blob_bottom_vec_, this->blob_top_per_class_vec_);
@@ -279,16 +278,16 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) {
   EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0),
               num_correct_labels / 100.0, 1e-4);
   for (int i = 0; i < num_class; ++i) {
+    TypeParam accuracy_per_class = (num_per_class[i] > 0 ?
+       static_cast<TypeParam>(correct_per_class[i]) / num_per_class[i] : 0);
     EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0),
-                static_cast<float>(correct_per_class[i]) / num_per_class[i],
-                1e-4);
+                accuracy_per_class, 1e-4);
   }
 }
 
 
 TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) {
   LayerParameter layer_param;
-  Caffe::set_mode(Caffe::CPU);
   const TypeParam kIgnoreLabelValue = -1;
   layer_param.mutable_accuracy_param()->set_ignore_label(kIgnoreLabelValue);
   AccuracyLayer<TypeParam> layer(layer_param);
@@ -329,9 +328,10 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) {
   EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0),
               num_correct_labels / TypeParam(count), 1e-4);
   for (int i = 0; i < 10; ++i) {
+    TypeParam accuracy_per_class = (num_per_class[i] > 0 ?
+       static_cast<TypeParam>(correct_per_class[i]) / num_per_class[i] : 0);
     EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0),
-                TypeParam(correct_per_class[i]) / num_per_class[i],
-                1e-4);
+                accuracy_per_class, 1e-4);
   }
 }