Added support for python 3 and NCCL
authorGuillaume Dumont <dumont.guillaume@gmail.com>
Sun, 12 Mar 2017 01:12:40 +0000 (20:12 -0500)
committerGuillaume Dumont <dumont.guillaume@gmail.com>
Thu, 13 Apr 2017 18:53:41 +0000 (14:53 -0400)
python/caffe/__init__.py
python/caffe/_caffe.cpp
python/caffe/test/test_nccl.py [new file with mode: 0644]

index 80f5171..776945e 100644 (file)
@@ -1,5 +1,5 @@
 from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver, NCCL, Timer
-from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, solver_count, set_solver_count, solver_rank, set_solver_rank, set_multiprocess
+from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, solver_count, set_solver_count, solver_rank, set_solver_rank, set_multiprocess, has_nccl
 from ._caffe import __version__
 from .proto.caffe_pb2 import TRAIN, TEST
 from .classifier import Classifier
index 01b34b8..7fc06c0 100644 (file)
@@ -347,6 +347,35 @@ class NCCL {
 };
 #endif
 
+bool HasNCCL() {
+#ifdef USE_NCCL
+  return true;
+#else
+  return false;
+#endif
+}
+
+#ifdef USE_NCCL
+bp::object NCCL_New_Uid() {
+  std::string uid = NCCL<Dtype>::new_uid();
+#if PY_MAJOR_VERSION >= 3
+  // Convert std::string to bytes so that Python does not
+  // try to decode the string using the current locale.
+
+  // Since boost 1.53 boost.python will convert str and bytes
+  // to std::string but will convert std::string to str. Here we
+  // force a bytes object to be returned. When this object
+  // is passed back to the NCCL constructor boost.python will
+  // correctly convert the bytes to std::string automatically
+  PyObject* py_uid = PyBytes_FromString(uid.c_str());
+  return bp::object(bp::handle<>(py_uid));
+#else
+  // automatic conversion is correct for python 2.
+  return uid;
+#endif
+}
+#endif
+
 BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
 
 BOOST_PYTHON_MODULE(_caffe) {
@@ -360,6 +389,7 @@ BOOST_PYTHON_MODULE(_caffe) {
   bp::def("init_log", &InitLogLevel);
   bp::def("init_log", &InitLogLevelPipe);
   bp::def("log", &Log);
+  bp::def("has_nccl", &HasNCCL);
   bp::def("set_mode_cpu", &set_mode_cpu);
   bp::def("set_mode_gpu", &set_mode_gpu);
   bp::def("set_random_seed", &set_random_seed);
@@ -518,7 +548,7 @@ BOOST_PYTHON_MODULE(_caffe) {
     boost::noncopyable>("NCCL",
                         bp::init<shared_ptr<Solver<Dtype> >, const string&>())
 #ifdef USE_NCCL
-    .def("new_uid", &NCCL<Dtype>::new_uid).staticmethod("new_uid")
+    .def("new_uid", NCCL_New_Uid).staticmethod("new_uid")
     .def("bcast", &NCCL<Dtype>::Broadcast)
 #endif
     /* NOLINT_NEXT_LINE(whitespace/semicolon) */
diff --git a/python/caffe/test/test_nccl.py b/python/caffe/test/test_nccl.py
new file mode 100644 (file)
index 0000000..127a933
--- /dev/null
@@ -0,0 +1,19 @@
+import sys
+import unittest
+
+import caffe
+
+
+class TestNCCL(unittest.TestCase):
+
+    def test_newuid(self):
+        """
+        Test that NCCL uids are of the proper type
+        according to python version
+        """
+        if caffe.has_nccl():
+            uid = caffe.NCCL.new_uid()
+            if sys.version_info.major >= 3:
+                self.assertTrue(isinstance(uid, bytes))
+            else:
+                self.assertTrue(isinstance(uid, str))