'caffe test' prints all scores and their names
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 25 Aug 2014 18:59:20 +0000 (11:59 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 25 Aug 2014 19:00:35 +0000 (12:00 -0700)
tools/caffe.cpp

index 77e031e..5b3ad0b 100644 (file)
@@ -136,14 +136,47 @@ int test() {
   caffe_net.CopyTrainedLayersFrom(FLAGS_weights);
   LOG(INFO) << "Running for " << FLAGS_iterations << " iterations.";
 
-  double test_score = 0;
+  vector<Blob<float>* > bottom_vec;
+  vector<int> test_score_output_id;
+  vector<float> test_score;
+  float loss = 0;
   for (int i = 0; i < FLAGS_iterations; ++i) {
-    const vector<Blob<float>*>& result = caffe_net.ForwardPrefilled();
-    test_score += result[0]->cpu_data()[0];
-    LOG(INFO) << "Batch " << i << ", score: " << result[0]->cpu_data()[0];
+    float iter_loss;
+    const vector<Blob<float>*>& result =
+        caffe_net.Forward(bottom_vec, &iter_loss);
+    loss += iter_loss;
+    int idx = 0;
+    for (int j = 0; j < result.size(); ++j) {
+      const float* result_vec = result[j]->cpu_data();
+      for (int k = 0; k < result[j]->count(); ++k, ++idx) {
+        const float score = result_vec[k];
+        if (i == 0) {
+          test_score.push_back(score);
+          test_score_output_id.push_back(j);
+        } else {
+          test_score[idx] += score;
+        }
+        const std::string& output_name = caffe_net.blob_names()[
+            caffe_net.output_blob_indices()[j]];
+        LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score;
+      }
+    }
+  }
+  loss /= FLAGS_iterations;
+  LOG(INFO) << "Loss: " << loss;
+  for (int i = 0; i < test_score.size(); ++i) {
+    const std::string& output_name = caffe_net.blob_names()[
+        caffe_net.output_blob_indices()[test_score_output_id[i]]];
+    const float loss_weight =
+        caffe_net.blob_loss_weights()[caffe_net.output_blob_indices()[i]];
+    std::ostringstream loss_msg_stream;
+    const float mean_score = test_score[i] / FLAGS_iterations;
+    if (loss_weight) {
+      loss_msg_stream << " (* " << loss_weight
+                      << " = " << loss_weight * mean_score << " loss)";
+    }
+    LOG(INFO) << output_name << " = " << mean_score << loss_msg_stream.str();
   }
-  test_score /= FLAGS_iterations;
-  LOG(INFO) << "Score: " << test_score;
 
   return 0;
 }