Add tests for phase filtering according to Caffe singleton phase
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 2 Aug 2014 20:10:46 +0000 (13:10 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 2 Aug 2014 20:28:49 +0000 (13:28 -0700)
(currently failing as FilterNet ignores the singleton phase).

src/caffe/test/test_net.cpp

index c418134..a0a3879 100644 (file)
@@ -961,8 +961,7 @@ TEST_F(FilterNetTest, TestFilterLeNetTrainTest) {
       "} ";
   const string input_proto_train = "state: { phase: TRAIN } " + input_proto;
   const string input_proto_test = "state: { phase: TEST } " + input_proto;
-  const string& output_proto_train =
-      "state: { phase: TRAIN } "
+  const string output_proto_train =
       "name: 'LeNet' "
       "layers { "
       "  name: 'mnist' "
@@ -1020,7 +1019,6 @@ TEST_F(FilterNetTest, TestFilterLeNetTrainTest) {
       "  top: 'loss' "
       "} ";
   const string& output_proto_test =
-      "state: { phase: TEST } "
       "name: 'LeNet' "
       "layers { "
       "  name: 'mnist' "
@@ -1085,8 +1083,26 @@ TEST_F(FilterNetTest, TestFilterLeNetTrainTest) {
       "  bottom: 'label' "
       "  top: 'loss' "
       "} ";
-  this->RunFilterNetTest(input_proto_train, output_proto_train);
-  this->RunFilterNetTest(input_proto_test, output_proto_test);
+  const string output_proto_train_explicit =
+      output_proto_train + " state: { phase: TRAIN } ";
+  const string output_proto_test_explicit =
+      output_proto_test + " state: { phase: TEST } ";
+  this->RunFilterNetTest(input_proto_train, output_proto_train_explicit);
+  this->RunFilterNetTest(input_proto_test, output_proto_test_explicit);
+
+  // Also check that nets are filtered according to the Caffe singleton phase,
+  // if not explicitly specified in the input proto.
+  Caffe::set_phase(Caffe::TRAIN);
+  this->RunFilterNetTest(input_proto, output_proto_train);
+  Caffe::set_phase(Caffe::TEST);
+  this->RunFilterNetTest(input_proto, output_proto_test);
+
+  // Finally, check that the current Caffe singleton phase is ignored if the
+  // phase is explicitly specified in the input proto.
+  Caffe::set_phase(Caffe::TEST);
+  this->RunFilterNetTest(input_proto_train, output_proto_train_explicit);
+  Caffe::set_phase(Caffe::TRAIN);
+  this->RunFilterNetTest(input_proto_test, output_proto_test_explicit);
 }
 
 TEST_F(FilterNetTest, TestFilterOutByStage) {