// For Python, for now, we'll just always use float as the type.
typedef float Dtype;
+const int NPY_DTYPE = NPY_FLOAT32;
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
PyArray_DIMS(data_arr)[0]);
}
+struct NdarrayConverterGenerator {
+ template <typename T> struct apply;
+};
+
+template <>
+struct NdarrayConverterGenerator::apply<Dtype*> {
+ struct type {
+ PyObject* operator() (Dtype* data) const {
+ // Just store the data pointer, and add the shape information in postcall.
+ return PyArray_SimpleNewFromData(0, NULL, NPY_DTYPE, data);
+ }
+ const PyTypeObject* get_pytype() {
+ return &PyArray_Type;
+ }
+ };
+};
+
+struct NdarrayCallPolicies : public bp::default_call_policies {
+ typedef NdarrayConverterGenerator result_converter;
+ PyObject* postcall(PyObject* pyargs, PyObject* result) {
+ bp::object pyblob = bp::extract<bp::tuple>(pyargs)()[0];
+ shared_ptr<Blob<Dtype> > blob =
+ bp::extract<shared_ptr<Blob<Dtype> > >(pyblob);
+ // Free the temporary pointer-holding array, and construct a new one with
+ // the shape information from the blob.
+ void* data = PyArray_DATA(reinterpret_cast<PyArrayObject*>(result));
+ Py_DECREF(result);
+ npy_intp dims[] = {blob->num(), blob->channels(),
+ blob->height(), blob->width()};
+ PyObject* arr_obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, data);
+ // SetBaseObject steals a ref, so we need to INCREF.
+ Py_INCREF(pyblob.ptr());
+ PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(arr_obj),
+ pyblob.ptr());
+ return arr_obj;
+ }
+};
+
BOOST_PYTHON_MODULE(_caffe) {
// below, we prepend an underscore to methods that will be replaced
// in Python
bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
.def("save", &Net_Save);
+ bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
+ "Blob", bp::no_init)
+ .add_property("num", &Blob<Dtype>::num)
+ .add_property("channels", &Blob<Dtype>::channels)
+ .add_property("height", &Blob<Dtype>::height)
+ .add_property("width", &Blob<Dtype>::width)
+ .add_property("count", &Blob<Dtype>::count)
+ .def("reshape", &Blob<Dtype>::Reshape)
+ .add_property("data", bp::make_function(&Blob<Dtype>::mutable_cpu_data,
+ NdarrayCallPolicies()))
+ .add_property("diff", bp::make_function(&Blob<Dtype>::mutable_cpu_diff,
+ NdarrayCallPolicies()));
+
// vector wrappers for all the vector types we use
bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());