From 0c9cc62379e4061b58b0dfa257d79c2ecaeb2be8 Mon Sep 17 00:00:00 2001 From: Guillaume Dumont Date: Sat, 11 Mar 2017 20:12:40 -0500 Subject: [PATCH] Added support for python 3 and NCCL --- python/caffe/__init__.py | 2 +- python/caffe/_caffe.cpp | 32 +++++++++++++++++++++++++++++++- python/caffe/test/test_nccl.py | 19 +++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 python/caffe/test/test_nccl.py diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py index 80f5171..776945e 100644 --- a/python/caffe/__init__.py +++ b/python/caffe/__init__.py @@ -1,5 +1,5 @@ from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver, NCCL, Timer -from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, solver_count, set_solver_count, solver_rank, set_solver_rank, set_multiprocess +from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, solver_count, set_solver_count, solver_rank, set_solver_rank, set_multiprocess, has_nccl from ._caffe import __version__ from .proto.caffe_pb2 import TRAIN, TEST from .classifier import Classifier diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 01b34b8..7fc06c0 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -347,6 +347,35 @@ class NCCL { }; #endif +bool HasNCCL() { +#ifdef USE_NCCL + return true; +#else + return false; +#endif +} + +#ifdef USE_NCCL +bp::object NCCL_New_Uid() { + std::string uid = NCCL::new_uid(); +#if PY_MAJOR_VERSION >= 3 + // Convert std::string to bytes so that Python does not + // try to decode the string using the current locale. + + // Since boost 1.53 boost.python will convert str and bytes + // to std::string but will convert std::string to str. Here we + // force a bytes object to be returned. When this object + // is passed back to the NCCL constructor boost.python will + // correctly convert the bytes to std::string automatically + PyObject* py_uid = PyBytes_FromString(uid.c_str()); + return bp::object(bp::handle<>(py_uid)); +#else + // automatic conversion is correct for python 2. + return uid; +#endif +} +#endif + BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1); BOOST_PYTHON_MODULE(_caffe) { @@ -360,6 +389,7 @@ BOOST_PYTHON_MODULE(_caffe) { bp::def("init_log", &InitLogLevel); bp::def("init_log", &InitLogLevelPipe); bp::def("log", &Log); + bp::def("has_nccl", &HasNCCL); bp::def("set_mode_cpu", &set_mode_cpu); bp::def("set_mode_gpu", &set_mode_gpu); bp::def("set_random_seed", &set_random_seed); @@ -518,7 +548,7 @@ BOOST_PYTHON_MODULE(_caffe) { boost::noncopyable>("NCCL", bp::init >, const string&>()) #ifdef USE_NCCL - .def("new_uid", &NCCL::new_uid).staticmethod("new_uid") + .def("new_uid", NCCL_New_Uid).staticmethod("new_uid") .def("bcast", &NCCL::Broadcast) #endif /* NOLINT_NEXT_LINE(whitespace/semicolon) */ diff --git a/python/caffe/test/test_nccl.py b/python/caffe/test/test_nccl.py new file mode 100644 index 0000000..127a933 --- /dev/null +++ b/python/caffe/test/test_nccl.py @@ -0,0 +1,19 @@ +import sys +import unittest + +import caffe + + +class TestNCCL(unittest.TestCase): + + def test_newuid(self): + """ + Test that NCCL uids are of the proper type + according to python version + """ + if caffe.has_nccl(): + uid = caffe.NCCL.new_uid() + if sys.version_info.major >= 3: + self.assertTrue(isinstance(uid, bytes)) + else: + self.assertTrue(isinstance(uid, str)) -- 2.7.4