[pycaffe] give phase to Net
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 26 Jan 2015 04:57:23 +0000 (20:57 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 17 Feb 2015 19:35:51 +0000 (11:35 -0800)
- expose `caffe.{TRAIN,TEST}` constants
- instantiate `caffe.Net`s with phase
- drop singleton phase interface `caffe.set_phase_{train,test}`

python/caffe/__init__.py
python/caffe/_caffe.cpp
python/caffe/test/test_net.py
python/caffe/test/test_python_layer.py

index 49f8678..3150287 100644 (file)
@@ -1,6 +1,7 @@
 from .pycaffe import Net, SGDSolver
 from ._caffe import set_mode_cpu, set_mode_gpu, set_device, \
-    set_phase_train, set_phase_test, Layer, get_solver
+    Layer, get_solver
+from .proto.caffe_pb2 import TRAIN, TEST
 from .classifier import Classifier
 from .detector import Detector
 import io
index 8662727..a5d0e64 100644 (file)
@@ -31,10 +31,9 @@ namespace caffe {
 typedef float Dtype;
 const int NPY_DTYPE = NPY_FLOAT32;
 
+// Selecting mode.
 void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
 void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
-void set_phase_train() { Caffe::set_phase(Caffe::TRAIN); }
-void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
 
 // For convenience, check that input files can be opened, and raise an
 // exception that boost will send to Python if not (caffe could still crash
@@ -71,13 +70,24 @@ void CheckContiguousArray(PyArrayObject* arr, string name,
   }
 }
 
-// Net construct-and-load convenience constructor
+// Net constructor for passing phase as int
 shared_ptr<Net<Dtype> > Net_Init(
-    string param_file, string pretrained_param_file) {
+    string param_file, int phase) {
+  CheckFile(param_file);
+
+  shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
+      static_cast<Phase>(phase)));
+  return net;
+}
+
+// Net construct-and-load convenience constructor
+shared_ptr<Net<Dtype> > Net_Init_Load(
+    string param_file, string pretrained_param_file, int phase) {
   CheckFile(param_file);
   CheckFile(pretrained_param_file);
 
-  shared_ptr<Net<Dtype> > net (new Net<Dtype>(param_file));
+  shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
+      static_cast<Phase>(phase)));
   net->CopyTrainedLayersFrom(pretrained_param_file);
   return net;
 }
@@ -172,13 +182,12 @@ BOOST_PYTHON_MODULE(_caffe) {
   // Caffe utility functions
   bp::def("set_mode_cpu", &set_mode_cpu);
   bp::def("set_mode_gpu", &set_mode_gpu);
-  bp::def("set_phase_train", &set_phase_train);
-  bp::def("set_phase_test", &set_phase_test);
   bp::def("set_device", &Caffe::SetDevice);
 
-  bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >(
-    "Net", bp::init<string>())
+  bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
+    bp::no_init)
     .def("__init__", bp::make_constructor(&Net_Init))
+    .def("__init__", bp::make_constructor(&Net_Init_Load))
     .def("_forward", &Net<Dtype>::ForwardFromTo)
     .def("_backward", &Net<Dtype>::BackwardFromTo)
     .def("reshape", &Net<Dtype>::Reshape)
index 9381c72..62b407d 100644 (file)
@@ -35,7 +35,7 @@ class TestNet(unittest.TestCase):
     def setUp(self):
         self.num_output = 13
         net_file = simple_net_file(self.num_output)
-        self.net = caffe.Net(net_file)
+        self.net = caffe.Net(net_file, caffe.TRAIN)
         # fill in valid labels
         self.net.blobs['label'].data[...] = \
                 np.random.randint(self.num_output,
@@ -69,7 +69,7 @@ class TestNet(unittest.TestCase):
         f.close()
         self.net.save(f.name)
         net_file = simple_net_file(self.num_output)
-        net2 = caffe.Net(net_file, f.name)
+        net2 = caffe.Net(net_file, f.name, caffe.TRAIN)
         os.remove(net_file)
         os.remove(f.name)
         for name in self.net.params:
index 03f5834..383c283 100644 (file)
@@ -36,7 +36,7 @@ def python_net_file():
 class TestPythonLayer(unittest.TestCase):
     def setUp(self):
         net_file = python_net_file()
-        self.net = caffe.Net(net_file)
+        self.net = caffe.Net(net_file, caffe.TRAIN)
         os.remove(net_file)
 
     def test_forward(self):