Snapshot model weights/solver state to HDF5 files.
authorEric Tzeng <etzeng@eecs.berkeley.edu>
Wed, 22 Jul 2015 23:17:01 +0000 (16:17 -0700)
committerEric Tzeng <etzeng@eecs.berkeley.edu>
Fri, 7 Aug 2015 21:56:38 +0000 (14:56 -0700)
Summary of changes:
- HDF5 helper functions were moved into a separate file util/hdf5.cpp
- hdf5_save_nd_dataset now saves n-d blobs, can save diffs instead of
  data
- Minor fix for memory leak in HDF5 functions (delete instead of
  delete[])
- Extra methods have been added to both Net/Solver enabling
  snapshotting and restoring from HDF5 files
- snapshot_format was added to SolverParameters, with possible values
  HDF5 or BINARYPROTO (default HDF5)
- kMaxBlobAxes was reduced to 32 to match the limitations of HDF5

14 files changed:
include/caffe/blob.hpp
include/caffe/net.hpp
include/caffe/solver.hpp
include/caffe/util/hdf5.hpp [new file with mode: 0644]
include/caffe/util/io.hpp
src/caffe/layers/hdf5_data_layer.cpp
src/caffe/layers/hdf5_output_layer.cpp
src/caffe/layers/hdf5_output_layer.cu
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp
src/caffe/test/test_hdf5_output_layer.cpp
src/caffe/util/hdf5.cpp [new file with mode: 0644]
src/caffe/util/io.cpp

index 472cc18..9b813e7 100644 (file)
@@ -10,7 +10,7 @@
 #include "caffe/syncedmem.hpp"
 #include "caffe/util/math_functions.hpp"
 
-const int kMaxBlobAxes = INT_MAX;
+const int kMaxBlobAxes = 32;
 
 namespace caffe {
 
index 5665df1..dfd2e55 100644 (file)
@@ -98,8 +98,12 @@ class Net {
    */
   void CopyTrainedLayersFrom(const NetParameter& param);
   void CopyTrainedLayersFrom(const string trained_filename);
+  void CopyTrainedLayersFromBinaryProto(const string trained_filename);
+  void CopyTrainedLayersFromHDF5(const string trained_filename);
   /// @brief Writes the net to a proto.
   void ToProto(NetParameter* param, bool write_diff = false) const;
+  /// @brief Writes the net to an HDF5 file.
+  void ToHDF5(const string& filename, bool write_diff = false) const;
 
   /// @brief returns the network name.
   inline const string& name() const { return name_; }
index c2ced48..703434b 100644 (file)
@@ -27,9 +27,9 @@ class Solver {
   virtual void Solve(const char* resume_file = NULL);
   inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
   void Step(int iters);
-  // The Restore function implements how one should restore the solver to a
-  // previously snapshotted state. You should implement the RestoreSolverState()
-  // function that restores the state from a SolverState protocol buffer.
+  // The Restore method simply dispatches to one of the
+  // RestoreSolverStateFrom___ protected methods. You should implement these
+  // methods to restore the state from the appropriate snapshot type.
   void Restore(const char* resume_file);
   virtual ~Solver() {}
   inline shared_ptr<Net<Dtype> > net() { return net_; }
@@ -46,11 +46,15 @@ class Solver {
   // function that produces a SolverState protocol buffer that needs to be
   // written to disk together with the learned net.
   void Snapshot();
+  string SnapshotFilename(const string extension);
+  string SnapshotToBinaryProto();
+  string SnapshotToHDF5();
   // The test routine
   void TestAll();
   void Test(const int test_net_id = 0);
-  virtual void SnapshotSolverState(SolverState* state) = 0;
-  virtual void RestoreSolverState(const SolverState& state) = 0;
+  virtual void SnapshotSolverState(const string& model_filename) = 0;
+  virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
+  virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
   void DisplayOutputBlobs(const int net_id);
 
   SolverParameter param_;
@@ -85,8 +89,11 @@ class SGDSolver : public Solver<Dtype> {
   virtual void Regularize(int param_id);
   virtual void ComputeUpdateValue(int param_id, Dtype rate);
   virtual void ClipGradients();
-  virtual void SnapshotSolverState(SolverState * state);
-  virtual void RestoreSolverState(const SolverState& state);
+  virtual void SnapshotSolverState(const string& model_filename);
+  virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
+  virtual void SnapshotSolverStateToHDF5(const string& model_filename);
+  virtual void RestoreSolverStateFromHDF5(const string& state_file);
+  virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
   // history maintains the historical momentum data.
   // update maintains update related data and is not needed in snapshots.
   // temp maintains other information that might be needed in computation
diff --git a/include/caffe/util/hdf5.hpp b/include/caffe/util/hdf5.hpp
new file mode 100644 (file)
index 0000000..ce568c5
--- /dev/null
@@ -0,0 +1,39 @@
+#ifndef CAFFE_UTIL_HDF5_H_
+#define CAFFE_UTIL_HDF5_H_
+
+#include <string>
+
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
+#include "caffe/blob.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void hdf5_load_nd_dataset_helper(
+    hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+    Blob<Dtype>* blob);
+
+template <typename Dtype>
+void hdf5_load_nd_dataset(
+    hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+    Blob<Dtype>* blob);
+
+template <typename Dtype>
+void hdf5_save_nd_dataset(
+    const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob,
+    bool write_diff = false);
+
+int hdf5_load_int(hid_t loc_id, const string& dataset_name);
+void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i);
+string hdf5_load_string(hid_t loc_id, const string& dataset_name);
+void hdf5_save_string(hid_t loc_id, const string& dataset_name,
+                      const string& s);
+
+int hdf5_get_num_links(hid_t loc_id);
+string hdf5_get_name_by_idx(hid_t loc_id, int idx);
+
+}  // namespace caffe
+
+#endif   // CAFFE_UTIL_HDF5_H_
index 3a62c3c..c0938ad 100644 (file)
@@ -5,15 +5,11 @@
 #include <string>
 
 #include "google/protobuf/message.h"
-#include "hdf5.h"
-#include "hdf5_hl.h"
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
 
-#define HDF5_NUM_DIMS 4
-
 namespace caffe {
 
 using ::google::protobuf::Message;
@@ -140,20 +136,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
 
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
 
-template <typename Dtype>
-void hdf5_load_nd_dataset_helper(
-    hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob);
-
-template <typename Dtype>
-void hdf5_load_nd_dataset(
-    hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob);
-
-template <typename Dtype>
-void hdf5_save_nd_dataset(
-    const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob);
-
 }  // namespace caffe
 
 #endif   // CAFFE_UTIL_IO_H_
index 8a782f7..8ced510 100644 (file)
@@ -16,7 +16,7 @@ TODO:
 
 #include "caffe/data_layers.hpp"
 #include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
+#include "caffe/util/hdf5.hpp"
 
 namespace caffe {
 
index f63375c..56788c2 100644 (file)
@@ -6,7 +6,7 @@
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
+#include "caffe/util/hdf5.hpp"
 #include "caffe/vision_layers.hpp"
 
 namespace caffe {
index ae497c3..eb6d0e4 100644 (file)
@@ -6,7 +6,6 @@
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
 #include "caffe/vision_layers.hpp"
 
 namespace caffe {
index a18ee63..0812b36 100644 (file)
@@ -5,12 +5,14 @@
 #include <utility>
 #include <vector>
 
+#include "hdf5.h"
+
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/net.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/hdf5.hpp"
 #include "caffe/util/insert_splits.hpp"
-#include "caffe/util/io.hpp"
 #include "caffe/util/math_functions.hpp"
 #include "caffe/util/upgrade_proto.hpp"
 
@@ -747,12 +749,73 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
 
 template <typename Dtype>
 void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
+  if (trained_filename.size() >= 3 &&
+      trained_filename.compare(trained_filename.size() - 3, 3, ".h5") == 0) {
+    CopyTrainedLayersFromHDF5(trained_filename);
+  } else {
+    CopyTrainedLayersFromBinaryProto(trained_filename);
+  }
+}
+
+template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFromBinaryProto(
+    const string trained_filename) {
   NetParameter param;
   ReadNetParamsFromBinaryFileOrDie(trained_filename, &param);
   CopyTrainedLayersFrom(param);
 }
 
 template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
+  hid_t file_hid = H5Fopen(trained_filename.c_str(), H5F_ACC_RDONLY,
+                           H5P_DEFAULT);
+  CHECK_GE(file_hid, 0) << "Couldn't open " << trained_filename;
+  hid_t data_hid = H5Gopen2(file_hid, "data", H5P_DEFAULT);
+  CHECK_GE(data_hid, 0) << "Error reading weights from " << trained_filename;
+  int num_layers = hdf5_get_num_links(data_hid);
+  for (int i = 0; i < num_layers; ++i) {
+    string source_layer_name = hdf5_get_name_by_idx(data_hid, i);
+    if (!layer_names_index_.count(source_layer_name)) {
+      DLOG(INFO) << "Ignoring source layer " << source_layer_name;
+      continue;
+    }
+    int target_layer_id = layer_names_index_[source_layer_name];
+    DLOG(INFO) << "Copying source layer " << source_layer_name;
+    vector<shared_ptr<Blob<Dtype> > >& target_blobs =
+        layers_[target_layer_id]->blobs();
+    hid_t layer_hid = H5Gopen2(data_hid, source_layer_name.c_str(),
+        H5P_DEFAULT);
+    CHECK_GE(layer_hid, 0)
+        << "Error reading weights from " << trained_filename;
+    // Check that source layer doesn't have more params than target layer
+    int num_source_params = hdf5_get_num_links(layer_hid);
+    CHECK_LE(num_source_params, target_blobs.size())
+        << "Incompatible number of blobs for layer " << source_layer_name;
+    for (int j = 0; j < target_blobs.size(); ++j) {
+      ostringstream oss;
+      oss << j;
+      string dataset_name = oss.str();
+      int target_net_param_id = param_id_vecs_[target_layer_id][j];
+      if (!H5Lexists(layer_hid, dataset_name.c_str(), H5P_DEFAULT)) {
+        // Target param doesn't exist in source weights...
+        if (param_owners_[target_net_param_id] != -1) {
+          // ...but it's weight-shared in target, so that's fine.
+          continue;
+        } else {
+          LOG(FATAL) << "Incompatible number of blobs for layer "
+              << source_layer_name;
+        }
+      }
+      hdf5_load_nd_dataset(layer_hid, dataset_name.c_str(), 0, kMaxBlobAxes,
+          target_blobs[j].get());
+    }
+    H5Gclose(layer_hid);
+  }
+  H5Gclose(data_hid);
+  H5Fclose(file_hid);
+}
+
+template <typename Dtype>
 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
   param->Clear();
   param->set_name(name_);
@@ -774,6 +837,63 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
 }
 
 template <typename Dtype>
+void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
+  hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
+      H5P_DEFAULT);
+  CHECK_GE(file_hid, 0)
+      << "Couldn't open " << filename << " to save weights.";
+  hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
+      H5P_DEFAULT);
+  CHECK_GE(data_hid, 0) << "Error saving weights to " << filename << ".";
+  hid_t diff_hid = -1;
+  if (write_diff) {
+    diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
+        H5P_DEFAULT);
+    CHECK_GE(diff_hid, 0) << "Error saving weights to " << filename << ".";
+  }
+  for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
+    const LayerParameter& layer_param = layers_[layer_id]->layer_param();
+    string layer_name = layer_param.name();
+    hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
+        H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
+    CHECK_GE(layer_data_hid, 0)
+        << "Error saving weights to " << filename << ".";
+    hid_t layer_diff_hid = -1;
+    if (write_diff) {
+      layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
+          H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
+      CHECK_GE(layer_diff_hid, 0)
+          << "Error saving weights to " << filename << ".";
+    }
+    int num_params = layers_[layer_id]->blobs().size();
+    for (int param_id = 0; param_id < num_params; ++param_id) {
+      ostringstream dataset_name;
+      dataset_name << param_id;
+      const int net_param_id = param_id_vecs_[layer_id][param_id];
+      if (param_owners_[net_param_id] == -1) {
+        // Only save params that own themselves
+        hdf5_save_nd_dataset<Dtype>(layer_data_hid, dataset_name.str(),
+            *params_[net_param_id]);
+      }
+      if (write_diff) {
+        // Write diffs regardless of weight-sharing
+        hdf5_save_nd_dataset<Dtype>(layer_diff_hid, dataset_name.str(),
+            *params_[net_param_id], true);
+      }
+    }
+    H5Gclose(layer_data_hid);
+    if (write_diff) {
+      H5Gclose(layer_diff_hid);
+    }
+  }
+  H5Gclose(data_hid);
+  if (write_diff) {
+    H5Gclose(diff_hid);
+  }
+  H5Fclose(file_hid);
+}
+
+template <typename Dtype>
 void Net<Dtype>::Update() {
   // First, accumulate the diffs of any shared parameters into their owner's
   // diff. (Assumes that the learning rate, weight decay, etc. have already been
index 03daa80..96e975b 100644 (file)
@@ -98,7 +98,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 37 (last added: iter_size)
+// SolverParameter next available ID: 38 (last added: snapshot_format)
 message SolverParameter {
   //////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -175,6 +175,11 @@ message SolverParameter {
   // whether to snapshot diff in the results or not. Snapshotting diff will help
   // debugging but the final protocol buffer size will be much larger.
   optional bool snapshot_diff = 16 [default = false];
+  enum SnapshotFormat {
+    HDF5 = 0;
+    BINARYPROTO = 1;
+  }
+  optional SnapshotFormat snapshot_format = 37 [default = HDF5];
   // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
   enum SolverMode {
     CPU = 0;
index aabe0ed..7527113 100644 (file)
@@ -4,9 +4,13 @@
 #include <string>
 #include <vector>
 
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
 #include "caffe/net.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/solver.hpp"
+#include "caffe/util/hdf5.hpp"
 #include "caffe/util/io.hpp"
 #include "caffe/util/math_functions.hpp"
 #include "caffe/util/upgrade_proto.hpp"
@@ -348,42 +352,58 @@ void Solver<Dtype>::Test(const int test_net_id) {
 
 template <typename Dtype>
 void Solver<Dtype>::Snapshot() {
-  NetParameter net_param;
-  // For intermediate results, we will also dump the gradient values.
-  net_->ToProto(&net_param, param_.snapshot_diff());
+  string model_filename;
+  switch (param_.snapshot_format()) {
+    case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
+      model_filename = SnapshotToBinaryProto();
+      break;
+    case caffe::SolverParameter_SnapshotFormat_HDF5:
+      model_filename = SnapshotToHDF5();
+      break;
+    default:
+      LOG(FATAL) << "Unsupported snapshot format.";
+  }
+
+  SnapshotSolverState(model_filename);
+}
+
+template <typename Dtype>
+string Solver<Dtype>::SnapshotFilename(const string extension) {
   string filename(param_.snapshot_prefix());
-  string model_filename, snapshot_filename;
   const int kBufferSize = 20;
   char iter_str_buffer[kBufferSize];
   snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_);
-  filename += iter_str_buffer;
-  model_filename = filename + ".caffemodel";
-  LOG(INFO) << "Snapshotting to " << model_filename;
-  WriteProtoToBinaryFile(net_param, model_filename.c_str());
-  SolverState state;
-  SnapshotSolverState(&state);
-  state.set_iter(iter_);
-  state.set_learned_net(model_filename);
-  state.set_current_step(current_step_);
-  snapshot_filename = filename + ".solverstate";
-  LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
-  WriteProtoToBinaryFile(state, snapshot_filename.c_str());
+  return filename + iter_str_buffer + extension;
 }
 
 template <typename Dtype>
-void Solver<Dtype>::Restore(const char* state_file) {
-  SolverState state;
+string Solver<Dtype>::SnapshotToBinaryProto() {
+  string model_filename = SnapshotFilename(".caffemodel");
+  LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
   NetParameter net_param;
-  ReadProtoFromBinaryFile(state_file, &state);
-  if (state.has_learned_net()) {
-    ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
-    net_->CopyTrainedLayersFrom(net_param);
-  }
-  iter_ = state.iter();
-  current_step_ = state.current_step();
-  RestoreSolverState(state);
+  net_->ToProto(&net_param, param_.snapshot_diff());
+  WriteProtoToBinaryFile(net_param, model_filename);
+  return model_filename;
+}
+
+template <typename Dtype>
+string Solver<Dtype>::SnapshotToHDF5() {
+  string model_filename = SnapshotFilename(".caffemodel.h5");
+  LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
+  net_->ToHDF5(model_filename, param_.snapshot_diff());
+  return model_filename;
 }
 
+template <typename Dtype>
+void Solver<Dtype>::Restore(const char* state_file) {
+  string state_filename(state_file);
+  if (state_filename.size() >= 3 &&
+      state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
+    RestoreSolverStateFromHDF5(state_filename);
+  } else {
+    RestoreSolverStateFromBinaryProto(state_filename);
+  }
+}
 
 // Return the current learning rate. The currently implemented learning rate
 // policies are as follows:
@@ -618,17 +638,76 @@ void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
 }
 
 template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
-  state->clear_history();
+void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
+  switch (this->param_.snapshot_format()) {
+    case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
+      SnapshotSolverStateToBinaryProto(model_filename);
+      break;
+    case caffe::SolverParameter_SnapshotFormat_HDF5:
+      SnapshotSolverStateToHDF5(model_filename);
+      break;
+    default:
+      LOG(FATAL) << "Unsupported snapshot format.";
+  }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
+    const string& model_filename) {
+  SolverState state;
+  state.set_iter(this->iter_);
+  state.set_learned_net(model_filename);
+  state.set_current_step(this->current_step_);
+  state.clear_history();
   for (int i = 0; i < history_.size(); ++i) {
     // Add history
-    BlobProto* history_blob = state->add_history();
+    BlobProto* history_blob = state.add_history();
     history_[i]->ToProto(history_blob);
   }
+  string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
+  LOG(INFO)
+    << "Snapshotting solver state to binary proto file" << snapshot_filename;
+  WriteProtoToBinaryFile(state, snapshot_filename.c_str());
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
+    const string& model_filename) {
+  string snapshot_filename =
+      Solver<Dtype>::SnapshotFilename(".solverstate.h5");
+  LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
+  hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
+      H5P_DEFAULT, H5P_DEFAULT);
+  CHECK_GE(file_hid, 0)
+      << "Couldn't open " << snapshot_filename << " to save solver state.";
+  hdf5_save_int(file_hid, "iter", this->iter_);
+  hdf5_save_string(file_hid, "learned_net", model_filename);
+  hdf5_save_int(file_hid, "current_step", this->current_step_);
+  hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
+      H5P_DEFAULT);
+  CHECK_GE(history_hid, 0)
+      << "Error saving solver state to " << snapshot_filename << ".";
+  for (int i = 0; i < history_.size(); ++i) {
+    ostringstream oss;
+    oss << i;
+    hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
+  }
+  H5Gclose(history_hid);
+  H5Fclose(file_hid);
 }
 
 template <typename Dtype>
-void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
+void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
+    const string& state_file) {
+  SolverState state;
+  ReadProtoFromBinaryFile(state_file, &state);
+  this->iter_ = state.iter();
+  if (state.has_learned_net()) {
+    NetParameter net_param;
+    ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
+    this->net_->CopyTrainedLayersFrom(net_param);
+  }
+  this->current_step_ = state.current_step();
   CHECK_EQ(state.history_size(), history_.size())
       << "Incorrect length of history blobs.";
   LOG(INFO) << "SGDSolver: restoring history";
@@ -638,6 +717,31 @@ void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
 }
 
 template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
+  hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
+  CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
+  this->iter_ = hdf5_load_int(file_hid, "iter");
+  if (H5LTfind_dataset(file_hid, "learned_net")) {
+    string learned_net = hdf5_load_string(file_hid, "learned_net");
+    this->net_->CopyTrainedLayersFrom(learned_net);
+  }
+  this->current_step_ = hdf5_load_int(file_hid, "current_step");
+  hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
+  CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
+  int state_history_size = hdf5_get_num_links(history_hid);
+  CHECK_EQ(state_history_size, history_.size())
+      << "Incorrect length of history blobs.";
+  for (int i = 0; i < history_.size(); ++i) {
+    ostringstream oss;
+    oss << i;
+    hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
+                                kMaxBlobAxes, history_[i].get());
+  }
+  H5Gclose(history_hid);
+  H5Fclose(file_hid);
+}
+
+template <typename Dtype>
 void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
   const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   const vector<float>& net_params_lr = this->net_->params_lr();
index a23034f..b56277b 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/hdf5.hpp"
 #include "caffe/util/io.hpp"
 #include "caffe/vision_layers.hpp"
 
diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp
new file mode 100644 (file)
index 0000000..d0d05f7
--- /dev/null
@@ -0,0 +1,160 @@
+#include "caffe/util/hdf5.hpp"
+
+#include <string>
+#include <vector>
+
+namespace caffe {
+
+// Verifies format of data stored in HDF5 file and reshapes blob accordingly.
+template <typename Dtype>
+void hdf5_load_nd_dataset_helper(
+    hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+    Blob<Dtype>* blob) {
+  // Verify that the dataset exists.
+  CHECK(H5LTfind_dataset(file_id, dataset_name_))
+      << "Failed to find HDF5 dataset " << dataset_name_;
+  // Verify that the number of dimensions is in the accepted range.
+  herr_t status;
+  int ndims;
+  status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
+  CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
+  CHECK_GE(ndims, min_dim);
+  CHECK_LE(ndims, max_dim);
+
+  // Verify that the data format is what we expect: float or double.
+  std::vector<hsize_t> dims(ndims);
+  H5T_class_t class_;
+  status = H5LTget_dataset_info(
+      file_id, dataset_name_, dims.data(), &class_, NULL);
+  CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
+  CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
+
+  vector<int> blob_dims(dims.size());
+  for (int i = 0; i < dims.size(); ++i) {
+    blob_dims[i] = dims[i];
+  }
+  blob->Reshape(blob_dims);
+}
+
+template <>
+void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
+        int min_dim, int max_dim, Blob<float>* blob) {
+  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+  herr_t status = H5LTread_dataset_float(
+    file_id, dataset_name_, blob->mutable_cpu_data());
+  CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
+}
+
+template <>
+void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
+        int min_dim, int max_dim, Blob<double>* blob) {
+  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+  herr_t status = H5LTread_dataset_double(
+    file_id, dataset_name_, blob->mutable_cpu_data());
+  CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
+}
+
+template <>
+void hdf5_save_nd_dataset<float>(
+    const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
+    bool write_diff) {
+  int num_axes = blob.num_axes();
+  hsize_t *dims = new hsize_t[num_axes];
+  for (int i = 0; i < num_axes; ++i) {
+    dims[i] = blob.shape(i);
+  }
+  const float* data;
+  if (write_diff) {
+    data = blob.cpu_diff();
+  } else {
+    data = blob.cpu_data();
+  }
+  herr_t status = H5LTmake_dataset_float(
+      file_id, dataset_name.c_str(), num_axes, dims, data);
+  CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
+  delete[] dims;
+}
+
+template <>
+void hdf5_save_nd_dataset<double>(
+    hid_t file_id, const string& dataset_name, const Blob<double>& blob,
+    bool write_diff) {
+  int num_axes = blob.num_axes();
+  hsize_t *dims = new hsize_t[num_axes];
+  for (int i = 0; i < num_axes; ++i) {
+    dims[i] = blob.shape(i);
+  }
+  const double* data;
+  if (write_diff) {
+    data = blob.cpu_diff();
+  } else {
+    data = blob.cpu_data();
+  }
+  herr_t status = H5LTmake_dataset_double(
+      file_id, dataset_name.c_str(), num_axes, dims, data);
+  CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
+  delete[] dims;
+}
+
+string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
+  // Get size of dataset
+  size_t size;
+  H5T_class_t class_;
+  herr_t status = \
+    H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
+  CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
+  char *buf = new char[size];
+  status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
+  CHECK_GE(status, 0)
+    << "Failed to load int dataset with name " << dataset_name;
+  string val(buf);
+  delete[] buf;
+  return val;
+}
+
+void hdf5_save_string(hid_t loc_id, const string& dataset_name,
+                      const string& s) {
+  herr_t status = \
+    H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
+  CHECK_GE(status, 0)
+    << "Failed to save string dataset with name " << dataset_name;
+}
+
+int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
+  int val;
+  herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
+  CHECK_GE(status, 0)
+    << "Failed to load int dataset with name " << dataset_name;
+  return val;
+}
+
+void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
+  hsize_t one = 1;
+  herr_t status = \
+    H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
+  CHECK_GE(status, 0)
+    << "Failed to save int dataset with name " << dataset_name;
+}
+
+int hdf5_get_num_links(hid_t loc_id) {
+  H5G_info_t info;
+  herr_t status = H5Gget_info(loc_id, &info);
+  CHECK_GE(status, 0) << "Error while counting HDF5 links.";
+  return info.nlinks;
+}
+
+string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
+  ssize_t str_size = H5Lget_name_by_idx(
+      loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
+  CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
+  char *c_str = new char[str_size+1];
+  ssize_t status = H5Lget_name_by_idx(
+      loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
+      H5P_DEFAULT);
+  CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
+  string result(c_str);
+  delete[] c_str;
+  return result;
+}
+
+}  // namespace caffe
index 77ef7f2..6f03314 100644 (file)
@@ -228,79 +228,5 @@ void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
   datum->set_data(buffer);
 }
 
-// Verifies format of data stored in HDF5 file and reshapes blob accordingly.
-template <typename Dtype>
-void hdf5_load_nd_dataset_helper(
-    hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
-    Blob<Dtype>* blob) {
-  // Verify that the dataset exists.
-  CHECK(H5LTfind_dataset(file_id, dataset_name_))
-      << "Failed to find HDF5 dataset " << dataset_name_;
-  // Verify that the number of dimensions is in the accepted range.
-  herr_t status;
-  int ndims;
-  status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
-  CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
-  CHECK_GE(ndims, min_dim);
-  CHECK_LE(ndims, max_dim);
-
-  // Verify that the data format is what we expect: float or double.
-  std::vector<hsize_t> dims(ndims);
-  H5T_class_t class_;
-  status = H5LTget_dataset_info(
-      file_id, dataset_name_, dims.data(), &class_, NULL);
-  CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
-  CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
-
-  vector<int> blob_dims(dims.size());
-  for (int i = 0; i < dims.size(); ++i) {
-    blob_dims[i] = dims[i];
-  }
-  blob->Reshape(blob_dims);
-}
-
-template <>
-void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
-        int min_dim, int max_dim, Blob<float>* blob) {
-  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
-  herr_t status = H5LTread_dataset_float(
-    file_id, dataset_name_, blob->mutable_cpu_data());
-  CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
-}
-
-template <>
-void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
-        int min_dim, int max_dim, Blob<double>* blob) {
-  hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
-  herr_t status = H5LTread_dataset_double(
-    file_id, dataset_name_, blob->mutable_cpu_data());
-  CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
-}
-
-template <>
-void hdf5_save_nd_dataset<float>(
-    const hid_t file_id, const string& dataset_name, const Blob<float>& blob) {
-  hsize_t dims[HDF5_NUM_DIMS];
-  dims[0] = blob.num();
-  dims[1] = blob.channels();
-  dims[2] = blob.height();
-  dims[3] = blob.width();
-  herr_t status = H5LTmake_dataset_float(
-      file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
-  CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
-}
-
-template <>
-void hdf5_save_nd_dataset<double>(
-    const hid_t file_id, const string& dataset_name, const Blob<double>& blob) {
-  hsize_t dims[HDF5_NUM_DIMS];
-  dims[0] = blob.num();
-  dims[1] = blob.channels();
-  dims[2] = blob.height();
-  dims[3] = blob.width();
-  herr_t status = H5LTmake_dataset_double(
-      file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
-  CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
-}
 
 }  // namespace caffe