"} ";
const string input_proto_train = "state: { phase: TRAIN } " + input_proto;
const string input_proto_test = "state: { phase: TEST } " + input_proto;
- const string& output_proto_train =
- "state: { phase: TRAIN } "
+ const string output_proto_train =
"name: 'LeNet' "
"layers { "
" name: 'mnist' "
" top: 'loss' "
"} ";
const string& output_proto_test =
- "state: { phase: TEST } "
"name: 'LeNet' "
"layers { "
" name: 'mnist' "
" bottom: 'label' "
" top: 'loss' "
"} ";
- this->RunFilterNetTest(input_proto_train, output_proto_train);
- this->RunFilterNetTest(input_proto_test, output_proto_test);
+ const string output_proto_train_explicit =
+ output_proto_train + " state: { phase: TRAIN } ";
+ const string output_proto_test_explicit =
+ output_proto_test + " state: { phase: TEST } ";
+ this->RunFilterNetTest(input_proto_train, output_proto_train_explicit);
+ this->RunFilterNetTest(input_proto_test, output_proto_test_explicit);
+
+ // Also check that nets are filtered according to the Caffe singleton phase,
+ // if not explicitly specified in the input proto.
+ Caffe::set_phase(Caffe::TRAIN);
+ this->RunFilterNetTest(input_proto, output_proto_train);
+ Caffe::set_phase(Caffe::TEST);
+ this->RunFilterNetTest(input_proto, output_proto_test);
+
+ // Finally, check that the current Caffe singleton phase is ignored if the
+ // phase is explicitly specified in the input proto.
+ Caffe::set_phase(Caffe::TEST);
+ this->RunFilterNetTest(input_proto_train, output_proto_train_explicit);
+ Caffe::set_phase(Caffe::TRAIN);
+ this->RunFilterNetTest(input_proto_test, output_proto_test_explicit);
}
TEST_F(FilterNetTest, TestFilterOutByStage) {