template <typename Dtype>
void Net<Dtype>::FilterNet(const NetParameter& param,
NetParameter* param_filtered) {
- const NetState& net_state = param.state();
+ NetState net_state(param.state());
+ // Let the phase of the net be the current global phase provided in the Caffe
+ // singleton, unless explicitly provided by the state.
+ if (!net_state.has_phase()) {
+ switch (Caffe::phase()) {
+ case Caffe::TRAIN:
+ net_state.set_phase(TRAIN);
+ break;
+ case Caffe::TEST:
+ net_state.set_phase(TEST);
+ break;
+ default:
+ LOG(FATAL) << "Unknown phase: " << Caffe::phase();
+ }
+ }
param_filtered->CopyFrom(param);
param_filtered->clear_layers();
for (int i = 0; i < param.layers_size(); ++i) {