bottom[0].diff[...] = 10 * top[0].diff
+ class ExceptionLayer(caffe.Layer):
+ """A layer for checking exceptions from Python"""
+
+ def setup(self, bottom, top):
+ raise RuntimeError
+
+
def python_net_file():
- with tempfile.NamedTemporaryFile(delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
layer { type: 'Python' name: 'one' bottom: 'data' top: 'one'
s = 4
self.net.blobs['data'].reshape(s, s, s, s)
self.net.forward()
- for blob in self.net.blobs.itervalues():
+ for blob in six.itervalues(self.net.blobs):
for d in blob.data.shape:
self.assertEqual(s, d)
+
+ def test_exception(self):
+ net_file = exception_net_file()
+ self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
+ os.remove(net_file)