[pycaffe] re-expose Blob
authorJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 6 Jan 2015 05:02:04 +0000 (21:02 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 17 Feb 2015 06:46:13 +0000 (22:46 -0800)
python/caffe/_caffe.cpp

index 5fcd853..c545a74 100644 (file)
@@ -28,6 +28,7 @@ namespace caffe {
 
 // For Python, for now, we'll just always use float as the type.
 typedef float Dtype;
+const int NPY_DTYPE = NPY_FLOAT32;
 
 void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
 void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
@@ -118,6 +119,44 @@ void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
       PyArray_DIMS(data_arr)[0]);
 }
 
+struct NdarrayConverterGenerator {
+  template <typename T> struct apply;
+};
+
+template <>
+struct NdarrayConverterGenerator::apply<Dtype*> {
+  struct type {
+    PyObject* operator() (Dtype* data) const {
+      // Just store the data pointer, and add the shape information in postcall.
+      return PyArray_SimpleNewFromData(0, NULL, NPY_DTYPE, data);
+    }
+    const PyTypeObject* get_pytype() {
+      return &PyArray_Type;
+    }
+  };
+};
+
+struct NdarrayCallPolicies : public bp::default_call_policies {
+  typedef NdarrayConverterGenerator result_converter;
+  PyObject* postcall(PyObject* pyargs, PyObject* result) {
+    bp::object pyblob = bp::extract<bp::tuple>(pyargs)()[0];
+    shared_ptr<Blob<Dtype> > blob =
+      bp::extract<shared_ptr<Blob<Dtype> > >(pyblob);
+    // Free the temporary pointer-holding array, and construct a new one with
+    // the shape information from the blob.
+    void* data = PyArray_DATA(reinterpret_cast<PyArrayObject*>(result));
+    Py_DECREF(result);
+    npy_intp dims[] = {blob->num(), blob->channels(),
+                       blob->height(), blob->width()};
+    PyObject* arr_obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, data);
+    // SetBaseObject steals a ref, so we need to INCREF.
+    Py_INCREF(pyblob.ptr());
+    PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(arr_obj),
+        pyblob.ptr());
+    return arr_obj;
+  }
+};
+
 BOOST_PYTHON_MODULE(_caffe) {
   // below, we prepend an underscore to methods that will be replaced
   // in Python
@@ -155,6 +194,19 @@ BOOST_PYTHON_MODULE(_caffe) {
         bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
     .def("save", &Net_Save);
 
+  bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
+    "Blob", bp::no_init)
+    .add_property("num",      &Blob<Dtype>::num)
+    .add_property("channels", &Blob<Dtype>::channels)
+    .add_property("height",   &Blob<Dtype>::height)
+    .add_property("width",    &Blob<Dtype>::width)
+    .add_property("count",    &Blob<Dtype>::count)
+    .def("reshape",           &Blob<Dtype>::Reshape)
+    .add_property("data",     bp::make_function(&Blob<Dtype>::mutable_cpu_data,
+          NdarrayCallPolicies()))
+    .add_property("diff",     bp::make_function(&Blob<Dtype>::mutable_cpu_diff,
+          NdarrayCallPolicies()));
+
   // vector wrappers for all the vector types we use
   bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
     .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());