From 62e0c8559045cb2b5a12e0d6c41acd25d4122630 Mon Sep 17 00:00:00 2001 From: Shai Date: Thu, 10 Aug 2017 10:07:19 +0300 Subject: [PATCH] upgrading Accuracy layer: (1) efficient CPU implementation O(L) for top_k, no need for fancy priority_queue etc. (2) GPU implementation --- include/caffe/layers/accuracy_layer.hpp | 4 + src/caffe/layers/accuracy_layer.cpp | 33 ++- src/caffe/layers/accuracy_layer.cu | 147 +++++++++++++ src/caffe/test/test_accuracy_layer.cpp | 360 +++++++++++++++++--------------- 4 files changed, 364 insertions(+), 180 deletions(-) create mode 100644 src/caffe/layers/accuracy_layer.cu diff --git a/include/caffe/layers/accuracy_layer.hpp b/include/caffe/layers/accuracy_layer.hpp index a9ad322..dd2247b 100644 --- a/include/caffe/layers/accuracy_layer.hpp +++ b/include/caffe/layers/accuracy_layer.hpp @@ -68,6 +68,8 @@ class AccuracyLayer : public Layer { */ virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); /// @brief Not implemented -- AccuracyLayer cannot be used as a loss. @@ -77,6 +79,8 @@ class AccuracyLayer : public Layer { if (propagate_down[i]) { NOT_IMPLEMENTED; } } } + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); int label_axis_, outer_num_, inner_num_; diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 4eddbb5..392829e 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -52,8 +52,6 @@ void AccuracyLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_label = bottom[1]->cpu_data(); const int dim = bottom[0]->count() / outer_num_; const int num_labels = bottom[0]->shape(label_axis_); - vector maxval(top_k_+1); - vector max_id(top_k_+1); if (top.size() > 1) { caffe_set(nums_buffer_.count(), Dtype(0), nums_buffer_.mutable_cpu_data()); caffe_set(top[1]->count(), Dtype(0), top[1]->mutable_cpu_data()); @@ -66,25 +64,22 @@ void AccuracyLayer::Forward_cpu(const vector*>& bottom, if (has_ignore_label_ && label_value == ignore_label_) { continue; } - if (top.size() > 1) ++nums_buffer_.mutable_cpu_data()[label_value]; DCHECK_GE(label_value, 0); DCHECK_LT(label_value, num_labels); + if (top.size() > 1) ++nums_buffer_.mutable_cpu_data()[label_value]; + const Dtype prob_of_true_class = bottom_data[i * dim + + label_value * inner_num_ + + j]; + int num_better_predictions = -1; // true_class also counts as "better" // Top-k accuracy - std::vector > bottom_data_vector; - for (int k = 0; k < num_labels; ++k) { - bottom_data_vector.push_back(std::make_pair( - bottom_data[i * dim + k * inner_num_ + j], k)); + for (int k = 0; k < num_labels && num_better_predictions < top_k_; ++k) { + num_better_predictions += + (bottom_data[i * dim + k * inner_num_ + j] >= prob_of_true_class); } - std::partial_sort( - bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, - bottom_data_vector.end(), std::greater >()); - // check if true label is in top k predictions - for (int k = 0; k < top_k_; k++) { - if (bottom_data_vector[k].second == label_value) { - ++accuracy; - if (top.size() > 1) ++top[1]->mutable_cpu_data()[label_value]; - break; - } + // check if there are less than top_k_ predictions + if (num_better_predictions < top_k_) { + ++accuracy; + if (top.size() > 1) ++top[1]->mutable_cpu_data()[label_value]; } ++count; } @@ -102,6 +97,10 @@ void AccuracyLayer::Forward_cpu(const vector*>& bottom, // Accuracy layer should not be used as a loss function. } +#ifdef CPU_ONLY +STUB_GPU(AccuracyLayer); +#endif + INSTANTIATE_CLASS(AccuracyLayer); REGISTER_LAYER_CLASS(Accuracy); diff --git a/src/caffe/layers/accuracy_layer.cu b/src/caffe/layers/accuracy_layer.cu new file mode 100644 index 0000000..a8cff93 --- /dev/null +++ b/src/caffe/layers/accuracy_layer.cu @@ -0,0 +1,147 @@ +#include + +#include "caffe/layers/accuracy_layer.hpp" +#include "caffe/util/math_functions.hpp" + + +namespace caffe { + +template +__global__ void AccuracyForwardGPU(const int nthreads, + const Dtype* bottom_data, const Dtype* label, Dtype* acc, + const int num, const int dim, const int spatial_dim, + const int num_labels, const int top_k, + const bool has_ignore_label_, const int ignore_label_, + Dtype* counts) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / spatial_dim; + const int s = index % spatial_dim; + const int label_value = static_cast(label[n * spatial_dim + s]); + const Dtype prob_of_true_class = bottom_data[n * dim + + label_value * spatial_dim + + s]; + int num_better_predictions = -1; // true_class also counts as "better" + if (has_ignore_label_ && label_value == ignore_label_) { + acc[index] = 0; + counts[index] = 0; + } else { + for (int k = 0; k < num_labels & num_better_predictions < top_k; k++) { + num_better_predictions += + (bottom_data[n * dim + k * spatial_dim + s] >= prob_of_true_class); + } + acc[index] = (num_better_predictions < top_k); + counts[index] = 1; + } + } +} + +template +__global__ void AccuracyForwardWithPerClassGPU(const int nthreads, + const Dtype* bottom_data, const Dtype* label, + Dtype* acc, Dtype* counts, + const int num, const int dim, const int spatial_dim, + const int num_labels, const int top_k, + const bool has_ignore_label_, const int ignore_label_) { + CUDA_KERNEL_LOOP(index, nthreads) { + const int n = index / spatial_dim; + const int s = index % spatial_dim; + const int label_value = static_cast(label[n * spatial_dim + s]); + const Dtype prob_of_true_class = bottom_data[n * dim + + label_value * spatial_dim + + s]; + if (has_ignore_label_ && label_value == ignore_label_) { + // nothing to be done. + } else { + int num_better_predictions = -1; // true_class also counts as "better" + for (int k = 0; k < num_labels & num_better_predictions < top_k; k++) { + num_better_predictions += + (bottom_data[n * dim + k * spatial_dim + s] >= prob_of_true_class); + } + acc[label_value*nthreads + index] += (num_better_predictions < top_k); + counts[label_value*nthreads + index] = 1; + } + } +} + +template +void AccuracyLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* bottom_label = bottom[1]->gpu_data(); + const int dim = bottom[0]->count() / outer_num_; + const int num_labels = bottom[0]->shape(label_axis_); + const int nthreads = outer_num_ * inner_num_; + // Since this memory is not used for anything, + // we use it here to avoid having to allocate new GPU + // memory to accumulate intermediate results in the kernel. + Dtype* acc_data = bottom[0]->mutable_gpu_diff(); + if (top.size() == 1) { + // simple case - report only global accuracy. + + // Similarly, this memory is never used elsewhere, and thus we can use it + // to avoid having to allocate additional GPU memory. + Dtype* counts = bottom[1]->mutable_gpu_diff(); + // NOLINT_NEXT_LINE(whitespace/operators) + AccuracyForwardGPU<<>>(nthreads, bottom_data, bottom_label, + acc_data, outer_num_, dim, inner_num_, num_labels, top_k_, + has_ignore_label_, ignore_label_, counts); + Dtype acc; + caffe_gpu_asum(nthreads, acc_data, &acc); + Dtype valid_count; + caffe_gpu_asum(nthreads, counts, &valid_count); + if (valid_count > 0) { + top[0]->mutable_cpu_data()[0] = acc / valid_count; + } else { + top[0]->mutable_cpu_data()[0] = 0; + } + } else { + // need to report per-class accuracy as well + + // allocate space for more detailed "counts" + nums_buffer_.ReshapeLike(*bottom[0]); + Dtype* counts = nums_buffer_.mutable_gpu_data(); + + caffe_gpu_set(bottom[0]->count(), Dtype(0), acc_data); + caffe_gpu_set(nums_buffer_.count(), Dtype(0), counts); + + // NOLINT_NEXT_LINE(whitespace/operators) + AccuracyForwardWithPerClassGPU<<>>(nthreads, bottom_data, bottom_label, + acc_data, counts, outer_num_, dim, inner_num_, num_labels, top_k_, + has_ignore_label_, ignore_label_); + + // get the overall accuracy + Dtype acc; + caffe_gpu_asum(bottom[0]->count(), acc_data, &acc); + Dtype valid_count; + caffe_gpu_asum(nums_buffer_.count(), counts, &valid_count); + if (valid_count > 0) { + top[0]->mutable_cpu_data()[0] = acc / valid_count; + } else { + top[0]->mutable_cpu_data()[0] = 0; + } + + // get per-class accuracy + Dtype* per_class_acc = top[1]->mutable_cpu_data(); + for (int l = 0; l < num_labels; l++) { + caffe_gpu_asum(nthreads, acc_data + l*nthreads, per_class_acc+l); + caffe_gpu_asum(nthreads, counts + l*nthreads, &valid_count); + if (valid_count > 0) { + per_class_acc[l] /= valid_count; + } else { + per_class_acc[l] = 0; + } + } + } +} + + +template +void AccuracyLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[1]) { NOT_IMPLEMENTED; } +} + +INSTANTIATE_LAYER_GPU_FUNCS(AccuracyLayer); +} // namespace caffe diff --git a/src/caffe/test/test_accuracy_layer.cpp b/src/caffe/test/test_accuracy_layer.cpp index 6fe808b..e5cc9d5 100644 --- a/src/caffe/test/test_accuracy_layer.cpp +++ b/src/caffe/test/test_accuracy_layer.cpp @@ -13,8 +13,10 @@ namespace caffe { -template -class AccuracyLayerTest : public CPUDeviceTest { +template +class AccuracyLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: AccuracyLayerTest() : blob_bottom_data_(new Blob()), @@ -69,11 +71,12 @@ class AccuracyLayerTest : public CPUDeviceTest { int top_k_; }; -TYPED_TEST_CASE(AccuracyLayerTest, TestDtypes); +TYPED_TEST_CASE(AccuracyLayerTest, TestDtypesAndDevices); TYPED_TEST(AccuracyLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_->num(), 1); EXPECT_EQ(this->blob_top_->channels(), 1); @@ -82,11 +85,12 @@ TYPED_TEST(AccuracyLayerTest, TestSetup) { } TYPED_TEST(AccuracyLayerTest, TestSetupTopK) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; AccuracyParameter* accuracy_param = layer_param.mutable_accuracy_param(); accuracy_param->set_top_k(5); - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_->num(), 1); EXPECT_EQ(this->blob_top_->channels(), 1); @@ -95,8 +99,9 @@ TYPED_TEST(AccuracyLayerTest, TestSetupTopK) { } TYPED_TEST(AccuracyLayerTest, TestSetupOutputPerClass) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_per_class_vec_); EXPECT_EQ(this->blob_top_->num(), 1); EXPECT_EQ(this->blob_top_->channels(), 1); @@ -108,33 +113,39 @@ TYPED_TEST(AccuracyLayerTest, TestSetupOutputPerClass) { EXPECT_EQ(this->blob_top_per_class_->width(), 1); } -TYPED_TEST(AccuracyLayerTest, TestForwardCPU) { +TYPED_TEST(AccuracyLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); - layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - - TypeParam max_value; - int max_id; - int num_correct_labels = 0; - for (int i = 0; i < 100; ++i) { - max_value = -FLT_MAX; - max_id = 0; - for (int j = 0; j < 10; ++j) { - if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { - max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - max_id = j; + + // repeat the forward + for (int iter = 0; iter < 3; iter++) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + Dtype max_value; + int max_id; + int num_correct_labels = 0; + for (int i = 0; i < 100; ++i) { + max_value = -FLT_MAX; + max_id = 0; + for (int j = 0; j < 10; ++j) { + if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { + max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + max_id = j; + } + } + if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; } } - if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - ++num_correct_labels; - } + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / Dtype(100.0), 1e-4); } - EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / 100.0, 1e-4); } TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) { + typedef typename TypeParam::Dtype Dtype; this->blob_bottom_data_->Reshape(2, 10, 4, 5); vector label_shape(3); label_shape[0] = 2; label_shape[1] = 4; label_shape[2] = 5; @@ -142,195 +153,218 @@ TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) { this->FillBottoms(); LayerParameter layer_param; layer_param.mutable_accuracy_param()->set_axis(1); - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); - layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - - TypeParam max_value; - const int num_labels = this->blob_bottom_label_->count(); - int max_id; - int num_correct_labels = 0; - vector label_offset(3); - for (int n = 0; n < this->blob_bottom_data_->num(); ++n) { - for (int h = 0; h < this->blob_bottom_data_->height(); ++h) { - for (int w = 0; w < this->blob_bottom_data_->width(); ++w) { - max_value = -FLT_MAX; - max_id = 0; - for (int c = 0; c < this->blob_bottom_data_->channels(); ++c) { - const TypeParam pred_value = - this->blob_bottom_data_->data_at(n, c, h, w); - if (pred_value > max_value) { - max_value = pred_value; - max_id = c; + + // repeat the forward + for (int iter = 0; iter < 3; iter++) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + Dtype max_value; + const int num_labels = this->blob_bottom_label_->count(); + int max_id; + int num_correct_labels = 0; + vector label_offset(3); + for (int n = 0; n < this->blob_bottom_data_->num(); ++n) { + for (int h = 0; h < this->blob_bottom_data_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_data_->width(); ++w) { + max_value = -FLT_MAX; + max_id = 0; + for (int c = 0; c < this->blob_bottom_data_->channels(); ++c) { + const Dtype pred_value = + this->blob_bottom_data_->data_at(n, c, h, w); + if (pred_value > max_value) { + max_value = pred_value; + max_id = c; + } + } + label_offset[0] = n; label_offset[1] = h; label_offset[2] = w; + const int correct_label = + static_cast(this->blob_bottom_label_->data_at(label_offset)); + if (max_id == correct_label) { + ++num_correct_labels; } - } - label_offset[0] = n; label_offset[1] = h; label_offset[2] = w; - const int correct_label = - static_cast(this->blob_bottom_label_->data_at(label_offset)); - if (max_id == correct_label) { - ++num_correct_labels; } } } + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / Dtype(num_labels), 1e-4); } - EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / TypeParam(num_labels), 1e-4); } TYPED_TEST(AccuracyLayerTest, TestForwardIgnoreLabel) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - const TypeParam kIgnoreLabelValue = -1; + const Dtype kIgnoreLabelValue = -1; layer_param.mutable_accuracy_param()->set_ignore_label(kIgnoreLabelValue); - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); // Manually set some labels to the ignore label value (-1). this->blob_bottom_label_->mutable_cpu_data()[2] = kIgnoreLabelValue; this->blob_bottom_label_->mutable_cpu_data()[5] = kIgnoreLabelValue; this->blob_bottom_label_->mutable_cpu_data()[32] = kIgnoreLabelValue; layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); - layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - - TypeParam max_value; - int max_id; - int num_correct_labels = 0; - int count = 0; - for (int i = 0; i < 100; ++i) { - if (kIgnoreLabelValue == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - continue; - } - ++count; - max_value = -FLT_MAX; - max_id = 0; - for (int j = 0; j < 10; ++j) { - if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { - max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - max_id = j; + + // repeat the forward + for (int iter = 0; iter < 3; iter++) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + Dtype max_value; + int max_id; + int num_correct_labels = 0; + int count = 0; + for (int i = 0; i < 100; ++i) { + if (kIgnoreLabelValue == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + continue; + } + ++count; + max_value = -FLT_MAX; + max_id = 0; + for (int j = 0; j < 10; ++j) { + if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { + max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + max_id = j; + } + } + if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; } } - if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - ++num_correct_labels; - } + EXPECT_EQ(count, 97); // We set 3 out of 100 labels to kIgnoreLabelValue. + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / Dtype(count), 1e-4); } - EXPECT_EQ(count, 97); // We set 3 out of 100 labels to kIgnoreLabelValue. - EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / TypeParam(count), 1e-4); } -TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) { +TYPED_TEST(AccuracyLayerTest, TestForwardTopK) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; AccuracyParameter* accuracy_param = layer_param.mutable_accuracy_param(); accuracy_param->set_top_k(this->top_k_); - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); - layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); - - TypeParam current_value; - int current_rank; - int num_correct_labels = 0; - for (int i = 0; i < 100; ++i) { - for (int j = 0; j < 10; ++j) { - current_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - current_rank = 0; - for (int k = 0; k < 10; ++k) { - if (this->blob_bottom_data_->data_at(i, k, 0, 0) > current_value) { - ++current_rank; + + // repeat the forward + for (int iter = 0; iter < 3; iter++) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + Dtype current_value; + int current_rank; + int num_correct_labels = 0; + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 10; ++j) { + current_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + current_rank = 0; + for (int k = 0; k < 10; ++k) { + if (this->blob_bottom_data_->data_at(i, k, 0, 0) > current_value) { + ++current_rank; + } + } + if (current_rank < this->top_k_ && + j == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; } - } - if (current_rank < this->top_k_ && - j == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - ++num_correct_labels; } } - } - EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / 100.0, 1e-4); + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / Dtype(100.0), 1e-4); + } } -TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) { +TYPED_TEST(AccuracyLayerTest, TestForwardPerClass) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_per_class_vec_); - layer.Forward(this->blob_bottom_vec_, this->blob_top_per_class_vec_); - - TypeParam max_value; - int max_id; - int num_correct_labels = 0; - const int num_class = this->blob_top_per_class_->num(); - vector correct_per_class(num_class, 0); - vector num_per_class(num_class, 0); - for (int i = 0; i < 100; ++i) { - max_value = -FLT_MAX; - max_id = 0; - for (int j = 0; j < 10; ++j) { - if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { - max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - max_id = j; + // repeat the forward + for (int iter = 0; iter < 3; iter++) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_per_class_vec_); + + Dtype max_value; + int max_id; + int num_correct_labels = 0; + const int num_class = this->blob_top_per_class_->num(); + vector correct_per_class(num_class, 0); + vector num_per_class(num_class, 0); + for (int i = 0; i < 100; ++i) { + max_value = -FLT_MAX; + max_id = 0; + for (int j = 0; j < 10; ++j) { + if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { + max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + max_id = j; + } + } + ++num_per_class[this->blob_bottom_label_->data_at(i, 0, 0, 0)]; + if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; + ++correct_per_class[max_id]; } } - ++num_per_class[this->blob_bottom_label_->data_at(i, 0, 0, 0)]; - if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - ++num_correct_labels; - ++correct_per_class[max_id]; + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / 100.0, 1e-4); + for (int i = 0; i < num_class; ++i) { + Dtype accuracy_per_class = (num_per_class[i] > 0 ? + static_cast(correct_per_class[i]) / num_per_class[i] : 0); + EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0), + accuracy_per_class, 1e-4); } } - EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / 100.0, 1e-4); - for (int i = 0; i < num_class; ++i) { - TypeParam accuracy_per_class = (num_per_class[i] > 0 ? - static_cast(correct_per_class[i]) / num_per_class[i] : 0); - EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0), - accuracy_per_class, 1e-4); - } } -TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) { +TYPED_TEST(AccuracyLayerTest, TestForwardPerClassWithIgnoreLabel) { + typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - const TypeParam kIgnoreLabelValue = -1; + const Dtype kIgnoreLabelValue = -1; layer_param.mutable_accuracy_param()->set_ignore_label(kIgnoreLabelValue); - AccuracyLayer layer(layer_param); + AccuracyLayer layer(layer_param); // Manually set some labels to the ignore label value (-1). this->blob_bottom_label_->mutable_cpu_data()[2] = kIgnoreLabelValue; this->blob_bottom_label_->mutable_cpu_data()[5] = kIgnoreLabelValue; this->blob_bottom_label_->mutable_cpu_data()[32] = kIgnoreLabelValue; layer.SetUp(this->blob_bottom_vec_, this->blob_top_per_class_vec_); - layer.Forward(this->blob_bottom_vec_, this->blob_top_per_class_vec_); - - TypeParam max_value; - int max_id; - int num_correct_labels = 0; - const int num_class = this->blob_top_per_class_->num(); - vector correct_per_class(num_class, 0); - vector num_per_class(num_class, 0); - int count = 0; - for (int i = 0; i < 100; ++i) { - if (kIgnoreLabelValue == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - continue; - } - ++count; - max_value = -FLT_MAX; - max_id = 0; - for (int j = 0; j < 10; ++j) { - if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { - max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - max_id = j; + + // repeat the forward + for (int iter = 0; iter < 3; iter++) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_per_class_vec_); + + Dtype max_value; + int max_id; + int num_correct_labels = 0; + const int num_class = this->blob_top_per_class_->num(); + vector correct_per_class(num_class, 0); + vector num_per_class(num_class, 0); + int count = 0; + for (int i = 0; i < 100; ++i) { + if (kIgnoreLabelValue == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + continue; + } + ++count; + max_value = -FLT_MAX; + max_id = 0; + for (int j = 0; j < 10; ++j) { + if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { + max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + max_id = j; + } + } + ++num_per_class[this->blob_bottom_label_->data_at(i, 0, 0, 0)]; + if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; + ++correct_per_class[max_id]; } } - ++num_per_class[this->blob_bottom_label_->data_at(i, 0, 0, 0)]; - if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - ++num_correct_labels; - ++correct_per_class[max_id]; + EXPECT_EQ(count, 97); + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / Dtype(count), 1e-4); + for (int i = 0; i < 10; ++i) { + Dtype accuracy_per_class = (num_per_class[i] > 0 ? + static_cast(correct_per_class[i]) / num_per_class[i] : 0); + EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0), + accuracy_per_class, 1e-4); } } - EXPECT_EQ(count, 97); - EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), - num_correct_labels / TypeParam(count), 1e-4); - for (int i = 0; i < 10; ++i) { - TypeParam accuracy_per_class = (num_per_class[i] > 0 ? - static_cast(correct_per_class[i]) / num_per_class[i] : 0); - EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0), - accuracy_per_class, 1e-4); - } } } // namespace caffe -- 2.7.4