Add Python interface to layer blobs
authorJonathan L Long <jonlong@cs.berkeley.edu>
Sat, 7 Dec 2013 00:39:01 +0000 (16:39 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 14 Jan 2014 23:46:51 +0000 (15:46 -0800)
python/caffe/pycaffe.cpp

index d09ccf5..346e37e 100644 (file)
@@ -6,6 +6,7 @@
 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
 
 #include <boost/python.hpp>
+#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
 #include <numpy/arrayobject.h>
 #include "caffe/caffe.hpp"
 
@@ -20,6 +21,78 @@ using boost::python::extract;
 using boost::python::len;
 using boost::python::list;
 using boost::python::object;
+using boost::python::handle;
+using boost::python::vector_indexing_suite;
+
+
+// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
+//  to Python
+class CaffeBlob {
+ public:
+
+  CaffeBlob(const shared_ptr<Blob<float> > &blob)
+      : blob_(blob) {}
+
+  CaffeBlob()
+  {}
+
+  int num() const { return blob_->num(); }
+  int channels() const { return blob_->channels(); }
+  int height() const { return blob_->height(); }
+  int width() const { return blob_->width(); }
+  int count() const { return blob_->count(); }
+
+  bool operator == (const CaffeBlob &other)
+  {
+      return this->blob_ == other.blob_;
+  }
+
+ protected:
+  shared_ptr<Blob<float> > blob_;
+};
+
+
+// we need another wrapper (used as boost::python's HeldType) that receives a
+//  self PyObject * which we can use as ndarray.base, so that data/diff memory
+//  is not freed while still being used in Python
+class CaffeBlobWrap : public CaffeBlob {
+ public:
+  CaffeBlobWrap(PyObject *p, shared_ptr<Blob<float> > &blob)
+      : CaffeBlob(blob), self_(p) {}
+
+  CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
+      : CaffeBlob(blob), self_(p) {}
+
+  object get_data()
+  {
+      npy_intp dims[] = {num(), channels(), height(), width()};
+
+      PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
+                                                blob_->mutable_cpu_data());
+      PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
+      Py_INCREF(self_);
+      handle<> h(obj);
+
+      return object(h);
+  }
+
+  object get_diff()
+  {
+      npy_intp dims[] = {num(), channels(), height(), width()};
+
+      PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
+                                                blob_->mutable_cpu_diff());
+      PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
+      Py_INCREF(self_);
+      handle<> h(obj);
+
+      return object(h);
+  }
+
+ private:
+  PyObject *self_;
+};
+
 
 
 // A simple wrapper over CaffeNet that runs the forward process.
@@ -143,14 +216,22 @@ struct CaffeNet
   void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
   void set_device(int device_id) { Caffe::SetDevice(device_id); }
 
+  vector<CaffeBlob> blobs() {
+      return vector<CaffeBlob>(net_->blobs().begin(), net_->blobs().end());
+  }
+
+  }
+
   // The pointer to the internal caffe::Net instant.
        shared_ptr<Net<float> > net_;
 };
 
 
+
 // The boost python module definition.
 BOOST_PYTHON_MODULE(pycaffe)
 {
+
   boost::python::class_<CaffeNet>(
       "CaffeNet", boost::python::init<string, string>())
       .def("Forward",         &CaffeNet::Forward)
@@ -160,5 +241,23 @@ BOOST_PYTHON_MODULE(pycaffe)
       .def("set_phase_train", &CaffeNet::set_phase_train)
       .def("set_phase_test",  &CaffeNet::set_phase_test)
       .def("set_device",      &CaffeNet::set_device)
+      .def("blobs",           &CaffeNet::blobs)
+  ;
+
+  boost::python::class_<CaffeBlob, CaffeBlobWrap>(
+      "CaffeBlob", boost::python::no_init)
+      .add_property("num",      &CaffeBlob::num)
+      .add_property("channels", &CaffeBlob::channels)
+      .add_property("height",   &CaffeBlob::height)
+      .add_property("width",    &CaffeBlob::width)
+      .add_property("count",    &CaffeBlob::count)
+      .add_property("data",     &CaffeBlobWrap::get_data)
+      .add_property("diff",     &CaffeBlobWrap::get_diff)
   ;
+
+  boost::python::class_<vector<CaffeBlob> >("BlobVec")
+      .def(vector_indexing_suite<vector<CaffeBlob>, true>());
+
+  import_array();
+
 }