From 19d9927d76d6655a3efc090611e59aa2ea0f25a5 Mon Sep 17 00:00:00 2001 From: Gustav Larsson Date: Mon, 5 Oct 2015 21:55:00 -0500 Subject: [PATCH] Add pycaffe test for solver.snapshot() --- python/caffe/test/test_solver.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py index 9cfc10d..f618fde 100644 --- a/python/caffe/test/test_solver.py +++ b/python/caffe/test/test_solver.py @@ -16,7 +16,8 @@ class TestSolver(unittest.TestCase): f.write("""net: '""" + net_f + """' test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75 - display: 100 max_iter: 100 snapshot_after_train: false""") + display: 100 max_iter: 100 snapshot_after_train: false + snapshot_prefix: "model" """) f.close() self.solver = caffe.SGDSolver(f.name) # also make sure get_solver runs @@ -51,3 +52,11 @@ class TestSolver(unittest.TestCase): total += p.data.sum() + p.diff.sum() for bl in six.itervalues(net.blobs): total += bl.data.sum() + bl.diff.sum() + + def test_snapshot(self): + self.solver.snapshot() + # Check that these files exist and then remove them + files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate'] + for fn in files: + assert os.path.isfile(fn) + os.remove(fn) -- 2.7.4