Add pycaffe test for solver.snapshot()
authorGustav Larsson <gustav.m.larsson@gmail.com>
Tue, 6 Oct 2015 02:55:00 +0000 (21:55 -0500)
committerGustav Larsson <gustav.m.larsson@gmail.com>
Tue, 6 Oct 2015 03:41:01 +0000 (22:41 -0500)
python/caffe/test/test_solver.py

index 9cfc10d..f618fde 100644 (file)
@@ -16,7 +16,8 @@ class TestSolver(unittest.TestCase):
         f.write("""net: '""" + net_f + """'
         test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
         weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
-        display: 100 max_iter: 100 snapshot_after_train: false""")
+        display: 100 max_iter: 100 snapshot_after_train: false
+        snapshot_prefix: "model" """)
         f.close()
         self.solver = caffe.SGDSolver(f.name)
         # also make sure get_solver runs
@@ -51,3 +52,11 @@ class TestSolver(unittest.TestCase):
                     total += p.data.sum() + p.diff.sum()
             for bl in six.itervalues(net.blobs):
                 total += bl.data.sum() + bl.diff.sum()
+
+    def test_snapshot(self):
+        self.solver.snapshot()
+        # Check that these files exist and then remove them
+        files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
+        for fn in files:
+            assert os.path.isfile(fn)
+            os.remove(fn)