From 17a59c30485d3a3832533bf606a32fb827fb7ea9 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sun, 23 Mar 2014 14:26:31 -0700 Subject: [PATCH] make all tools backwards compatible with v0 net param --- include/caffe/net.hpp | 3 ++ include/caffe/util/io.hpp | 19 ++++++++--- src/caffe/layers/data_layer.cpp | 7 ++-- src/caffe/layers/window_data_layer.cpp | 2 +- src/caffe/net.cpp | 58 ++++++++++++++++++++++++---------- src/caffe/util/io.cpp | 5 +-- tools/dump_network.cpp | 13 +++----- tools/extract_features.cpp | 11 ++----- tools/net_speed_benchmark.cpp | 5 +-- tools/test_net.cpp | 4 +-- 10 files changed, 74 insertions(+), 53 deletions(-) diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index b97c1cc..7fe2c5b 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -58,6 +58,9 @@ class Net { // trained layers from another net parameter instance. void CopyTrainedLayersFrom(const NetParameter& param); void CopyTrainedLayersFrom(const string trained_filename); + // Read parameters from a file into a NetParameter proto message. + void ReadParamsFromTextFile(const string& param_file, NetParameter* param); + void ReadParamsFromBinaryFile(const string& param_file, NetParameter* param); // Writes the net to a proto. void ToProto(NetParameter* param, bool write_diff = false); diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 11201c1..056b573 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -38,13 +38,22 @@ inline void WriteProtoToTextFile(const Message& proto, const string& filename) { WriteProtoToTextFile(proto, filename.c_str()); } -void ReadProtoFromBinaryFile(const char* filename, - Message* proto); -inline void ReadProtoFromBinaryFile(const string& filename, - Message* proto) { - ReadProtoFromBinaryFile(filename.c_str(), proto); +bool ReadProtoFromBinaryFile(const char* filename, Message* proto); + +inline bool ReadProtoFromBinaryFile(const string& filename, Message* proto) { + return ReadProtoFromBinaryFile(filename.c_str(), proto); +} + +inline void ReadProtoFromBinaryFileOrDie(const char* filename, Message* proto) { + CHECK(ReadProtoFromBinaryFile(filename, proto)); } +inline void ReadProtoFromBinaryFileOrDie(const string& filename, + Message* proto) { + ReadProtoFromBinaryFileOrDie(filename.c_str(), proto); +} + + void WriteProtoToBinaryFile(const Message& proto, const char* filename); inline void WriteProtoToBinaryFile( const Message& proto, const string& filename) { diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 237ce37..399f771 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -190,11 +190,10 @@ void DataLayer::SetUp(const vector*>& bottom, CHECK_GT(datum_width_, crop_size); // check if we want to have mean if (this->layer_param_.data_param().has_mean_file()) { + const string& mean_file = this->layer_param_.data_param().mean_file(); + LOG(INFO) << "Loading mean file from" << mean_file; BlobProto blob_proto; - LOG(INFO) << "Loading mean file from" - << this->layer_param_.data_param().mean_file(); - ReadProtoFromBinaryFile(this->layer_param_.data_param().mean_file().c_str(), - &blob_proto); + ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); CHECK_EQ(data_mean_.num(), 1); CHECK_EQ(data_mean_.channels(), datum_channels_); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index 6f60175..7c3e24b 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -388,7 +388,7 @@ void WindowDataLayer::SetUp(const vector*>& bottom, this->layer_param_.window_data_param().mean_file(); LOG(INFO) << "Loading mean file from" << mean_file; BlobProto blob_proto; - ReadProtoFromBinaryFile(mean_file, &blob_proto); + ReadProtoFromBinaryFileOrDie(mean_file, &blob_proto); data_mean_.FromProto(blob_proto); CHECK_EQ(data_mean_.num(), 1); CHECK_EQ(data_mean_.width(), data_mean_.height()); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 3483509..405cf1b 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -28,21 +28,7 @@ Net::Net(const NetParameter& param) { template Net::Net(const string& param_file) { NetParameter param; - if (!ReadProtoFromTextFile(param_file, ¶m)) { - // Failed to parse file as NetParameter; try to parse as a V0NetParameter - // instead. - V0NetParameter v0_param; - CHECK(ReadProtoFromTextFile(param_file, &v0_param)) - << "Failed to parse NetParameter file: " << param_file; - LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file; - LOG(ERROR) << "Note that future Caffe releases will not support " - << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade " - << "this and any other network proto files to the new format."; - if (!UpgradeV0Net(v0_param, ¶m)) { - LOG(ERROR) << "Warning: had one or more problems upgrading " - << "V0NetParameter to NetParameter (see above); continuing anyway."; - } - } + ReadParamsFromTextFile(param_file, ¶m); Init(param); } @@ -317,11 +303,51 @@ void Net::CopyTrainedLayersFrom(const NetParameter& param) { template void Net::CopyTrainedLayersFrom(const string trained_filename) { NetParameter param; - ReadProtoFromBinaryFile(trained_filename, ¶m); + ReadParamsFromBinaryFile(trained_filename, ¶m); CopyTrainedLayersFrom(param); } template +void Net::ReadParamsFromTextFile(const string& param_file, + NetParameter* param) { + if (!ReadProtoFromTextFile(param_file, param)) { + // Failed to parse file as NetParameter; try to parse as a V0NetParameter + // instead. + V0NetParameter v0_param; + CHECK(ReadProtoFromTextFile(param_file, &v0_param)) + << "Failed to parse NetParameter file: " << param_file; + LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file; + LOG(ERROR) << "Note that future Caffe releases will not support " + << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade " + << "this and any other network proto files to the new format."; + if (!UpgradeV0Net(v0_param, param)) { + LOG(ERROR) << "Warning: had one or more problems upgrading " + << "V0NetParameter to NetParameter (see above); continuing anyway."; + } + } +} + +template +void Net::ReadParamsFromBinaryFile(const string& param_file, + NetParameter* param) { + if (!ReadProtoFromBinaryFile(param_file, param)) { + // Failed to parse file as NetParameter; try to parse as a V0NetParameter + // instead. + V0NetParameter v0_param; + CHECK(ReadProtoFromBinaryFile(param_file, &v0_param)) + << "Failed to parse NetParameter file: " << param_file; + LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file; + LOG(ERROR) << "Note that future Caffe releases will not support " + << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade " + << "this and any other network proto files to the new format."; + if (!UpgradeV0Net(v0_param, param)) { + LOG(ERROR) << "Warning: had one or more problems upgrading " + << "V0NetParameter to NetParameter (see above); continuing anyway."; + } + } +} + +template void Net::ToProto(NetParameter* param, bool write_diff) { param->Clear(); param->set_name(name_); diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index fe45267..e1e3c3a 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -51,18 +51,19 @@ void WriteProtoToTextFile(const Message& proto, const char* filename) { close(fd); } -void ReadProtoFromBinaryFile(const char* filename, Message* proto) { +bool ReadProtoFromBinaryFile(const char* filename, Message* proto) { int fd = open(filename, O_RDONLY); CHECK_NE(fd, -1) << "File not found: " << filename; ZeroCopyInputStream* raw_input = new FileInputStream(fd); CodedInputStream* coded_input = new CodedInputStream(raw_input); coded_input->SetTotalBytesLimit(536870912, 268435456); - CHECK(proto->ParseFromCodedStream(coded_input)); + bool success = proto->ParseFromCodedStream(coded_input); delete coded_input; delete raw_input; close(fd); + return success; } void WriteProtoToBinaryFile(const Message& proto, const char* filename) { diff --git a/tools/dump_network.cpp b/tools/dump_network.cpp index cee8bc8..f29e150 100644 --- a/tools/dump_network.cpp +++ b/tools/dump_network.cpp @@ -31,16 +31,14 @@ int main(int argc, char** argv) { Caffe::set_mode(Caffe::GPU); Caffe::set_phase(Caffe::TEST); - NetParameter net_param; - NetParameter trained_net_param; - + shared_ptr > caffe_net; if (strcmp(argv[1], "none") == 0) { // We directly load the net param from trained file - ReadProtoFromBinaryFile(argv[2], &net_param); + caffe_net.reset(new Net(argv[2])); } else { - ReadProtoFromTextFileOrDie(argv[1], &net_param); + caffe_net.reset(new Net(argv[1])); } - ReadProtoFromBinaryFile(argv[2], &trained_net_param); + caffe_net->CopyTrainedLayersFrom(argv[2]); vector* > input_vec; shared_ptr > input_blob(new Blob()); @@ -51,9 +49,6 @@ int main(int argc, char** argv) { input_vec.push_back(input_blob.get()); } - shared_ptr > caffe_net(new Net(net_param)); - caffe_net->CopyTrainedLayersFrom(trained_net_param); - string output_prefix(argv[4]); // Run the network without training. LOG(ERROR) << "Performing Forward"; diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 4274db4..0393826 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -56,12 +56,8 @@ int feature_extraction_pipeline(int argc, char** argv) { } Caffe::set_phase(Caffe::TEST); - NetParameter pretrained_net_param; - arg_pos = 0; // the name of the executable string pretrained_binary_proto(argv[++arg_pos]); - ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(), - &pretrained_net_param); // Expected prototxt contains at least one data layer such as // the layer data_layer_name and one feature blob such as the @@ -90,13 +86,10 @@ int feature_extraction_pipeline(int argc, char** argv) { top: "fc7" } */ - NetParameter feature_extraction_net_param; string feature_extraction_proto(argv[++arg_pos]); - ReadProtoFromTextFile(feature_extraction_proto, - &feature_extraction_net_param); shared_ptr > feature_extraction_net( - new Net(feature_extraction_net_param)); - feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param); + new Net(feature_extraction_proto)); + feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); string extract_feature_blob_name(argv[++arg_pos]); CHECK(feature_extraction_net->has_blob(extract_feature_blob_name)) diff --git a/tools/net_speed_benchmark.cpp b/tools/net_speed_benchmark.cpp index d6c0cdb..36a0077 100644 --- a/tools/net_speed_benchmark.cpp +++ b/tools/net_speed_benchmark.cpp @@ -49,10 +49,7 @@ int main(int argc, char** argv) { } Caffe::set_phase(Caffe::TRAIN); - NetParameter net_param; - ReadProtoFromTextFileOrDie(argv[1], - &net_param); - Net caffe_net(net_param); + Net caffe_net(argv[1]); // Run the network without training. LOG(ERROR) << "Performing Forward"; diff --git a/tools/test_net.cpp b/tools/test_net.cpp index 0b49e54..559fa73 100644 --- a/tools/test_net.cpp +++ b/tools/test_net.cpp @@ -34,9 +34,7 @@ int main(int argc, char** argv) { Caffe::set_mode(Caffe::CPU); } - NetParameter test_net_param; - ReadProtoFromTextFileOrDie(argv[1], &test_net_param); - Net caffe_test_net(test_net_param); + Net caffe_test_net(argv[1]); NetParameter trained_net_param; ReadProtoFromBinaryFile(argv[2], &trained_net_param); caffe_test_net.CopyTrainedLayersFrom(trained_net_param); -- 2.7.4