def backward(self, top, propagate_down, bottom):
self.blobs[0].diff[0] = 1
+class PhaseLayer(caffe.Layer):
+ """A layer for checking attribute `phase`"""
+
+ def setup(self, bottom, top):
+ pass
+
+ def reshape(self, bootom, top):
+ top[0].reshape()
+
+ def forward(self, bottom, top):
+ top[0].data[()] = self.phase
+
def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
""")
return f.name
+def phase_net_file():
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
+ f.write("""name: 'pythonnet' force_backward: true
+ layer { type: 'Python' name: 'layer' top: 'phase'
+ python_param { module: 'test_python_layer' layer: 'PhaseLayer' } }
+ """)
+ return f.name
+
@unittest.skipIf('Python' not in caffe.layer_type_list(),
'Caffe built without Python layer support')
self.assertEqual(layer.blobs[0].data[0], 1)
os.remove(net_file)
+
+ def test_phase(self):
+ net_file = phase_net_file()
+ for phase in caffe.TRAIN, caffe.TEST:
+ net = caffe.Net(net_file, phase)
+ self.assertEqual(net.forward()['phase'], phase)