pycaffe: allow 1d labels to be passed to set_input_arrays
authorJonathan L Long <jonlong@cs.berkeley.edu>
Thu, 17 Apr 2014 10:06:20 +0000 (03:06 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 2 May 2014 20:27:07 +0000 (13:27 -0700)
python/caffe/_caffe.cpp
python/caffe/pycaffe.py

index f5d0d22..08899d4 100644 (file)
@@ -384,21 +384,22 @@ class CaffeSGDSolver {
 
 // The boost python module definition.
 BOOST_PYTHON_MODULE(_caffe) {
+  // below, we prepend an underscore to methods that will be replaced
+  //  in Python
   boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
       "Net", boost::python::init<string, string>())
       .def(boost::python::init<string>())
-      .def("Forward",          &CaffeNet::Forward)
-      .def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
-      .def("Backward",         &CaffeNet::Backward)
-      .def("set_mode_cpu",     &CaffeNet::set_mode_cpu)
-      .def("set_mode_gpu",     &CaffeNet::set_mode_gpu)
-      .def("set_phase_train",  &CaffeNet::set_phase_train)
-      .def("set_phase_test",   &CaffeNet::set_phase_test)
-      .def("set_device",       &CaffeNet::set_device)
-      // rename blobs here since the pycaffe.py wrapper will replace it
-      .add_property("_blobs",  &CaffeNet::blobs)
-      .add_property("layers",  &CaffeNet::layers)
-      .def("set_input_arrays", &CaffeNet::set_input_arrays);
+      .def("Forward",           &CaffeNet::Forward)
+      .def("ForwardPrefilled",  &CaffeNet::ForwardPrefilled)
+      .def("Backward",          &CaffeNet::Backward)
+      .def("set_mode_cpu",      &CaffeNet::set_mode_cpu)
+      .def("set_mode_gpu",      &CaffeNet::set_mode_gpu)
+      .def("set_phase_train",   &CaffeNet::set_phase_train)
+      .def("set_phase_test",    &CaffeNet::set_phase_test)
+      .def("set_device",        &CaffeNet::set_device)
+      .add_property("_blobs",   &CaffeNet::blobs)
+      .add_property("layers",   &CaffeNet::layers)
+      .def("_set_input_arrays", &CaffeNet::set_input_arrays);
 
   boost::python::class_<CaffeBlob, CaffeBlobWrap>(
       "Blob", boost::python::no_init)
index 863d315..05187e9 100644 (file)
@@ -3,8 +3,10 @@ Wrap the internal caffe C++ module (_caffe.so) with a clean, Pythonic
 interface.
 """
 
-from ._caffe import Net, SGDSolver
 from collections import OrderedDict
+import numpy as np
+
+from ._caffe import Net, SGDSolver
 
 # we directly update methods from Net here (rather than using composition or
 # inheritance) so that nets created by caffe (e.g., by SGDSolver) will
@@ -31,3 +33,11 @@ def _Net_params(self):
                                             if len(lr.blobs) > 0])
 
 Net.params = _Net_params
+
+def _Net_set_input_arrays(self, data, labels):
+    if labels.ndim == 1:
+        labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis,
+                                             np.newaxis])
+    return self._set_input_arrays(data, labels)
+
+Net.set_input_arrays = _Net_set_input_arrays