Add test for attribute "phase" in python layer
authorZhouYzzz <zhouyz9608@gmail.com>
Wed, 4 May 2016 10:00:12 +0000 (18:00 +0800)
committerZhouYzzz <zhouyz9608@gmail.com>
Wed, 4 May 2016 10:00:12 +0000 (18:00 +0800)
python/caffe/test/test_python_layer.py

index e46b711..899514e 100644 (file)
@@ -44,6 +44,18 @@ class ParameterLayer(caffe.Layer):
     def backward(self, top, propagate_down, bottom):
         self.blobs[0].diff[0] = 1
 
+class PhaseLayer(caffe.Layer):
+    """A layer for checking attribute `phase`"""
+
+    def setup(self, bottom, top):
+        pass
+
+    def reshape(self, bootom, top):
+        top[0].reshape()
+
+    def forward(self, bottom, top):
+        top[0].data[()] = self.phase
+
 def python_net_file():
     with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
         f.write("""name: 'pythonnet' force_backward: true
@@ -76,6 +88,14 @@ def parameter_net_file():
           """)
         return f.name
 
+def phase_net_file():
+    with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
+        f.write("""name: 'pythonnet' force_backward: true
+        layer { type: 'Python' name: 'layer' top: 'phase'
+          python_param { module: 'test_python_layer' layer: 'PhaseLayer' } }
+          """)
+        return f.name
+
 
 @unittest.skipIf('Python' not in caffe.layer_type_list(),
     'Caffe built without Python layer support')
@@ -140,3 +160,9 @@ class TestPythonLayer(unittest.TestCase):
         self.assertEqual(layer.blobs[0].data[0], 1)
 
         os.remove(net_file)
+
+    def test_phase(self):
+        net_file = phase_net_file()
+        for phase in caffe.TRAIN, caffe.TEST:
+            net = caffe.Net(net_file, phase)
+            self.assertEqual(net.forward()['phase'], phase)