1 #include <Python.h> // NOLINT(build/include_alpha)
3 // Produce deprecation warnings (needs to come before arrayobject.h inclusion).
4 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
6 #include <boost/make_shared.hpp>
7 #include <boost/python.hpp>
8 #include <boost/python/raw_function.hpp>
9 #include <boost/python/suite/indexing/vector_indexing_suite.hpp>
10 #include <numpy/arrayobject.h>
12 // these need to be included after boost on OS X
13 #include <string> // NOLINT(build/include_order)
14 #include <vector> // NOLINT(build/include_order)
15 #include <fstream> // NOLINT
17 #include "caffe/caffe.hpp"
18 #include "caffe/layers/memory_data_layer.hpp"
19 #include "caffe/layers/python_layer.hpp"
20 #include "caffe/sgd_solvers.hpp"
22 // Temporary solution for numpy < 1.7 versions: old macro, no promises.
23 // You're strongly advised to upgrade to >= 1.7.
24 #ifndef NPY_ARRAY_C_CONTIGUOUS
25 #define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS
26 #define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x))
29 /* Fix to avoid registration warnings in pycaffe (#3960) */
30 #define BP_REGISTER_SHARED_PTR_TO_PYTHON(PTR) do { \
31 const boost::python::type_info info = \
32 boost::python::type_id<shared_ptr<PTR > >(); \
33 const boost::python::converter::registration* reg = \
34 boost::python::converter::registry::query(info); \
36 bp::register_ptr_to_python<shared_ptr<PTR > >(); \
37 } else if ((*reg).m_to_python == NULL) { \
38 bp::register_ptr_to_python<shared_ptr<PTR > >(); \
42 namespace bp = boost::python;
46 // For Python, for now, we'll just always use float as the type.
48 const int NPY_DTYPE = NPY_FLOAT32;
51 void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
52 void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
54 void InitLog(int level) {
55 FLAGS_logtostderr = 1;
56 FLAGS_minloglevel = level;
57 ::google::InitGoogleLogging("");
58 ::google::InstallFailureSignalHandler();
61 InitLog(google::INFO);
63 void Log(const string& s) {
67 void set_random_seed(unsigned int seed) { Caffe::set_random_seed(seed); }
69 // For convenience, check that input files can be opened, and raise an
70 // exception that boost will send to Python if not (caffe could still crash
71 // later if the input files are disturbed before they are actually used, but
72 // this saves frustration in most cases).
73 static void CheckFile(const string& filename) {
74 std::ifstream f(filename.c_str());
77 throw std::runtime_error("Could not open file " + filename);
82 void CheckContiguousArray(PyArrayObject* arr, string name,
83 int channels, int height, int width) {
84 if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
85 throw std::runtime_error(name + " must be C contiguous");
87 if (PyArray_NDIM(arr) != 4) {
88 throw std::runtime_error(name + " must be 4-d");
90 if (PyArray_TYPE(arr) != NPY_FLOAT32) {
91 throw std::runtime_error(name + " must be float32");
93 if (PyArray_DIMS(arr)[1] != channels) {
94 throw std::runtime_error(name + " has wrong number of channels");
96 if (PyArray_DIMS(arr)[2] != height) {
97 throw std::runtime_error(name + " has wrong height");
99 if (PyArray_DIMS(arr)[3] != width) {
100 throw std::runtime_error(name + " has wrong width");
105 shared_ptr<Net<Dtype> > Net_Init(string network_file, int phase,
106 const int level, const bp::object& stages,
107 const bp::object& weights) {
108 CheckFile(network_file);
110 // Convert stages from list to vector
111 vector<string> stages_vector;
112 if (!stages.is_none()) {
113 for (int i = 0; i < len(stages); i++) {
114 stages_vector.push_back(bp::extract<string>(stages[i]));
119 shared_ptr<Net<Dtype> > net(new Net<Dtype>(network_file,
120 static_cast<Phase>(phase), level, &stages_vector));
123 if (!weights.is_none()) {
124 std::string weights_file_str = bp::extract<std::string>(weights);
125 CheckFile(weights_file_str);
126 net->CopyTrainedLayersFrom(weights_file_str);
132 // Legacy Net construct-and-load convenience constructor
133 shared_ptr<Net<Dtype> > Net_Init_Load(
134 string param_file, string pretrained_param_file, int phase) {
135 LOG(WARNING) << "DEPRECATION WARNING - deprecated use of Python interface";
136 LOG(WARNING) << "Use this instead (with the named \"weights\""
138 LOG(WARNING) << "Net('" << param_file << "', " << phase
139 << ", weights='" << pretrained_param_file << "')";
140 CheckFile(param_file);
141 CheckFile(pretrained_param_file);
143 shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
144 static_cast<Phase>(phase)));
145 net->CopyTrainedLayersFrom(pretrained_param_file);
149 void Net_Save(const Net<Dtype>& net, string filename) {
150 NetParameter net_param;
151 net.ToProto(&net_param, false);
152 WriteProtoToBinaryFile(net_param, filename.c_str());
155 void Net_SaveHDF5(const Net<Dtype>& net, string filename) {
156 net.ToHDF5(filename);
159 void Net_LoadHDF5(Net<Dtype>* net, string filename) {
160 net->CopyTrainedLayersFromHDF5(filename.c_str());
163 void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
164 bp::object labels_obj) {
165 // check that this network has an input MemoryDataLayer
166 shared_ptr<MemoryDataLayer<Dtype> > md_layer =
167 boost::dynamic_pointer_cast<MemoryDataLayer<Dtype> >(net->layers()[0]);
169 throw std::runtime_error("set_input_arrays may only be called if the"
170 " first layer is a MemoryDataLayer");
173 // check that we were passed appropriately-sized contiguous memory
174 PyArrayObject* data_arr =
175 reinterpret_cast<PyArrayObject*>(data_obj.ptr());
176 PyArrayObject* labels_arr =
177 reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
178 CheckContiguousArray(data_arr, "data array", md_layer->channels(),
179 md_layer->height(), md_layer->width());
180 CheckContiguousArray(labels_arr, "labels array", 1, 1, 1);
181 if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
182 throw std::runtime_error("data and labels must have the same first"
185 if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
186 throw std::runtime_error("first dimensions of input arrays must be a"
187 " multiple of batch size");
190 md_layer->Reset(static_cast<Dtype*>(PyArray_DATA(data_arr)),
191 static_cast<Dtype*>(PyArray_DATA(labels_arr)),
192 PyArray_DIMS(data_arr)[0]);
195 Solver<Dtype>* GetSolverFromFile(const string& filename) {
196 SolverParameter param;
197 ReadSolverParamsFromTextFileOrDie(filename, ¶m);
198 return SolverRegistry<Dtype>::CreateSolver(param);
201 struct NdarrayConverterGenerator {
202 template <typename T> struct apply;
206 struct NdarrayConverterGenerator::apply<Dtype*> {
208 PyObject* operator() (Dtype* data) const {
209 // Just store the data pointer, and add the shape information in postcall.
210 return PyArray_SimpleNewFromData(0, NULL, NPY_DTYPE, data);
212 const PyTypeObject* get_pytype() {
213 return &PyArray_Type;
218 struct NdarrayCallPolicies : public bp::default_call_policies {
219 typedef NdarrayConverterGenerator result_converter;
220 PyObject* postcall(PyObject* pyargs, PyObject* result) {
221 bp::object pyblob = bp::extract<bp::tuple>(pyargs)()[0];
222 shared_ptr<Blob<Dtype> > blob =
223 bp::extract<shared_ptr<Blob<Dtype> > >(pyblob);
224 // Free the temporary pointer-holding array, and construct a new one with
225 // the shape information from the blob.
226 void* data = PyArray_DATA(reinterpret_cast<PyArrayObject*>(result));
228 const int num_axes = blob->num_axes();
229 vector<npy_intp> dims(blob->shape().begin(), blob->shape().end());
230 PyObject *arr_obj = PyArray_SimpleNewFromData(num_axes, dims.data(),
232 // SetBaseObject steals a ref, so we need to INCREF.
233 Py_INCREF(pyblob.ptr());
234 PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(arr_obj),
240 bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) {
241 if (bp::len(kwargs) > 0) {
242 throw std::runtime_error("Blob.reshape takes no kwargs");
244 Blob<Dtype>* self = bp::extract<Blob<Dtype>*>(args[0]);
245 vector<int> shape(bp::len(args) - 1);
246 for (int i = 1; i < bp::len(args); ++i) {
247 shape[i - 1] = bp::extract<int>(args[i]);
249 self->Reshape(shape);
250 // We need to explicitly return None to use bp::raw_function.
254 bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
255 if (bp::len(kwargs) > 0) {
256 throw std::runtime_error("BlobVec.add_blob takes no kwargs");
258 typedef vector<shared_ptr<Blob<Dtype> > > BlobVec;
259 BlobVec* self = bp::extract<BlobVec*>(args[0]);
260 vector<int> shape(bp::len(args) - 1);
261 for (int i = 1; i < bp::len(args); ++i) {
262 shape[i - 1] = bp::extract<int>(args[i]);
264 self->push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
265 // We need to explicitly return None to use bp::raw_function.
269 template<typename Dtype>
270 class SolverCallback: public Solver<Dtype>::Callback {
272 bp::object on_start_, on_gradients_ready_;
275 SolverCallback(bp::object on_start, bp::object on_gradients_ready)
276 : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
277 virtual void on_gradients_ready() {
278 on_gradients_ready_();
280 virtual void on_start() {
284 template<typename Dtype>
285 void Solver_add_callback(Solver<Dtype> * solver, bp::object on_start,
286 bp::object on_gradients_ready) {
287 solver->add_callback(new SolverCallback<Dtype>(on_start, on_gradients_ready));
290 // Seems boost cannot call the base method directly
291 void Solver_add_nccl(SGDSolver<Dtype>* solver
297 solver->add_callback(nccl);
301 template<typename Dtype>
302 class NetCallback: public Net<Dtype>::Callback {
304 explicit NetCallback(bp::object run) : run_(run) {}
307 virtual void run(int layer) {
312 void Net_before_forward(Net<Dtype>* net, bp::object run) {
313 net->add_before_forward(new NetCallback<Dtype>(run));
315 void Net_after_forward(Net<Dtype>* net, bp::object run) {
316 net->add_after_forward(new NetCallback<Dtype>(run));
318 void Net_before_backward(Net<Dtype>* net, bp::object run) {
319 net->add_before_backward(new NetCallback<Dtype>(run));
321 void Net_after_backward(Net<Dtype>* net, bp::object run) {
322 net->add_after_backward(new NetCallback<Dtype>(run));
325 void Net_add_nccl(Net<Dtype>* net
331 net->add_after_backward(nccl);
335 template<typename Dtype>
338 NCCL(shared_ptr<Solver<Dtype> > solver, const string& uid) {}
342 BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
344 BOOST_PYTHON_MODULE(_caffe) {
345 // below, we prepend an underscore to methods that will be replaced
348 bp::scope().attr("__version__") = AS_STRING(CAFFE_VERSION);
350 // Caffe utility functions
351 bp::def("init_log", &InitLog);
352 bp::def("init_log", &InitLogInfo);
353 bp::def("log", &Log);
354 bp::def("set_mode_cpu", &set_mode_cpu);
355 bp::def("set_mode_gpu", &set_mode_gpu);
356 bp::def("set_random_seed", &set_random_seed);
357 bp::def("set_device", &Caffe::SetDevice);
358 bp::def("solver_count", &Caffe::solver_count);
359 bp::def("set_solver_count", &Caffe::set_solver_count);
360 bp::def("solver_rank", &Caffe::solver_rank);
361 bp::def("set_solver_rank", &Caffe::set_solver_rank);
363 bp::def("layer_type_list", &LayerRegistry<Dtype>::LayerTypeList);
365 bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
368 .def("__init__", bp::make_constructor(&Net_Init,
369 bp::default_call_policies(), (bp::arg("network_file"), "phase",
370 bp::arg("level")=0, bp::arg("stages")=bp::object(),
371 bp::arg("weights")=bp::object())))
372 // Legacy constructor
373 .def("__init__", bp::make_constructor(&Net_Init_Load))
374 .def("_forward", &Net<Dtype>::ForwardFromTo)
375 .def("_backward", &Net<Dtype>::BackwardFromTo)
376 .def("reshape", &Net<Dtype>::Reshape)
377 .def("clear_param_diffs", &Net<Dtype>::ClearParamDiffs)
378 // The cast is to select a particular overload.
379 .def("copy_from", static_cast<void (Net<Dtype>::*)(const string)>(
380 &Net<Dtype>::CopyTrainedLayersFrom))
381 .def("share_with", &Net<Dtype>::ShareTrainedLayersWith)
382 .add_property("_blob_loss_weights", bp::make_function(
383 &Net<Dtype>::blob_loss_weights, bp::return_internal_reference<>()))
384 .def("_bottom_ids", bp::make_function(&Net<Dtype>::bottom_ids,
385 bp::return_value_policy<bp::copy_const_reference>()))
386 .def("_top_ids", bp::make_function(&Net<Dtype>::top_ids,
387 bp::return_value_policy<bp::copy_const_reference>()))
388 .add_property("_blobs", bp::make_function(&Net<Dtype>::blobs,
389 bp::return_internal_reference<>()))
390 .add_property("layers", bp::make_function(&Net<Dtype>::layers,
391 bp::return_internal_reference<>()))
392 .add_property("_blob_names", bp::make_function(&Net<Dtype>::blob_names,
393 bp::return_value_policy<bp::copy_const_reference>()))
394 .add_property("_layer_names", bp::make_function(&Net<Dtype>::layer_names,
395 bp::return_value_policy<bp::copy_const_reference>()))
396 .add_property("_inputs", bp::make_function(&Net<Dtype>::input_blob_indices,
397 bp::return_value_policy<bp::copy_const_reference>()))
398 .add_property("_outputs",
399 bp::make_function(&Net<Dtype>::output_blob_indices,
400 bp::return_value_policy<bp::copy_const_reference>()))
401 .def("_set_input_arrays", &Net_SetInputArrays,
402 bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
403 .def("save", &Net_Save)
404 .def("save_hdf5", &Net_SaveHDF5)
405 .def("load_hdf5", &Net_LoadHDF5)
406 .def("before_forward", &Net_before_forward)
407 .def("after_forward", &Net_after_forward)
408 .def("before_backward", &Net_before_backward)
409 .def("after_backward", &Net_after_backward)
410 .def("after_backward", &Net_add_nccl);
411 BP_REGISTER_SHARED_PTR_TO_PYTHON(Net<Dtype>);
413 bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
415 .add_property("shape",
417 static_cast<const vector<int>& (Blob<Dtype>::*)() const>(
418 &Blob<Dtype>::shape),
419 bp::return_value_policy<bp::copy_const_reference>()))
420 .add_property("num", &Blob<Dtype>::num)
421 .add_property("channels", &Blob<Dtype>::channels)
422 .add_property("height", &Blob<Dtype>::height)
423 .add_property("width", &Blob<Dtype>::width)
424 .add_property("count", static_cast<int (Blob<Dtype>::*)() const>(
425 &Blob<Dtype>::count))
426 .def("reshape", bp::raw_function(&Blob_Reshape))
427 .add_property("data", bp::make_function(&Blob<Dtype>::mutable_cpu_data,
428 NdarrayCallPolicies()))
429 .add_property("diff", bp::make_function(&Blob<Dtype>::mutable_cpu_diff,
430 NdarrayCallPolicies()));
431 BP_REGISTER_SHARED_PTR_TO_PYTHON(Blob<Dtype>);
433 bp::class_<Layer<Dtype>, shared_ptr<PythonLayer<Dtype> >,
434 boost::noncopyable>("Layer", bp::init<const LayerParameter&>())
435 .add_property("blobs", bp::make_function(&Layer<Dtype>::blobs,
436 bp::return_internal_reference<>()))
437 .def("setup", &Layer<Dtype>::LayerSetUp)
438 .def("reshape", &Layer<Dtype>::Reshape)
439 .add_property("type", bp::make_function(&Layer<Dtype>::type));
440 BP_REGISTER_SHARED_PTR_TO_PYTHON(Layer<Dtype>);
442 bp::class_<SolverParameter>("SolverParameter", bp::no_init)
443 .add_property("max_iter", &SolverParameter::max_iter)
444 .add_property("display", &SolverParameter::display)
445 .add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce);
446 bp::class_<LayerParameter>("LayerParameter", bp::no_init);
448 bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
449 "Solver", bp::no_init)
450 .add_property("net", &Solver<Dtype>::net)
451 .add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
452 bp::return_internal_reference<>()))
453 .add_property("iter", &Solver<Dtype>::iter)
454 .def("add_callback", &Solver_add_callback<Dtype>)
455 .def("add_callback", &Solver_add_nccl)
456 .def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
457 &Solver<Dtype>::Solve), SolveOverloads())
458 .def("step", &Solver<Dtype>::Step)
459 .def("restore", &Solver<Dtype>::Restore)
460 .def("snapshot", &Solver<Dtype>::Snapshot)
461 .add_property("param", bp::make_function(&Solver<Dtype>::param,
462 bp::return_value_policy<bp::copy_const_reference>()));
463 BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver<Dtype>);
465 bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
466 shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
467 "SGDSolver", bp::init<string>());
468 bp::class_<NesterovSolver<Dtype>, bp::bases<Solver<Dtype> >,
469 shared_ptr<NesterovSolver<Dtype> >, boost::noncopyable>(
470 "NesterovSolver", bp::init<string>());
471 bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
472 shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
473 "AdaGradSolver", bp::init<string>());
474 bp::class_<RMSPropSolver<Dtype>, bp::bases<Solver<Dtype> >,
475 shared_ptr<RMSPropSolver<Dtype> >, boost::noncopyable>(
476 "RMSPropSolver", bp::init<string>());
477 bp::class_<AdaDeltaSolver<Dtype>, bp::bases<Solver<Dtype> >,
478 shared_ptr<AdaDeltaSolver<Dtype> >, boost::noncopyable>(
479 "AdaDeltaSolver", bp::init<string>());
480 bp::class_<AdamSolver<Dtype>, bp::bases<Solver<Dtype> >,
481 shared_ptr<AdamSolver<Dtype> >, boost::noncopyable>(
482 "AdamSolver", bp::init<string>());
484 bp::def("get_solver", &GetSolverFromFile,
485 bp::return_value_policy<bp::manage_new_object>());
487 // vector wrappers for all the vector types we use
488 bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
489 .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>())
490 .def("add_blob", bp::raw_function(&BlobVec_add_blob));
491 bp::class_<vector<Blob<Dtype>*> >("RawBlobVec")
492 .def(bp::vector_indexing_suite<vector<Blob<Dtype>*>, true>());
493 bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec")
494 .def(bp::vector_indexing_suite<vector<shared_ptr<Layer<Dtype> > >, true>());
495 bp::class_<vector<string> >("StringVec")
496 .def(bp::vector_indexing_suite<vector<string> >());
497 bp::class_<vector<int> >("IntVec")
498 .def(bp::vector_indexing_suite<vector<int> >());
499 bp::class_<vector<Dtype> >("DtypeVec")
500 .def(bp::vector_indexing_suite<vector<Dtype> >());
501 bp::class_<vector<shared_ptr<Net<Dtype> > > >("NetVec")
502 .def(bp::vector_indexing_suite<vector<shared_ptr<Net<Dtype> > >, true>());
503 bp::class_<vector<bool> >("BoolVec")
504 .def(bp::vector_indexing_suite<vector<bool> >());
506 bp::class_<NCCL<Dtype>, shared_ptr<NCCL<Dtype> >,
507 boost::noncopyable>("NCCL",
508 bp::init<shared_ptr<Solver<Dtype> >, const string&>())
510 .def("new_uid", &NCCL<Dtype>::new_uid).staticmethod("new_uid")
511 .def("bcast", &NCCL<Dtype>::Broadcast)
513 /* NOLINT_NEXT_LINE(whitespace/semicolon) */
515 BP_REGISTER_SHARED_PTR_TO_PYTHON(NCCL<Dtype>);
517 bp::class_<Timer, shared_ptr<Timer>, boost::noncopyable>(
518 "Timer", bp::init<>())
519 .def("start", &Timer::Start)
520 .def("stop", &Timer::Stop)
521 .add_property("ms", &Timer::MilliSeconds);
522 BP_REGISTER_SHARED_PTR_TO_PYTHON(Timer);
524 // boost python expects a void (missing) return value, while import_array
525 // returns NULL for python3. import_array1() forces a void return value.