// The boost python module definition.
BOOST_PYTHON_MODULE(_caffe) {
+ // below, we prepend an underscore to methods that will be replaced
+ // in Python
boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
"Net", boost::python::init<string, string>())
.def(boost::python::init<string>())
- .def("Forward", &CaffeNet::Forward)
- .def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
- .def("Backward", &CaffeNet::Backward)
- .def("set_mode_cpu", &CaffeNet::set_mode_cpu)
- .def("set_mode_gpu", &CaffeNet::set_mode_gpu)
- .def("set_phase_train", &CaffeNet::set_phase_train)
- .def("set_phase_test", &CaffeNet::set_phase_test)
- .def("set_device", &CaffeNet::set_device)
- // rename blobs here since the pycaffe.py wrapper will replace it
- .add_property("_blobs", &CaffeNet::blobs)
- .add_property("layers", &CaffeNet::layers)
- .def("set_input_arrays", &CaffeNet::set_input_arrays);
+ .def("Forward", &CaffeNet::Forward)
+ .def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
+ .def("Backward", &CaffeNet::Backward)
+ .def("set_mode_cpu", &CaffeNet::set_mode_cpu)
+ .def("set_mode_gpu", &CaffeNet::set_mode_gpu)
+ .def("set_phase_train", &CaffeNet::set_phase_train)
+ .def("set_phase_test", &CaffeNet::set_phase_test)
+ .def("set_device", &CaffeNet::set_device)
+ .add_property("_blobs", &CaffeNet::blobs)
+ .add_property("layers", &CaffeNet::layers)
+ .def("_set_input_arrays", &CaffeNet::set_input_arrays);
boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"Blob", boost::python::no_init)
interface.
"""
-from ._caffe import Net, SGDSolver
from collections import OrderedDict
+import numpy as np
+
+from ._caffe import Net, SGDSolver
# we directly update methods from Net here (rather than using composition or
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
if len(lr.blobs) > 0])
Net.params = _Net_params
+
+def _Net_set_input_arrays(self, data, labels):
+ if labels.ndim == 1:
+ labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis,
+ np.newaxis])
+ return self._set_input_arrays(data, labels)
+
+Net.set_input_arrays = _Net_set_input_arrays