from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver, NCCL, Timer
-from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, solver_count, set_solver_count, solver_rank, set_solver_rank, set_multiprocess
+from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, solver_count, set_solver_count, solver_rank, set_solver_rank, set_multiprocess, has_nccl
from ._caffe import __version__
from .proto.caffe_pb2 import TRAIN, TEST
from .classifier import Classifier
};
#endif
+bool HasNCCL() {
+#ifdef USE_NCCL
+ return true;
+#else
+ return false;
+#endif
+}
+
+#ifdef USE_NCCL
+bp::object NCCL_New_Uid() {
+ std::string uid = NCCL<Dtype>::new_uid();
+#if PY_MAJOR_VERSION >= 3
+ // Convert std::string to bytes so that Python does not
+ // try to decode the string using the current locale.
+
+ // Since boost 1.53 boost.python will convert str and bytes
+ // to std::string but will convert std::string to str. Here we
+ // force a bytes object to be returned. When this object
+ // is passed back to the NCCL constructor boost.python will
+ // correctly convert the bytes to std::string automatically
+ PyObject* py_uid = PyBytes_FromString(uid.c_str());
+ return bp::object(bp::handle<>(py_uid));
+#else
+ // automatic conversion is correct for python 2.
+ return uid;
+#endif
+}
+#endif
+
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
BOOST_PYTHON_MODULE(_caffe) {
bp::def("init_log", &InitLogLevel);
bp::def("init_log", &InitLogLevelPipe);
bp::def("log", &Log);
+ bp::def("has_nccl", &HasNCCL);
bp::def("set_mode_cpu", &set_mode_cpu);
bp::def("set_mode_gpu", &set_mode_gpu);
bp::def("set_random_seed", &set_random_seed);
boost::noncopyable>("NCCL",
bp::init<shared_ptr<Solver<Dtype> >, const string&>())
#ifdef USE_NCCL
- .def("new_uid", &NCCL<Dtype>::new_uid).staticmethod("new_uid")
+ .def("new_uid", NCCL_New_Uid).staticmethod("new_uid")
.def("bcast", &NCCL<Dtype>::Broadcast)
#endif
/* NOLINT_NEXT_LINE(whitespace/semicolon) */
--- /dev/null
+import sys
+import unittest
+
+import caffe
+
+
+class TestNCCL(unittest.TestCase):
+
+ def test_newuid(self):
+ """
+ Test that NCCL uids are of the proper type
+ according to python version
+ """
+ if caffe.has_nccl():
+ uid = caffe.NCCL.new_uid()
+ if sys.version_info.major >= 3:
+ self.assertTrue(isinstance(uid, bytes))
+ else:
+ self.assertTrue(isinstance(uid, str))