solver: added snapshotting ability
authorYangqing Jia <jiayq84@gmail.com>
Tue, 15 Oct 2013 20:46:34 +0000 (13:46 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 15 Oct 2013 20:46:34 +0000 (13:46 -0700)
src/caffe/optimization/solver.cpp
src/caffe/optimization/solver.hpp
src/caffe/proto/caffe.proto
src/caffe/pyutil/drawnet.py

index 73c69c0..0c68330 100644 (file)
@@ -38,8 +38,8 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
     net_->Update();
 
     // Check if we need to do snapshot
-    if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) {
-      Snapshot(false);
+    if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
+      Snapshot();
     }
     if (param_.display() && iter_ % param_.display() == 0) {
       LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
@@ -50,18 +50,14 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
 
 
 template <typename Dtype>
-void Solver<Dtype>::Snapshot(bool is_final) {
+void Solver<Dtype>::Snapshot() {
   NetParameter net_param;
   // For intermediate results, we will also dump the gradient values.
-  net_->ToProto(&net_param, !is_final);
+  net_->ToProto(&net_param, param_.snapshot_diff());
   string filename(param_.snapshot_prefix());
-  if (is_final) {
-    filename += "_final";
-  } else {
-    char iter_str_buffer[20];
-    sprintf(iter_str_buffer, "_iter_%d", iter_);
-    filename += iter_str_buffer;
-  }
+  char iter_str_buffer[20];
+  sprintf(iter_str_buffer, "_iter_%d", iter_);
+  filename += iter_str_buffer;
   LOG(INFO) << "Snapshotting to " << filename;
   WriteProtoToBinaryFile(net_param, filename.c_str());
   SolverState state;
index a5ea612..98c872d 100644 (file)
@@ -27,7 +27,7 @@ class Solver {
   // that stores the learned net. You should implement the SnapshotSolverState()
   // function that produces a SolverState protocol buffer that needs to be
   // written to disk together with the learned net.
-  void Snapshot(bool is_final = false);
+  void Snapshot();
   virtual void SnapshotSolverState(SolverState* state) = 0;
   // The Restore function implements how one should restore the solver to a
   // previously snapshotted state. You should implement the RestoreSolverState()
index 4be9696..8eb39b3 100644 (file)
@@ -105,7 +105,9 @@ message SolverParameter {
   optional float stepsize = 12; // the stepsize for learning rate policy "step"
 
   optional string snapshot_prefix = 13; // The prefix for the snapshot.
-
+  // whether to snapshot diff in the results or not. Snapshotting diff will help
+  // debugging but the final protocol buffer size will be much larger.
+  optional bool snapshot_diff = 14 [ default = false];
   // Adagrad solver parameters
   // For Adagrad, we will first run normal sgd using the sgd parameters above
   // for adagrad_skip iterations, and then kick in the adagrad algorithm, with
@@ -114,8 +116,8 @@ message SolverParameter {
   // the layer parameter specifications, as it will adjust the learning rate
   // of individual parameters in a data-dependent way.
   //    WORK IN PROGRESS: not actually implemented yet.
-  optional float adagrad_gamma = 14; // adagrad learning rate multiplier
-  optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in
+  optional float adagrad_gamma = 15; // adagrad learning rate multiplier
+  optional float adagrad_skip = 16; // the steps to skip before adagrad kicks in
 }
 
 // A message that stores the solver snapshots
index f958c90..bce3dc4 100644 (file)
@@ -11,14 +11,7 @@ NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90',
 BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
         'style': 'filled'}
 
-def draw_net(caffe_net, ext='png'):
-  """Draws a caffe net and returns the image string encoded using the given
-  extension.
-
-  Input:
-    caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
-    ext: the image extension. Default 'png'.
-  """
+def get_pydot_graph(caffe_net):
   pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
   pydot_nodes = {}
   pydot_edges = []
@@ -47,11 +40,22 @@ def draw_net(caffe_net, ext='png'):
   for edge in pydot_edges:
     pydot_graph.add_edge(
         pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]]))
-  return pydot_graph.create(format=ext)
+  return pydot_graph
+
+def draw_net(caffe_net, ext='png'):
+  """Draws a caffe net and returns the image string encoded using the given
+  extension.
+
+  Input:
+    caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
+    ext: the image extension. Default 'png'.
+  """
+  return get_pydot_graph(caffe_net).create(format=ext)
 
 def draw_net_to_file(caffe_net, filename):
   """Draws a caffe net, and saves it to file using the format given as the
-  file extension.
+  file extension. Use '.raw' to output raw text that you can manually feed
+  to graphviz to draw graphs.
   """
   ext = filename[filename.rfind('.')+1:]
   with open(filename, 'w') as fid: