[pytest] check that Python receives (correct) exceptions from Python layers
authorJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 15 May 2015 05:20:58 +0000 (22:20 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 15 May 2015 05:20:58 +0000 (22:20 -0700)
python/caffe/test/test_python_layer.py

index 6fba491..46c6e88 100644 (file)
@@ -21,6 +21,13 @@ class SimpleLayer(caffe.Layer):
         bottom[0].diff[...] = 10 * top[0].diff
 
 
+class ExceptionLayer(caffe.Layer):
+    """A layer for checking exceptions from Python"""
+
+    def setup(self, bottom, top):
+        raise RuntimeError
+
+
 def python_net_file():
     with tempfile.NamedTemporaryFile(delete=False) as f:
         f.write("""name: 'pythonnet' force_backward: true
@@ -34,6 +41,16 @@ def python_net_file():
         return f.name
 
 
+def exception_net_file():
+    with tempfile.NamedTemporaryFile(delete=False) as f:
+        f.write("""name: 'pythonnet' force_backward: true
+        input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
+        layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
+          python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
+          """)
+        return f.name
+
+
 class TestPythonLayer(unittest.TestCase):
     def setUp(self):
         net_file = python_net_file()
@@ -61,3 +78,8 @@ class TestPythonLayer(unittest.TestCase):
         for blob in self.net.blobs.itervalues():
             for d in blob.data.shape:
                 self.assertEqual(s, d)
+
+    def test_exception(self):
+        net_file = exception_net_file()
+        self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
+        os.remove(net_file)