// trained layers from another net parameter instance.
void CopyTrainedLayersFrom(const NetParameter& param);
void CopyTrainedLayersFrom(const string trained_filename);
+ // Read parameters from a file into a NetParameter proto message.
+ void ReadParamsFromTextFile(const string& param_file, NetParameter* param);
+ void ReadParamsFromBinaryFile(const string& param_file, NetParameter* param);
// Writes the net to a proto.
void ToProto(NetParameter* param, bool write_diff = false);
WriteProtoToTextFile(proto, filename.c_str());
}
-void ReadProtoFromBinaryFile(const char* filename,
- Message* proto);
-inline void ReadProtoFromBinaryFile(const string& filename,
- Message* proto) {
- ReadProtoFromBinaryFile(filename.c_str(), proto);
+bool ReadProtoFromBinaryFile(const char* filename, Message* proto);
+
+inline bool ReadProtoFromBinaryFile(const string& filename, Message* proto) {
+ return ReadProtoFromBinaryFile(filename.c_str(), proto);
+}
+
+inline void ReadProtoFromBinaryFileOrDie(const char* filename, Message* proto) {
+ CHECK(ReadProtoFromBinaryFile(filename, proto));
}
+inline void ReadProtoFromBinaryFileOrDie(const string& filename,
+ Message* proto) {
+ ReadProtoFromBinaryFileOrDie(filename.c_str(), proto);
+}
+
+
void WriteProtoToBinaryFile(const Message& proto, const char* filename);
inline void WriteProtoToBinaryFile(
const Message& proto, const string& filename) {
CHECK_GT(datum_width_, crop_size);
// check if we want to have mean
if (this->layer_param_.data_param().has_mean_file()) {
+ const string& mean_file = this->layer_param_.data_param().mean_file();
+ LOG(INFO) << "Loading mean file from" << mean_file;
BlobProto blob_proto;
- LOG(INFO) << "Loading mean file from"
- << this->layer_param_.data_param().mean_file();
- ReadProtoFromBinaryFile(this->layer_param_.data_param().mean_file().c_str(),
- &blob_proto);
+ ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
data_mean_.FromProto(blob_proto);
CHECK_EQ(data_mean_.num(), 1);
CHECK_EQ(data_mean_.channels(), datum_channels_);
this->layer_param_.window_data_param().mean_file();
LOG(INFO) << "Loading mean file from" << mean_file;
BlobProto blob_proto;
- ReadProtoFromBinaryFile(mean_file, &blob_proto);
+ ReadProtoFromBinaryFileOrDie(mean_file, &blob_proto);
data_mean_.FromProto(blob_proto);
CHECK_EQ(data_mean_.num(), 1);
CHECK_EQ(data_mean_.width(), data_mean_.height());
template <typename Dtype>
Net<Dtype>::Net(const string& param_file) {
NetParameter param;
- if (!ReadProtoFromTextFile(param_file, ¶m)) {
- // Failed to parse file as NetParameter; try to parse as a V0NetParameter
- // instead.
- V0NetParameter v0_param;
- CHECK(ReadProtoFromTextFile(param_file, &v0_param))
- << "Failed to parse NetParameter file: " << param_file;
- LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file;
- LOG(ERROR) << "Note that future Caffe releases will not support "
- << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade "
- << "this and any other network proto files to the new format.";
- if (!UpgradeV0Net(v0_param, ¶m)) {
- LOG(ERROR) << "Warning: had one or more problems upgrading "
- << "V0NetParameter to NetParameter (see above); continuing anyway.";
- }
- }
+ ReadParamsFromTextFile(param_file, ¶m);
Init(param);
}
template <typename Dtype>
void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
NetParameter param;
- ReadProtoFromBinaryFile(trained_filename, ¶m);
+ ReadParamsFromBinaryFile(trained_filename, ¶m);
CopyTrainedLayersFrom(param);
}
template <typename Dtype>
+void Net<Dtype>::ReadParamsFromTextFile(const string& param_file,
+ NetParameter* param) {
+ if (!ReadProtoFromTextFile(param_file, param)) {
+ // Failed to parse file as NetParameter; try to parse as a V0NetParameter
+ // instead.
+ V0NetParameter v0_param;
+ CHECK(ReadProtoFromTextFile(param_file, &v0_param))
+ << "Failed to parse NetParameter file: " << param_file;
+ LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file;
+ LOG(ERROR) << "Note that future Caffe releases will not support "
+ << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade "
+ << "this and any other network proto files to the new format.";
+ if (!UpgradeV0Net(v0_param, param)) {
+ LOG(ERROR) << "Warning: had one or more problems upgrading "
+ << "V0NetParameter to NetParameter (see above); continuing anyway.";
+ }
+ }
+}
+
+template <typename Dtype>
+void Net<Dtype>::ReadParamsFromBinaryFile(const string& param_file,
+ NetParameter* param) {
+ if (!ReadProtoFromBinaryFile(param_file, param)) {
+ // Failed to parse file as NetParameter; try to parse as a V0NetParameter
+ // instead.
+ V0NetParameter v0_param;
+ CHECK(ReadProtoFromBinaryFile(param_file, &v0_param))
+ << "Failed to parse NetParameter file: " << param_file;
+ LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file;
+ LOG(ERROR) << "Note that future Caffe releases will not support "
+ << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade "
+ << "this and any other network proto files to the new format.";
+ if (!UpgradeV0Net(v0_param, param)) {
+ LOG(ERROR) << "Warning: had one or more problems upgrading "
+ << "V0NetParameter to NetParameter (see above); continuing anyway.";
+ }
+ }
+}
+
+template <typename Dtype>
void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
param->Clear();
param->set_name(name_);
close(fd);
}
-void ReadProtoFromBinaryFile(const char* filename, Message* proto) {
+bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
int fd = open(filename, O_RDONLY);
CHECK_NE(fd, -1) << "File not found: " << filename;
ZeroCopyInputStream* raw_input = new FileInputStream(fd);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(536870912, 268435456);
- CHECK(proto->ParseFromCodedStream(coded_input));
+ bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input;
delete raw_input;
close(fd);
+ return success;
}
void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TEST);
- NetParameter net_param;
- NetParameter trained_net_param;
-
+ shared_ptr<Net<float> > caffe_net;
if (strcmp(argv[1], "none") == 0) {
// We directly load the net param from trained file
- ReadProtoFromBinaryFile(argv[2], &net_param);
+ caffe_net.reset(new Net<float>(argv[2]));
} else {
- ReadProtoFromTextFileOrDie(argv[1], &net_param);
+ caffe_net.reset(new Net<float>(argv[1]));
}
- ReadProtoFromBinaryFile(argv[2], &trained_net_param);
+ caffe_net->CopyTrainedLayersFrom(argv[2]);
vector<Blob<float>* > input_vec;
shared_ptr<Blob<float> > input_blob(new Blob<float>());
input_vec.push_back(input_blob.get());
}
- shared_ptr<Net<float> > caffe_net(new Net<float>(net_param));
- caffe_net->CopyTrainedLayersFrom(trained_net_param);
-
string output_prefix(argv[4]);
// Run the network without training.
LOG(ERROR) << "Performing Forward";
}
Caffe::set_phase(Caffe::TEST);
- NetParameter pretrained_net_param;
-
arg_pos = 0; // the name of the executable
string pretrained_binary_proto(argv[++arg_pos]);
- ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
- &pretrained_net_param);
// Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
top: "fc7"
}
*/
- NetParameter feature_extraction_net_param;
string feature_extraction_proto(argv[++arg_pos]);
- ReadProtoFromTextFile(feature_extraction_proto,
- &feature_extraction_net_param);
shared_ptr<Net<Dtype> > feature_extraction_net(
- new Net<Dtype>(feature_extraction_net_param));
- feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
+ new Net<Dtype>(feature_extraction_proto));
+ feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
string extract_feature_blob_name(argv[++arg_pos]);
CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
}
Caffe::set_phase(Caffe::TRAIN);
- NetParameter net_param;
- ReadProtoFromTextFileOrDie(argv[1],
- &net_param);
- Net<float> caffe_net(net_param);
+ Net<float> caffe_net(argv[1]);
// Run the network without training.
LOG(ERROR) << "Performing Forward";
Caffe::set_mode(Caffe::CPU);
}
- NetParameter test_net_param;
- ReadProtoFromTextFileOrDie(argv[1], &test_net_param);
- Net<float> caffe_test_net(test_net_param);
+ Net<float> caffe_test_net(argv[1]);
NetParameter trained_net_param;
ReadProtoFromBinaryFile(argv[2], &trained_net_param);
caffe_test_net.CopyTrainedLayersFrom(trained_net_param);