Fix speedtest (and possibly other tools) by setting the net phase to the
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 2 Aug 2014 20:17:15 +0000 (13:17 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 2 Aug 2014 20:28:49 +0000 (13:28 -0700)
current Caffe::phase() unless explicitly specified in the state.

src/caffe/net.cpp

index a4d1f23..b80bbba 100644 (file)
@@ -174,7 +174,21 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
 template <typename Dtype>
 void Net<Dtype>::FilterNet(const NetParameter& param,
     NetParameter* param_filtered) {
-  const NetState& net_state = param.state();
+  NetState net_state(param.state());
+  // Let the phase of the net be the current global phase provided in the Caffe
+  // singleton, unless explicitly provided by the state.
+  if (!net_state.has_phase()) {
+    switch (Caffe::phase()) {
+      case Caffe::TRAIN:
+        net_state.set_phase(TRAIN);
+        break;
+      case Caffe::TEST:
+        net_state.set_phase(TEST);
+        break;
+      default:
+        LOG(FATAL) << "Unknown phase: " << Caffe::phase();
+    }
+  }
   param_filtered->CopyFrom(param);
   param_filtered->clear_layers();
   for (int i = 0; i < param.layers_size(); ++i) {