f.write("""net: '""" + net_f + """'
test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
- display: 100 max_iter: 100 snapshot_after_train: false""")
+ display: 100 max_iter: 100 snapshot_after_train: false
+ snapshot_prefix: "model" """)
f.close()
self.solver = caffe.SGDSolver(f.name)
# also make sure get_solver runs
total += p.data.sum() + p.diff.sum()
for bl in six.itervalues(net.blobs):
total += bl.data.sum() + bl.diff.sum()
+
+ def test_snapshot(self):
+ self.solver.snapshot()
+ # Check that these files exist and then remove them
+ files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
+ for fn in files:
+ assert os.path.isfile(fn)
+ os.remove(fn)