Hide boost rng behind facade for osx compatibility
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 22 Mar 2014 06:47:01 +0000 (23:47 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 22 Mar 2014 19:08:26 +0000 (12:08 -0700)
Split boost random number generation from the common Caffe singleton and
add a helper function for rng. This resolves a build conflict in OSX
between boost rng and nvcc compilation of cuda code.

Refer to #165 for a full discussion.

Thanks to @satol for suggesting a random number generation facade rather
than a total split of cpp and cu code, which is far more involved.

include/caffe/common.hpp
include/caffe/util/rng.hpp [new file with mode: 0644]
src/caffe/common.cpp
src/caffe/util/math_functions.cpp

index 2647b0f..ca5a348 100644 (file)
@@ -1,9 +1,9 @@
 // Copyright 2013 Yangqing Jia
+// Copyright 2014 Evan Shelhamer
 
 #ifndef CAFFE_COMMON_HPP_
 #define CAFFE_COMMON_HPP_
 
-#include <boost/random/mersenne_twister.hpp>
 #include <boost/shared_ptr.hpp>
 #include <cublas_v2.h>
 #include <cuda.h>
 #include <driver_types.h>  // cuda driver types
 #include <glog/logging.h>
 
-// various checks for different function calls.
-#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
-#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
-#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
-
-#define CUDA_KERNEL_LOOP(i, n) \
-  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
-       i < (n); \
-       i += blockDim.x * gridDim.x)
-
-// After a kernel is executed, this will check the error and if there is one,
-// exit loudly.
-#define CUDA_POST_KERNEL_CHECK \
-  if (cudaSuccess != cudaPeekAtLastError()) \
-    LOG(FATAL) << "Cuda kernel failed. Error: " \
-        << cudaGetErrorString(cudaPeekAtLastError())
-
 // Disable the copy and assignment operator for a class.
 #define DISABLE_COPY_AND_ASSIGN(classname) \
 private:\
@@ -43,24 +26,29 @@ private:\
 // is executed we will see a fatal log.
 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
 
-namespace caffe {
+// CUDA: various checks for different function calls.
+#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
+#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
+#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
 
-// We will use the boost shared_ptr instead of the new C++11 one mainly
-// because cuda does not work (at least now) well with C++11 features.
-using boost::shared_ptr;
+// CUDA: grid stride looping
+#define CUDA_KERNEL_LOOP(i, n) \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+       i < (n); \
+       i += blockDim.x * gridDim.x)
 
+// CUDA: check for error after kernel execution and exit loudly if there is one.
+#define CUDA_POST_KERNEL_CHECK \
+  if (cudaSuccess != cudaPeekAtLastError()) \
+    LOG(FATAL) << "Cuda kernel failed. Error: " \
+        << cudaGetErrorString(cudaPeekAtLastError())
 
-// We will use 1024 threads per block, which requires cuda sm_2x or above.
-#if __CUDA_ARCH__ >= 200
-    const int CAFFE_CUDA_NUM_THREADS = 1024;
-#else
-    const int CAFFE_CUDA_NUM_THREADS = 512;
-#endif
 
+namespace caffe {
 
-inline int CAFFE_GET_BLOCKS(const int N) {
-  return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
-}
+// We will use the boost shared_ptr instead of the new C++11 one mainly
+// because cuda does not work (at least now) well with C++11 features.
+using boost::shared_ptr;
 
 
 // A singleton class to hold common caffe stuff, such as the handler that
@@ -77,20 +65,32 @@ class Caffe {
   enum Brew { CPU, GPU };
   enum Phase { TRAIN, TEST };
 
-  // The getters for the variables.
-  // Returns the cublas handle.
+
+  // This random number generator facade hides boost and CUDA rng
+  // implementation from one another (for cross-platform compatibility).
+  class RNG {
+   public:
+    RNG();
+    explicit RNG(unsigned int seed);
+    ~RNG();
+    RNG(const RNG&);
+    RNG& operator=(const RNG&);
+    const void* generator() const;
+    void* generator();
+   private:
+    class Generator;
+    Generator* generator_;
+  };
+
+  // Getters for boost rng, curand, and cublas handles
+  inline static RNG &rng_stream() {
+    return Get().random_generator_;
+  }
   inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
-  // Returns the curand generator.
   inline static curandGenerator_t curand_generator() {
     return Get().curand_generator_;
   }
 
-  // boost RNG
-  typedef boost::mt19937 random_generator_t;
-  inline static random_generator_t &rng_stream() {
-    return Get().random_generator_;
-  }
-
   // Returns the mode: running on CPU or GPU.
   inline static Brew mode() { return Get().mode_; }
   // Returns the phase: TRAIN or TEST.
@@ -114,7 +114,7 @@ class Caffe {
  protected:
   cublasHandle_t cublas_handle_;
   curandGenerator_t curand_generator_;
-  random_generator_t random_generator_;
+  RNG random_generator_;
 
   Brew mode_;
   Phase phase_;
@@ -128,6 +128,21 @@ class Caffe {
 };
 
 
+// CUDA: thread number configuration.
+// Use 1024 threads per block, which requires cuda sm_2x or above,
+// or fall back to attempt compatibility (best of luck to you).
+#if __CUDA_ARCH__ >= 200
+    const int CAFFE_CUDA_NUM_THREADS = 1024;
+#else
+    const int CAFFE_CUDA_NUM_THREADS = 512;
+#endif
+
+// CUDA: number of blocks for threads.
+inline int CAFFE_GET_BLOCKS(const int N) {
+  return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+}
+
+
 }  // namespace caffe
 
 #endif  // CAFFE_COMMON_HPP_
diff --git a/include/caffe/util/rng.hpp b/include/caffe/util/rng.hpp
new file mode 100644 (file)
index 0000000..c7530c7
--- /dev/null
@@ -0,0 +1,19 @@
+// Copyright 2014 Evan Shelhamer
+
+#ifndef CAFFE_RNG_CPP_HPP_
+#define CAFFE_RNG_CPP_HPP_
+
+#include <boost/random/mersenne_twister.hpp>
+#include "caffe/common.hpp"
+
+namespace caffe {
+
+  typedef boost::mt19937 rng_t;
+  inline rng_t& caffe_rng() {
+    Caffe::RNG &generator = Caffe::rng_stream();
+    return *(caffe::rng_t*) generator.generator();
+  }
+
+}  // namespace caffe
+
+#endif  // CAFFE_RNG_HPP_
index ad52371..a25dfda 100644 (file)
@@ -1,15 +1,18 @@
 // Copyright 2013 Yangqing Jia
+// Copyright 2014 Evan Shelhamer
 
 #include <cstdio>
 #include <ctime>
 
 #include "caffe/common.hpp"
+#include "caffe/util/rng.hpp"
 
 namespace caffe {
 
 shared_ptr<Caffe> Caffe::singleton_;
 
 
+// curand seeding
 int64_t cluster_seedgen(void) {
   int64_t s, seed, pid;
   pid = getpid();
@@ -58,7 +61,7 @@ void Caffe::set_random_seed(const unsigned int seed) {
     LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
   }
   // RNG seed
-  Get().random_generator_ = random_generator_t(seed);
+  Get().random_generator_ = RNG(seed);
 }
 
 void Caffe::SetDevice(const int device_id) {
@@ -112,4 +115,37 @@ void Caffe::DeviceQuery() {
   return;
 }
 
+
+class Caffe::RNG::Generator {
+ public:
+  caffe::rng_t rng;
+};
+
+Caffe::RNG::RNG()
+: generator_(new Generator) { }
+
+Caffe::RNG::RNG(unsigned int seed)
+: generator_(new Generator) {
+  generator_->rng = caffe::rng_t(seed);
+}
+
+Caffe::RNG::~RNG() { delete generator_; }
+
+Caffe::RNG::RNG(const RNG& other) : generator_(new Generator) {
+  *generator_ = *other.generator_;
+}
+
+Caffe::RNG& Caffe::RNG::operator=(const RNG& other) {
+  *generator_ = *other.generator_;
+  return *this;
+}
+
+void* Caffe::RNG::generator() {
+  return &generator_->rng;
+}
+
+const void* Caffe::RNG::generator() const {
+  return &generator_->rng;
+}
+
 }  // namespace caffe
index 3da4b21..3d02c5f 100644 (file)
@@ -1,5 +1,6 @@
 // Copyright 2013 Yangqing Jia
 // Copyright 2014 kloudkl@github
+// Copyright 2014 Evan Shelhamer
 
 #include <boost/math/special_functions/next.hpp>
 #include <boost/random.hpp>
@@ -9,6 +10,7 @@
 
 #include "caffe/common.hpp"
 #include "caffe/util/math_functions.hpp"
+#include "caffe/util/rng.hpp"
 
 namespace caffe {
 
@@ -287,10 +289,9 @@ void caffe_vRngUniform(const int n, Dtype* r,
 
   boost::uniform_real<Dtype> random_distribution(
       a, caffe_nextafter<Dtype>(b));
-  Caffe::random_generator_t &generator = Caffe::rng_stream();
-  boost::variate_generator<Caffe::random_generator_t,
+  boost::variate_generator<caffe::rng_t,
       boost::uniform_real<Dtype> > variate_generator(
-      generator, random_distribution);
+      caffe_rng(), random_distribution);
 
   for (int i = 0; i < n; ++i) {
     r[i] = variate_generator();
@@ -311,10 +312,9 @@ void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
   CHECK(r);
   CHECK_GT(sigma, 0);
   boost::normal_distribution<Dtype> random_distribution(a, sigma);
-  Caffe::random_generator_t &generator = Caffe::rng_stream();
-  boost::variate_generator<Caffe::random_generator_t,
+  boost::variate_generator<caffe::rng_t,
       boost::normal_distribution<Dtype> > variate_generator(
-      generator, random_distribution);
+      caffe_rng(), random_distribution);
 
   for (int i = 0; i < n; ++i) {
     r[i] = variate_generator();
@@ -336,10 +336,9 @@ void caffe_vRngBernoulli(const int n, Dtype* r, const double p) {
   CHECK_GE(p, 0);
   CHECK_LE(p, 1);
   boost::bernoulli_distribution<double> random_distribution(p);
-  Caffe::random_generator_t &generator = Caffe::rng_stream();
-  boost::variate_generator<Caffe::random_generator_t,
+  boost::variate_generator<caffe::rng_t,
       boost::bernoulli_distribution<double> > variate_generator(
-      generator, random_distribution);
+      caffe_rng(), random_distribution);
 
   for (int i = 0; i < n; ++i) {
     r[i] = variate_generator();