[pycaffe] expose Blob.reshape as *args function
authorJonathan L Long <jonlong@cs.berkeley.edu>
Mon, 2 Mar 2015 23:27:45 +0000 (15:27 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 3 Mar 2015 23:55:15 +0000 (15:55 -0800)
python/caffe/_caffe.cpp

index d4eda79..bfea0de 100644 (file)
@@ -5,6 +5,7 @@
 
 #include <boost/make_shared.hpp>
 #include <boost/python.hpp>
+#include <boost/python/raw_function.hpp>
 #include <boost/python/suite/indexing/vector_indexing_suite.hpp>
 #include <numpy/arrayobject.h>
 
@@ -175,25 +176,18 @@ struct NdarrayCallPolicies : public bp::default_call_policies {
   }
 };
 
-void Blob_Reshape(Blob<Dtype>* blob, bp::object shape_obj) {
-  PyArrayObject* shape_arr =
-      reinterpret_cast<PyArrayObject*>(shape_obj.ptr());
-  if (!(PyArray_FLAGS(shape_arr) & NPY_ARRAY_C_CONTIGUOUS)) {
-    throw std::runtime_error("new shape must be C contiguous");
+bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) {
+  if (bp::len(kwargs) > 0) {
+    throw std::runtime_error("Blob.reshape takes no kwargs");
   }
-  if (PyArray_NDIM(shape_arr) != 1) {
-    throw std::runtime_error("new shape must be 1-d");
+  Blob<Dtype>* self = bp::extract<Blob<Dtype>*>(args[0]);
+  vector<int> shape(bp::len(args) - 1);
+  for (int i = 1; i < bp::len(args); ++i) {
+    shape[i - 1] = bp::extract<int>(args[i]);
   }
-  if (PyArray_TYPE(shape_arr) != NPY_INT32) {
-    throw std::runtime_error("new shape must be specified as int32 array");
-  }
-  npy_int32* shape_data = static_cast<npy_int32*>(PyArray_DATA(shape_arr));
-  const int num_axes = PyArray_SIZE(shape_arr);
-  vector<int> shape(num_axes);
-  for (int i = 0; i < num_axes; ++i) {
-    shape[i] = shape_data[i];
-  }
-  blob->Reshape(shape);
+  self->Reshape(shape);
+  // We need to explicitly return None to use bp::raw_function.
+  return bp::object();
 }
 
 BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
@@ -242,9 +236,7 @@ BOOST_PYTHON_MODULE(_caffe) {
     .add_property("width",    &Blob<Dtype>::width)
     .add_property("count",    static_cast<int (Blob<Dtype>::*)() const>(
         &Blob<Dtype>::count))
-    .def("reshape", static_cast<void (Blob<Dtype>::*)(int, int, int, int)>(
-        &Blob<Dtype>::Reshape))
-    .def("reshape", &Blob_Reshape)
+    .def("reshape",           bp::raw_function(&Blob_Reshape))
     .add_property("data",     bp::make_function(&Blob<Dtype>::mutable_cpu_data,
           NdarrayCallPolicies()))
     .add_property("diff",     bp::make_function(&Blob<Dtype>::mutable_cpu_diff,