[pycaffe] re-expose Net
authorJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 6 Jan 2015 03:31:27 +0000 (19:31 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 17 Feb 2015 06:46:13 +0000 (22:46 -0800)
python/caffe/_caffe.cpp
python/caffe/pycaffe.py

index 7caf785..5fcd853 100644 (file)
@@ -47,6 +47,77 @@ static void CheckFile(const string& filename) {
     f.close();
 }
 
+void CheckContiguousArray(PyArrayObject* arr, string name,
+    int channels, int height, int width) {
+  if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
+    throw std::runtime_error(name + " must be C contiguous");
+  }
+  if (PyArray_NDIM(arr) != 4) {
+    throw std::runtime_error(name + " must be 4-d");
+  }
+  if (PyArray_TYPE(arr) != NPY_FLOAT32) {
+    throw std::runtime_error(name + " must be float32");
+  }
+  if (PyArray_DIMS(arr)[1] != channels) {
+    throw std::runtime_error(name + " has wrong number of channels");
+  }
+  if (PyArray_DIMS(arr)[2] != height) {
+    throw std::runtime_error(name + " has wrong height");
+  }
+  if (PyArray_DIMS(arr)[3] != width) {
+    throw std::runtime_error(name + " has wrong width");
+  }
+}
+
+// Net construct-and-load convenience constructor
+shared_ptr<Net<Dtype> > Net_Init(
+    string param_file, string pretrained_param_file) {
+  CheckFile(param_file);
+  CheckFile(pretrained_param_file);
+
+  shared_ptr<Net<Dtype> > net (new Net<Dtype>(param_file));
+  net->CopyTrainedLayersFrom(pretrained_param_file);
+  return net;
+}
+
+void Net_Save(const Net<Dtype>& net, string filename) {
+  NetParameter net_param;
+  net.ToProto(&net_param, false);
+  WriteProtoToBinaryFile(net_param, filename.c_str());
+}
+
+void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
+    bp::object labels_obj) {
+  // check that this network has an input MemoryDataLayer
+  shared_ptr<MemoryDataLayer<Dtype> > md_layer =
+    boost::dynamic_pointer_cast<MemoryDataLayer<Dtype> >(net->layers()[0]);
+  if (!md_layer) {
+    throw std::runtime_error("set_input_arrays may only be called if the"
+        " first layer is a MemoryDataLayer");
+  }
+
+  // check that we were passed appropriately-sized contiguous memory
+  PyArrayObject* data_arr =
+      reinterpret_cast<PyArrayObject*>(data_obj.ptr());
+  PyArrayObject* labels_arr =
+      reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
+  CheckContiguousArray(data_arr, "data array", md_layer->channels(),
+      md_layer->height(), md_layer->width());
+  CheckContiguousArray(labels_arr, "labels array", 1, 1, 1);
+  if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
+    throw std::runtime_error("data and labels must have the same first"
+        " dimension");
+  }
+  if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
+    throw std::runtime_error("first dimensions of input arrays must be a"
+        " multiple of batch size");
+  }
+
+  md_layer->Reset(static_cast<Dtype*>(PyArray_DATA(data_arr)),
+      static_cast<Dtype*>(PyArray_DATA(labels_arr)),
+      PyArray_DIMS(data_arr)[0]);
+}
+
 BOOST_PYTHON_MODULE(_caffe) {
   // below, we prepend an underscore to methods that will be replaced
   // in Python
@@ -57,6 +128,43 @@ BOOST_PYTHON_MODULE(_caffe) {
   bp::def("set_phase_test", &set_phase_test);
   bp::def("set_device", &Caffe::SetDevice);
 
+  bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >(
+    "Net", bp::init<string>())
+    .def("__init__", bp::make_constructor(&Net_Init))
+    .def("_forward", &Net<Dtype>::ForwardFromTo)
+    .def("_backward", &Net<Dtype>::BackwardFromTo)
+    .def("reshape", &Net<Dtype>::Reshape)
+    // The cast is to select a particular overload.
+    .def("copy_from", static_cast<void (Net<Dtype>::*)(const string)>(
+        &Net<Dtype>::CopyTrainedLayersFrom))
+    .def("share_with", &Net<Dtype>::ShareTrainedLayersWith)
+    .add_property("_blobs", bp::make_function(&Net<Dtype>::blobs,
+        bp::return_internal_reference<>()))
+    .add_property("layers", bp::make_function(&Net<Dtype>::layers,
+        bp::return_internal_reference<>()))
+    .add_property("_blob_names", bp::make_function(&Net<Dtype>::blob_names,
+        bp::return_value_policy<bp::copy_const_reference>()))
+    .add_property("_layer_names", bp::make_function(&Net<Dtype>::layer_names,
+        bp::return_value_policy<bp::copy_const_reference>()))
+    .add_property("_inputs", bp::make_function(&Net<Dtype>::input_blob_indices,
+        bp::return_value_policy<bp::copy_const_reference>()))
+    .add_property("_outputs",
+        bp::make_function(&Net<Dtype>::output_blob_indices,
+        bp::return_value_policy<bp::copy_const_reference>()))
+    .def("_set_input_arrays", &Net_SetInputArrays,
+        bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
+    .def("save", &Net_Save);
+
+  // vector wrappers for all the vector types we use
+  bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
+    .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
+  bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec")
+    .def(bp::vector_indexing_suite<vector<shared_ptr<Layer<Dtype> > >, true>());
+  bp::class_<vector<string> >("StringVec")
+    .def(bp::vector_indexing_suite<vector<string> >());
+  bp::class_<vector<int> >("IntVec")
+    .def(bp::vector_indexing_suite<vector<int> >());
+
   import_array();
 }
 
index 31dc1f9..f6e8fbf 100644 (file)
@@ -376,6 +376,13 @@ def _Net_batch(self, blobs):
                                                  padding])
         yield padded_batch
 
+@property
+def _Net_inputs(self):
+    return [self.blobs.keys()[i] for i in self._inputs]
+
+@property
+def _Net_outputs(self):
+    return [self.blobs.keys()[i] for i in self._outputs]
 
 # Attach methods to Net.
 Net.blobs = _Net_blobs
@@ -392,3 +399,5 @@ Net.preprocess = _Net_preprocess
 Net.deprocess = _Net_deprocess
 Net.set_input_arrays = _Net_set_input_arrays
 Net._batch = _Net_batch
+Net.inputs = _Net_inputs
+Net.outputs = _Net_outputs