From be8d83603c559cb46170442c3c0cbd663492db05 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sat, 15 Mar 2014 02:33:40 -0700 Subject: [PATCH] make ReadProtoFromTextFile not die on parse failure; add ReadProtoFromTextFileOrDie which has the old functionality --- include/caffe/util/io.hpp | 17 ++++++++++++----- include/caffe/util/upgrade_proto.hpp | 1 - src/caffe/net.cpp | 2 +- src/caffe/solver.cpp | 4 ++-- src/caffe/util/io.cpp | 7 ++++--- tools/dump_network.cpp | 2 +- tools/finetune_net.cpp | 2 +- tools/net_speed_benchmark.cpp | 2 +- tools/test_net.cpp | 2 +- tools/train_net.cpp | 2 +- 10 files changed, 24 insertions(+), 17 deletions(-) diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 89f9c18..11201c1 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -19,11 +19,18 @@ using ::google::protobuf::Message; namespace caffe { -void ReadProtoFromTextFile(const char* filename, - Message* proto); -inline void ReadProtoFromTextFile(const string& filename, - Message* proto) { - ReadProtoFromTextFile(filename.c_str(), proto); +bool ReadProtoFromTextFile(const char* filename, Message* proto); + +inline bool ReadProtoFromTextFile(const string& filename, Message* proto) { + return ReadProtoFromTextFile(filename.c_str(), proto); +} + +inline void ReadProtoFromTextFileOrDie(const char* filename, Message* proto) { + CHECK(ReadProtoFromTextFile(filename, proto)); +} + +inline void ReadProtoFromTextFileOrDie(const string& filename, Message* proto) { + ReadProtoFromTextFileOrDie(filename.c_str(), proto); } void WriteProtoToTextFile(const Message& proto, const char* filename); diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp index 8b93829..ed9a829 100644 --- a/include/caffe/util/upgrade_proto.hpp +++ b/include/caffe/util/upgrade_proto.hpp @@ -13,7 +13,6 @@ #include "caffe/blob.hpp" using std::string; -// using ::google::protobuf::Message; namespace caffe { diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 8a95a3d..7310a1a 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -25,7 +25,7 @@ Net::Net(const NetParameter& param) { template Net::Net(const string& param_file) { NetParameter param; - ReadProtoFromTextFile(param_file, ¶m); + ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); } diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 5fa150c..6d31290 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -22,13 +22,13 @@ Solver::Solver(const SolverParameter& param) : param_(param), net_(), test_net_() { // Scaffolding code NetParameter train_net_param; - ReadProtoFromTextFile(param_.train_net(), &train_net_param); + ReadProtoFromTextFileOrDie(param_.train_net(), &train_net_param); LOG(INFO) << "Creating training net."; net_.reset(new Net(train_net_param)); if (param_.has_test_net()) { LOG(INFO) << "Creating testing net."; NetParameter test_net_param; - ReadProtoFromTextFile(param_.test_net(), &test_net_param); + ReadProtoFromTextFileOrDie(param_.test_net(), &test_net_param); test_net_.reset(new Net(test_net_param)); CHECK_GT(param_.test_iter(), 0); CHECK_GT(param_.test_interval(), 0); diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index fdad21d..fe45267 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -29,17 +29,18 @@ using google::protobuf::io::ZeroCopyInputStream; using google::protobuf::io::CodedInputStream; using google::protobuf::io::ZeroCopyOutputStream; using google::protobuf::io::CodedOutputStream; +using google::protobuf::Message; namespace caffe { -void ReadProtoFromTextFile(const char* filename, - ::google::protobuf::Message* proto) { +bool ReadProtoFromTextFile(const char* filename, Message* proto) { int fd = open(filename, O_RDONLY); CHECK_NE(fd, -1) << "File not found: " << filename; FileInputStream* input = new FileInputStream(fd); - CHECK(google::protobuf::TextFormat::Parse(input, proto)); + bool success = google::protobuf::TextFormat::Parse(input, proto); delete input; close(fd); + return success; } void WriteProtoToTextFile(const Message& proto, const char* filename) { diff --git a/tools/dump_network.cpp b/tools/dump_network.cpp index f5c3682..cee8bc8 100644 --- a/tools/dump_network.cpp +++ b/tools/dump_network.cpp @@ -38,7 +38,7 @@ int main(int argc, char** argv) { // We directly load the net param from trained file ReadProtoFromBinaryFile(argv[2], &net_param); } else { - ReadProtoFromTextFile(argv[1], &net_param); + ReadProtoFromTextFileOrDie(argv[1], &net_param); } ReadProtoFromBinaryFile(argv[2], &trained_net_param); diff --git a/tools/finetune_net.cpp b/tools/finetune_net.cpp index db96b02..c1cd788 100644 --- a/tools/finetune_net.cpp +++ b/tools/finetune_net.cpp @@ -20,7 +20,7 @@ int main(int argc, char** argv) { } SolverParameter solver_param; - ReadProtoFromTextFile(argv[1], &solver_param); + ReadProtoFromTextFileOrDie(argv[1], &solver_param); LOG(INFO) << "Starting Optimization"; SGDSolver solver(solver_param); diff --git a/tools/net_speed_benchmark.cpp b/tools/net_speed_benchmark.cpp index f52aac9..d6c0cdb 100644 --- a/tools/net_speed_benchmark.cpp +++ b/tools/net_speed_benchmark.cpp @@ -50,7 +50,7 @@ int main(int argc, char** argv) { Caffe::set_phase(Caffe::TRAIN); NetParameter net_param; - ReadProtoFromTextFile(argv[1], + ReadProtoFromTextFileOrDie(argv[1], &net_param); Net caffe_net(net_param); diff --git a/tools/test_net.cpp b/tools/test_net.cpp index 0abfbf6..0b49e54 100644 --- a/tools/test_net.cpp +++ b/tools/test_net.cpp @@ -35,7 +35,7 @@ int main(int argc, char** argv) { } NetParameter test_net_param; - ReadProtoFromTextFile(argv[1], &test_net_param); + ReadProtoFromTextFileOrDie(argv[1], &test_net_param); Net caffe_test_net(test_net_param); NetParameter trained_net_param; ReadProtoFromBinaryFile(argv[2], &trained_net_param); diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 751a704..7c6f23e 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -21,7 +21,7 @@ int main(int argc, char** argv) { } SolverParameter solver_param; - ReadProtoFromTextFile(argv[1], &solver_param); + ReadProtoFromTextFileOrDie(argv[1], &solver_param); LOG(INFO) << "Starting Optimization"; SGDSolver solver(solver_param); -- 2.7.4