f.close();
}
+void CheckContiguousArray(PyArrayObject* arr, string name,
+ int channels, int height, int width) {
+ if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
+ throw std::runtime_error(name + " must be C contiguous");
+ }
+ if (PyArray_NDIM(arr) != 4) {
+ throw std::runtime_error(name + " must be 4-d");
+ }
+ if (PyArray_TYPE(arr) != NPY_FLOAT32) {
+ throw std::runtime_error(name + " must be float32");
+ }
+ if (PyArray_DIMS(arr)[1] != channels) {
+ throw std::runtime_error(name + " has wrong number of channels");
+ }
+ if (PyArray_DIMS(arr)[2] != height) {
+ throw std::runtime_error(name + " has wrong height");
+ }
+ if (PyArray_DIMS(arr)[3] != width) {
+ throw std::runtime_error(name + " has wrong width");
+ }
+}
+
+// Net construct-and-load convenience constructor
+shared_ptr<Net<Dtype> > Net_Init(
+ string param_file, string pretrained_param_file) {
+ CheckFile(param_file);
+ CheckFile(pretrained_param_file);
+
+ shared_ptr<Net<Dtype> > net (new Net<Dtype>(param_file));
+ net->CopyTrainedLayersFrom(pretrained_param_file);
+ return net;
+}
+
+void Net_Save(const Net<Dtype>& net, string filename) {
+ NetParameter net_param;
+ net.ToProto(&net_param, false);
+ WriteProtoToBinaryFile(net_param, filename.c_str());
+}
+
+void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
+ bp::object labels_obj) {
+ // check that this network has an input MemoryDataLayer
+ shared_ptr<MemoryDataLayer<Dtype> > md_layer =
+ boost::dynamic_pointer_cast<MemoryDataLayer<Dtype> >(net->layers()[0]);
+ if (!md_layer) {
+ throw std::runtime_error("set_input_arrays may only be called if the"
+ " first layer is a MemoryDataLayer");
+ }
+
+ // check that we were passed appropriately-sized contiguous memory
+ PyArrayObject* data_arr =
+ reinterpret_cast<PyArrayObject*>(data_obj.ptr());
+ PyArrayObject* labels_arr =
+ reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
+ CheckContiguousArray(data_arr, "data array", md_layer->channels(),
+ md_layer->height(), md_layer->width());
+ CheckContiguousArray(labels_arr, "labels array", 1, 1, 1);
+ if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
+ throw std::runtime_error("data and labels must have the same first"
+ " dimension");
+ }
+ if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
+ throw std::runtime_error("first dimensions of input arrays must be a"
+ " multiple of batch size");
+ }
+
+ md_layer->Reset(static_cast<Dtype*>(PyArray_DATA(data_arr)),
+ static_cast<Dtype*>(PyArray_DATA(labels_arr)),
+ PyArray_DIMS(data_arr)[0]);
+}
+
BOOST_PYTHON_MODULE(_caffe) {
// below, we prepend an underscore to methods that will be replaced
// in Python
bp::def("set_phase_test", &set_phase_test);
bp::def("set_device", &Caffe::SetDevice);
+ bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >(
+ "Net", bp::init<string>())
+ .def("__init__", bp::make_constructor(&Net_Init))
+ .def("_forward", &Net<Dtype>::ForwardFromTo)
+ .def("_backward", &Net<Dtype>::BackwardFromTo)
+ .def("reshape", &Net<Dtype>::Reshape)
+ // The cast is to select a particular overload.
+ .def("copy_from", static_cast<void (Net<Dtype>::*)(const string)>(
+ &Net<Dtype>::CopyTrainedLayersFrom))
+ .def("share_with", &Net<Dtype>::ShareTrainedLayersWith)
+ .add_property("_blobs", bp::make_function(&Net<Dtype>::blobs,
+ bp::return_internal_reference<>()))
+ .add_property("layers", bp::make_function(&Net<Dtype>::layers,
+ bp::return_internal_reference<>()))
+ .add_property("_blob_names", bp::make_function(&Net<Dtype>::blob_names,
+ bp::return_value_policy<bp::copy_const_reference>()))
+ .add_property("_layer_names", bp::make_function(&Net<Dtype>::layer_names,
+ bp::return_value_policy<bp::copy_const_reference>()))
+ .add_property("_inputs", bp::make_function(&Net<Dtype>::input_blob_indices,
+ bp::return_value_policy<bp::copy_const_reference>()))
+ .add_property("_outputs",
+ bp::make_function(&Net<Dtype>::output_blob_indices,
+ bp::return_value_policy<bp::copy_const_reference>()))
+ .def("_set_input_arrays", &Net_SetInputArrays,
+ bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
+ .def("save", &Net_Save);
+
+ // vector wrappers for all the vector types we use
+ bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
+ .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
+ bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec")
+ .def(bp::vector_indexing_suite<vector<shared_ptr<Layer<Dtype> > >, true>());
+ bp::class_<vector<string> >("StringVec")
+ .def(bp::vector_indexing_suite<vector<string> >());
+ bp::class_<vector<int> >("IntVec")
+ .def(bp::vector_indexing_suite<vector<int> >());
+
import_array();
}