net_->Update();
// Check if we need to do snapshot
- if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) {
- Snapshot(false);
+ if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
+ Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
template <typename Dtype>
-void Solver<Dtype>::Snapshot(bool is_final) {
+void Solver<Dtype>::Snapshot() {
NetParameter net_param;
// For intermediate results, we will also dump the gradient values.
- net_->ToProto(&net_param, !is_final);
+ net_->ToProto(&net_param, param_.snapshot_diff());
string filename(param_.snapshot_prefix());
- if (is_final) {
- filename += "_final";
- } else {
- char iter_str_buffer[20];
- sprintf(iter_str_buffer, "_iter_%d", iter_);
- filename += iter_str_buffer;
- }
+ char iter_str_buffer[20];
+ sprintf(iter_str_buffer, "_iter_%d", iter_);
+ filename += iter_str_buffer;
LOG(INFO) << "Snapshotting to " << filename;
WriteProtoToBinaryFile(net_param, filename.c_str());
SolverState state;
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
- void Snapshot(bool is_final = false);
+ void Snapshot();
virtual void SnapshotSolverState(SolverState* state) = 0;
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
optional float stepsize = 12; // the stepsize for learning rate policy "step"
optional string snapshot_prefix = 13; // The prefix for the snapshot.
-
+ // whether to snapshot diff in the results or not. Snapshotting diff will help
+ // debugging but the final protocol buffer size will be much larger.
+ optional bool snapshot_diff = 14 [ default = false];
// Adagrad solver parameters
// For Adagrad, we will first run normal sgd using the sgd parameters above
// for adagrad_skip iterations, and then kick in the adagrad algorithm, with
// the layer parameter specifications, as it will adjust the learning rate
// of individual parameters in a data-dependent way.
// WORK IN PROGRESS: not actually implemented yet.
- optional float adagrad_gamma = 14; // adagrad learning rate multiplier
- optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in
+ optional float adagrad_gamma = 15; // adagrad learning rate multiplier
+ optional float adagrad_skip = 16; // the steps to skip before adagrad kicks in
}
// A message that stores the solver snapshots
BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
'style': 'filled'}
-def draw_net(caffe_net, ext='png'):
- """Draws a caffe net and returns the image string encoded using the given
- extension.
-
- Input:
- caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
- ext: the image extension. Default 'png'.
- """
+def get_pydot_graph(caffe_net):
pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
pydot_nodes = {}
pydot_edges = []
for edge in pydot_edges:
pydot_graph.add_edge(
pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]]))
- return pydot_graph.create(format=ext)
+ return pydot_graph
+
+def draw_net(caffe_net, ext='png'):
+ """Draws a caffe net and returns the image string encoded using the given
+ extension.
+
+ Input:
+ caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
+ ext: the image extension. Default 'png'.
+ """
+ return get_pydot_graph(caffe_net).create(format=ext)
def draw_net_to_file(caffe_net, filename):
"""Draws a caffe net, and saves it to file using the format given as the
- file extension.
+ file extension. Use '.raw' to output raw text that you can manually feed
+ to graphviz to draw graphs.
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'w') as fid: