[pycaffe] basic, partial testing of Net and SGDSolver
authorJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 25 Nov 2014 02:53:14 +0000 (18:53 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 7 Jan 2015 04:39:14 +0000 (20:39 -0800)
python/caffe/test/test_net.py [new file with mode: 0644]
python/caffe/test/test_solver.py [new file with mode: 0644]

diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
new file mode 100644 (file)
index 0000000..f0e9dee
--- /dev/null
@@ -0,0 +1,77 @@
+import unittest
+import tempfile
+import os
+import numpy as np
+
+import caffe
+
+def simple_net_file(num_output):
+    """Make a simple net prototxt, based on test_net.cpp, returning the name
+    of the (temporary) file."""
+
+    f = tempfile.NamedTemporaryFile(delete=False)
+    f.write("""name: 'testnet' force_backward: true
+    layers { type: DUMMY_DATA name: 'data' top: 'data' top: 'label'
+      dummy_data_param { num: 5 channels: 2 height: 3 width: 4
+        num: 5 channels: 1 height: 1 width: 1
+        data_filler { type: 'gaussian' std: 1 }
+        data_filler { type: 'constant' } } }
+    layers { type: CONVOLUTION name: 'conv' bottom: 'data' top: 'conv'
+      convolution_param { num_output: 11 kernel_size: 2 pad: 3
+        weight_filler { type: 'gaussian' std: 1 }
+        bias_filler { type: 'constant' value: 2 } }
+        weight_decay: 1 weight_decay: 0 }
+    layers { type: INNER_PRODUCT name: 'ip' bottom: 'conv' top: 'ip'
+      inner_product_param { num_output: """ + str(num_output) + """
+        weight_filler { type: 'gaussian' std: 2.5 }
+        bias_filler { type: 'constant' value: -3 } } }
+    layers { type: SOFTMAX_LOSS name: 'loss' bottom: 'ip' bottom: 'label'
+      top: 'loss' }""")
+    f.close()
+    return f.name
+
+class TestNet(unittest.TestCase):
+    def setUp(self):
+        self.num_output = 13
+        net_file = simple_net_file(self.num_output)
+        self.net = caffe.Net(net_file)
+        # fill in valid labels
+        self.net.blobs['label'].data[...] = \
+                np.random.randint(self.num_output,
+                    size=self.net.blobs['label'].data.shape)
+        os.remove(net_file)
+
+    def test_memory(self):
+        """Check that holding onto blob data beyond the life of a Net is OK"""
+
+        params = sum(map(list, self.net.params.itervalues()), [])
+        blobs = self.net.blobs.values()
+        del self.net
+
+        # now sum everything (forcing all memory to be read)
+        total = 0
+        for p in params:
+            total += p.data.sum() + p.diff.sum()
+        for bl in blobs:
+            total += bl.data.sum() + bl.diff.sum()
+
+    def test_forward_backward(self):
+        self.net.forward()
+        self.net.backward()
+
+    def test_inputs_outputs(self):
+        self.assertEqual(self.net.inputs, [])
+        self.assertEqual(self.net.outputs, ['loss'])
+
+    def test_save_and_read(self):
+        f = tempfile.NamedTemporaryFile(delete=False)
+        f.close()
+        self.net.save(f.name)
+        net_file = simple_net_file(self.num_output)
+        net2 = caffe.Net(net_file, f.name)
+        os.remove(net_file)
+        os.remove(f.name)
+        for name in self.net.params:
+            for i in range(len(self.net.params[name])):
+                self.assertEqual(abs(self.net.params[name][i].data
+                    - net2.params[name][i].data).sum(), 0)
diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py
new file mode 100644 (file)
index 0000000..b78c91f
--- /dev/null
@@ -0,0 +1,49 @@
+import unittest
+import tempfile
+import os
+import numpy as np
+
+import caffe
+from test_net import simple_net_file
+
+class TestSolver(unittest.TestCase):
+    def setUp(self):
+        self.num_output = 13
+        net_f = simple_net_file(self.num_output)
+        f = tempfile.NamedTemporaryFile(delete=False)
+        f.write("""net: '""" + net_f + """'
+        test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
+        weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
+        display: 100 max_iter: 100 snapshot_after_train: false""")
+        f.close()
+        self.solver = caffe.SGDSolver(f.name)
+        self.solver.net.set_mode_cpu()
+        # fill in valid labels
+        self.solver.net.blobs['label'].data[...] = \
+                np.random.randint(self.num_output,
+                    size=self.solver.net.blobs['label'].data.shape)
+        self.solver.test_nets[0].blobs['label'].data[...] = \
+                np.random.randint(self.num_output,
+                    size=self.solver.test_nets[0].blobs['label'].data.shape)
+        os.remove(f.name)
+        os.remove(net_f)
+
+    def test_solve(self):
+        self.assertEqual(self.solver.iter, 0)
+        self.solver.solve()
+        self.assertEqual(self.solver.iter, 100)
+
+    def test_net_memory(self):
+        """Check that nets survive after the solver is destroyed."""
+
+        nets = [self.solver.net] + list(self.solver.test_nets)
+        self.assertEqual(len(nets), 2)
+        del self.solver
+
+        total = 0
+        for net in nets:
+            for ps in net.params.itervalues():
+                for p in ps:
+                    total += p.data.sum() + p.diff.sum()
+            for bl in net.blobs.itervalues():
+                total += bl.data.sum() + bl.diff.sum()