From c2dba923b82c669f2998a3174310fbbb5c64c39f Mon Sep 17 00:00:00 2001 From: ZhouYzzz Date: Wed, 4 May 2016 18:00:12 +0800 Subject: [PATCH] Add test for attribute "phase" in python layer --- python/caffe/test/test_python_layer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/caffe/test/test_python_layer.py b/python/caffe/test/test_python_layer.py index e46b711..899514e 100644 --- a/python/caffe/test/test_python_layer.py +++ b/python/caffe/test/test_python_layer.py @@ -44,6 +44,18 @@ class ParameterLayer(caffe.Layer): def backward(self, top, propagate_down, bottom): self.blobs[0].diff[0] = 1 +class PhaseLayer(caffe.Layer): + """A layer for checking attribute `phase`""" + + def setup(self, bottom, top): + pass + + def reshape(self, bootom, top): + top[0].reshape() + + def forward(self, bottom, top): + top[0].data[()] = self.phase + def python_net_file(): with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: f.write("""name: 'pythonnet' force_backward: true @@ -76,6 +88,14 @@ def parameter_net_file(): """) return f.name +def phase_net_file(): + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: + f.write("""name: 'pythonnet' force_backward: true + layer { type: 'Python' name: 'layer' top: 'phase' + python_param { module: 'test_python_layer' layer: 'PhaseLayer' } } + """) + return f.name + @unittest.skipIf('Python' not in caffe.layer_type_list(), 'Caffe built without Python layer support') @@ -140,3 +160,9 @@ class TestPythonLayer(unittest.TestCase): self.assertEqual(layer.blobs[0].data[0], 1) os.remove(net_file) + + def test_phase(self): + net_file = phase_net_file() + for phase in caffe.TRAIN, caffe.TEST: + net = caffe.Net(net_file, phase) + self.assertEqual(net.forward()['phase'], phase) -- 2.7.4