#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <boost/python.hpp>
+#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
#include <numpy/arrayobject.h>
#include "caffe/caffe.hpp"
using boost::python::len;
using boost::python::list;
using boost::python::object;
+using boost::python::handle;
+using boost::python::vector_indexing_suite;
+
+
+// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
+// to Python
+class CaffeBlob {
+ public:
+
+ CaffeBlob(const shared_ptr<Blob<float> > &blob)
+ : blob_(blob) {}
+
+ CaffeBlob()
+ {}
+
+ int num() const { return blob_->num(); }
+ int channels() const { return blob_->channels(); }
+ int height() const { return blob_->height(); }
+ int width() const { return blob_->width(); }
+ int count() const { return blob_->count(); }
+
+ bool operator == (const CaffeBlob &other)
+ {
+ return this->blob_ == other.blob_;
+ }
+
+ protected:
+ shared_ptr<Blob<float> > blob_;
+};
+
+
+// we need another wrapper (used as boost::python's HeldType) that receives a
+// self PyObject * which we can use as ndarray.base, so that data/diff memory
+// is not freed while still being used in Python
+class CaffeBlobWrap : public CaffeBlob {
+ public:
+ CaffeBlobWrap(PyObject *p, shared_ptr<Blob<float> > &blob)
+ : CaffeBlob(blob), self_(p) {}
+
+ CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
+ : CaffeBlob(blob), self_(p) {}
+
+ object get_data()
+ {
+ npy_intp dims[] = {num(), channels(), height(), width()};
+
+ PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
+ blob_->mutable_cpu_data());
+ PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
+ Py_INCREF(self_);
+ handle<> h(obj);
+
+ return object(h);
+ }
+
+ object get_diff()
+ {
+ npy_intp dims[] = {num(), channels(), height(), width()};
+
+ PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32,
+ blob_->mutable_cpu_diff());
+ PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), self_);
+ Py_INCREF(self_);
+ handle<> h(obj);
+
+ return object(h);
+ }
+
+ private:
+ PyObject *self_;
+};
+
// A simple wrapper over CaffeNet that runs the forward process.
void set_phase_test() { Caffe::set_phase(Caffe::TEST); }
void set_device(int device_id) { Caffe::SetDevice(device_id); }
+ vector<CaffeBlob> blobs() {
+ return vector<CaffeBlob>(net_->blobs().begin(), net_->blobs().end());
+ }
+
+ }
+
// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
};
+
// The boost python module definition.
BOOST_PYTHON_MODULE(pycaffe)
{
+
boost::python::class_<CaffeNet>(
"CaffeNet", boost::python::init<string, string>())
.def("Forward", &CaffeNet::Forward)
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
+ .def("blobs", &CaffeNet::blobs)
+ ;
+
+ boost::python::class_<CaffeBlob, CaffeBlobWrap>(
+ "CaffeBlob", boost::python::no_init)
+ .add_property("num", &CaffeBlob::num)
+ .add_property("channels", &CaffeBlob::channels)
+ .add_property("height", &CaffeBlob::height)
+ .add_property("width", &CaffeBlob::width)
+ .add_property("count", &CaffeBlob::count)
+ .add_property("data", &CaffeBlobWrap::get_data)
+ .add_property("diff", &CaffeBlobWrap::get_diff)
;
+
+ boost::python::class_<vector<CaffeBlob> >("BlobVec")
+ .def(vector_indexing_suite<vector<CaffeBlob>, true>());
+
+ import_array();
+
}