make all tools backwards compatible with v0 net param
authorJeff Donahue <jeff.donahue@gmail.com>
Sun, 23 Mar 2014 21:26:31 +0000 (14:26 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 28 Mar 2014 06:42:29 +0000 (23:42 -0700)
include/caffe/net.hpp
include/caffe/util/io.hpp
src/caffe/layers/data_layer.cpp
src/caffe/layers/window_data_layer.cpp
src/caffe/net.cpp
src/caffe/util/io.cpp
tools/dump_network.cpp
tools/extract_features.cpp
tools/net_speed_benchmark.cpp
tools/test_net.cpp

index b97c1cc..7fe2c5b 100644 (file)
@@ -58,6 +58,9 @@ class Net {
   // trained layers from another net parameter instance.
   void CopyTrainedLayersFrom(const NetParameter& param);
   void CopyTrainedLayersFrom(const string trained_filename);
+  // Read parameters from a file into a NetParameter proto message.
+  void ReadParamsFromTextFile(const string& param_file, NetParameter* param);
+  void ReadParamsFromBinaryFile(const string& param_file, NetParameter* param);
   // Writes the net to a proto.
   void ToProto(NetParameter* param, bool write_diff = false);
 
index 11201c1..056b573 100644 (file)
@@ -38,13 +38,22 @@ inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
   WriteProtoToTextFile(proto, filename.c_str());
 }
 
-void ReadProtoFromBinaryFile(const char* filename,
-    Message* proto);
-inline void ReadProtoFromBinaryFile(const string& filename,
-    Message* proto) {
-  ReadProtoFromBinaryFile(filename.c_str(), proto);
+bool ReadProtoFromBinaryFile(const char* filename, Message* proto);
+
+inline bool ReadProtoFromBinaryFile(const string& filename, Message* proto) {
+  return ReadProtoFromBinaryFile(filename.c_str(), proto);
+}
+
+inline void ReadProtoFromBinaryFileOrDie(const char* filename, Message* proto) {
+  CHECK(ReadProtoFromBinaryFile(filename, proto));
 }
 
+inline void ReadProtoFromBinaryFileOrDie(const string& filename,
+                                         Message* proto) {
+  ReadProtoFromBinaryFileOrDie(filename.c_str(), proto);
+}
+
+
 void WriteProtoToBinaryFile(const Message& proto, const char* filename);
 inline void WriteProtoToBinaryFile(
     const Message& proto, const string& filename) {
index 237ce37..399f771 100644 (file)
@@ -190,11 +190,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_GT(datum_width_, crop_size);
   // check if we want to have mean
   if (this->layer_param_.data_param().has_mean_file()) {
+    const string& mean_file = this->layer_param_.data_param().mean_file();
+    LOG(INFO) << "Loading mean file from" << mean_file;
     BlobProto blob_proto;
-    LOG(INFO) << "Loading mean file from"
-              << this->layer_param_.data_param().mean_file();
-    ReadProtoFromBinaryFile(this->layer_param_.data_param().mean_file().c_str(),
-                            &blob_proto);
+    ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
     data_mean_.FromProto(blob_proto);
     CHECK_EQ(data_mean_.num(), 1);
     CHECK_EQ(data_mean_.channels(), datum_channels_);
index 6f60175..7c3e24b 100644 (file)
@@ -388,7 +388,7 @@ void WindowDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
         this->layer_param_.window_data_param().mean_file();
     LOG(INFO) << "Loading mean file from" << mean_file;
     BlobProto blob_proto;
-    ReadProtoFromBinaryFile(mean_file, &blob_proto);
+    ReadProtoFromBinaryFileOrDie(mean_file, &blob_proto);
     data_mean_.FromProto(blob_proto);
     CHECK_EQ(data_mean_.num(), 1);
     CHECK_EQ(data_mean_.width(), data_mean_.height());
index 3483509..405cf1b 100644 (file)
@@ -28,21 +28,7 @@ Net<Dtype>::Net(const NetParameter& param) {
 template <typename Dtype>
 Net<Dtype>::Net(const string& param_file) {
   NetParameter param;
-  if (!ReadProtoFromTextFile(param_file, &param)) {
-    // Failed to parse file as NetParameter; try to parse as a V0NetParameter
-    // instead.
-    V0NetParameter v0_param;
-    CHECK(ReadProtoFromTextFile(param_file, &v0_param))
-        << "Failed to parse NetParameter file: " << param_file;
-    LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file;
-    LOG(ERROR) << "Note that future Caffe releases will not support "
-        << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade "
-        << "this and any other network proto files to the new format.";
-    if (!UpgradeV0Net(v0_param, &param)) {
-      LOG(ERROR) << "Warning: had one or more problems upgrading "
-          << "V0NetParameter to NetParameter (see above); continuing anyway.";
-    }
-  }
+  ReadParamsFromTextFile(param_file, &param);
   Init(param);
 }
 
@@ -317,11 +303,51 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
 template <typename Dtype>
 void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
   NetParameter param;
-  ReadProtoFromBinaryFile(trained_filename, &param);
+  ReadParamsFromBinaryFile(trained_filename, &param);
   CopyTrainedLayersFrom(param);
 }
 
 template <typename Dtype>
+void Net<Dtype>::ReadParamsFromTextFile(const string& param_file,
+                                        NetParameter* param) {
+  if (!ReadProtoFromTextFile(param_file, param)) {
+    // Failed to parse file as NetParameter; try to parse as a V0NetParameter
+    // instead.
+    V0NetParameter v0_param;
+    CHECK(ReadProtoFromTextFile(param_file, &v0_param))
+        << "Failed to parse NetParameter file: " << param_file;
+    LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file;
+    LOG(ERROR) << "Note that future Caffe releases will not support "
+        << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade "
+        << "this and any other network proto files to the new format.";
+    if (!UpgradeV0Net(v0_param, param)) {
+      LOG(ERROR) << "Warning: had one or more problems upgrading "
+          << "V0NetParameter to NetParameter (see above); continuing anyway.";
+    }
+  }
+}
+
+template <typename Dtype>
+void Net<Dtype>::ReadParamsFromBinaryFile(const string& param_file,
+                                          NetParameter* param) {
+  if (!ReadProtoFromBinaryFile(param_file, param)) {
+    // Failed to parse file as NetParameter; try to parse as a V0NetParameter
+    // instead.
+    V0NetParameter v0_param;
+    CHECK(ReadProtoFromBinaryFile(param_file, &v0_param))
+        << "Failed to parse NetParameter file: " << param_file;
+    LOG(ERROR) << "Parsed file as V0NetParameter: " << param_file;
+    LOG(ERROR) << "Note that future Caffe releases will not support "
+        << "V0NetParameter; use ./build/tools/upgrade_net_proto.bin to upgrade "
+        << "this and any other network proto files to the new format.";
+    if (!UpgradeV0Net(v0_param, param)) {
+      LOG(ERROR) << "Warning: had one or more problems upgrading "
+          << "V0NetParameter to NetParameter (see above); continuing anyway.";
+    }
+  }
+}
+
+template <typename Dtype>
 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
   param->Clear();
   param->set_name(name_);
index fe45267..e1e3c3a 100644 (file)
@@ -51,18 +51,19 @@ void WriteProtoToTextFile(const Message& proto, const char* filename) {
   close(fd);
 }
 
-void ReadProtoFromBinaryFile(const char* filename, Message* proto) {
+bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
   int fd = open(filename, O_RDONLY);
   CHECK_NE(fd, -1) << "File not found: " << filename;
   ZeroCopyInputStream* raw_input = new FileInputStream(fd);
   CodedInputStream* coded_input = new CodedInputStream(raw_input);
   coded_input->SetTotalBytesLimit(536870912, 268435456);
 
-  CHECK(proto->ParseFromCodedStream(coded_input));
+  bool success = proto->ParseFromCodedStream(coded_input);
 
   delete coded_input;
   delete raw_input;
   close(fd);
+  return success;
 }
 
 void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
index cee8bc8..f29e150 100644 (file)
@@ -31,16 +31,14 @@ int main(int argc, char** argv) {
   Caffe::set_mode(Caffe::GPU);
   Caffe::set_phase(Caffe::TEST);
 
-  NetParameter net_param;
-  NetParameter trained_net_param;
-
+  shared_ptr<Net<float> > caffe_net;
   if (strcmp(argv[1], "none") == 0) {
     // We directly load the net param from trained file
-    ReadProtoFromBinaryFile(argv[2], &net_param);
+    caffe_net.reset(new Net<float>(argv[2]));
   } else {
-    ReadProtoFromTextFileOrDie(argv[1], &net_param);
+    caffe_net.reset(new Net<float>(argv[1]));
   }
-  ReadProtoFromBinaryFile(argv[2], &trained_net_param);
+  caffe_net->CopyTrainedLayersFrom(argv[2]);
 
   vector<Blob<float>* > input_vec;
   shared_ptr<Blob<float> > input_blob(new Blob<float>());
@@ -51,9 +49,6 @@ int main(int argc, char** argv) {
     input_vec.push_back(input_blob.get());
   }
 
-  shared_ptr<Net<float> > caffe_net(new Net<float>(net_param));
-  caffe_net->CopyTrainedLayersFrom(trained_net_param);
-
   string output_prefix(argv[4]);
   // Run the network without training.
   LOG(ERROR) << "Performing Forward";
index 4274db4..0393826 100644 (file)
@@ -56,12 +56,8 @@ int feature_extraction_pipeline(int argc, char** argv) {
   }
   Caffe::set_phase(Caffe::TEST);
 
-  NetParameter pretrained_net_param;
-
   arg_pos = 0;  // the name of the executable
   string pretrained_binary_proto(argv[++arg_pos]);
-  ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
-                          &pretrained_net_param);
 
   // Expected prototxt contains at least one data layer such as
   //  the layer data_layer_name and one feature blob such as the
@@ -90,13 +86,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
    top: "fc7"
    }
    */
-  NetParameter feature_extraction_net_param;
   string feature_extraction_proto(argv[++arg_pos]);
-  ReadProtoFromTextFile(feature_extraction_proto,
-                        &feature_extraction_net_param);
   shared_ptr<Net<Dtype> > feature_extraction_net(
-      new Net<Dtype>(feature_extraction_net_param));
-  feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);
+      new Net<Dtype>(feature_extraction_proto));
+  feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
 
   string extract_feature_blob_name(argv[++arg_pos]);
   CHECK(feature_extraction_net->has_blob(extract_feature_blob_name))
index d6c0cdb..36a0077 100644 (file)
@@ -49,10 +49,7 @@ int main(int argc, char** argv) {
   }
 
   Caffe::set_phase(Caffe::TRAIN);
-  NetParameter net_param;
-  ReadProtoFromTextFileOrDie(argv[1],
-      &net_param);
-  Net<float> caffe_net(net_param);
+  Net<float> caffe_net(argv[1]);
 
   // Run the network without training.
   LOG(ERROR) << "Performing Forward";
index 0b49e54..559fa73 100644 (file)
@@ -34,9 +34,7 @@ int main(int argc, char** argv) {
     Caffe::set_mode(Caffe::CPU);
   }
 
-  NetParameter test_net_param;
-  ReadProtoFromTextFileOrDie(argv[1], &test_net_param);
-  Net<float> caffe_test_net(test_net_param);
+  Net<float> caffe_test_net(argv[1]);
   NetParameter trained_net_param;
   ReadProtoFromBinaryFile(argv[2], &trained_net_param);
   caffe_test_net.CopyTrainedLayersFrom(trained_net_param);