clean up residual mkl comments and code
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 21 Mar 2014 21:58:11 +0000 (14:58 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 21 Mar 2014 22:26:09 +0000 (15:26 -0700)
The FIXMEs about RNG were addressed by caffe_nextafter for
uniform distributions and the normal distribution concern is surely a
typo in the boost documentation, since the normal pdf is correctly
stated elsewhere in the documentation.

include/caffe/common.hpp
include/caffe/filler.hpp
src/caffe/common.cpp
src/caffe/layers/dropout_layer.cpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/layers/inner_product_layer.cu
src/caffe/test/test_common.cpp
src/caffe/test/test_util_blas.cpp
src/caffe/util/math_functions.cpp

index 9621b26..2ffc93f 100644 (file)
@@ -8,16 +8,13 @@
 #include <cublas_v2.h>
 #include <cuda.h>
 #include <curand.h>
-// cuda driver types
-#include <driver_types.h>
+#include <driver_types.h>  // cuda driver types
 #include <glog/logging.h>
-//#include <mkl_vsl.h>
 
 // various checks for different function calls.
 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
-#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
 
 #define CUDA_KERNEL_LOOP(i, n) \
   for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
@@ -46,7 +43,6 @@ private:\
 // is executed we will see a fatal log.
 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
 
-
 namespace caffe {
 
 // We will use the boost shared_ptr instead of the new C++11 one mainly
@@ -62,7 +58,6 @@ using boost::shared_ptr;
 #endif
 
 
-
 inline int CAFFE_GET_BLOCKS(const int N) {
   return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
 }
@@ -90,11 +85,9 @@ class Caffe {
     return Get().curand_generator_;
   }
 
-  // Returns the MKL random stream.
-  //inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
-
+  // boost RNG
   typedef boost::mt19937 random_generator_t;
-  inline static random_generator_t &vsl_stream() { return Get().random_generator_; }
+  inline static random_generator_t &rng_stream() { return Get().random_generator_; }
 
   // Returns the mode: running on CPU or GPU.
   inline static Brew mode() { return Get().mode_; }
@@ -108,7 +101,7 @@ class Caffe {
   inline static void set_mode(Brew mode) { Get().mode_ = mode; }
   // Sets the phase.
   inline static void set_phase(Phase phase) { Get().phase_ = phase; }
-  // Sets the random seed of both MKL and curand
+  // Sets the random seed of both boost and curand
   static void set_random_seed(const unsigned int seed);
   // Sets the device. Since we have cublas and curand stuff, set device also
   // requires us to reset those values.
@@ -119,7 +112,6 @@ class Caffe {
  protected:
   cublasHandle_t cublas_handle_;
   curandGenerator_t curand_generator_;
-  //VSLStreamStatePtr vsl_stream_;
   random_generator_t random_generator_;
 
   Brew mode_;
index d0b5baa..7c10022 100644 (file)
@@ -7,7 +7,6 @@
 #ifndef CAFFE_FILLER_HPP
 #define CAFFE_FILLER_HPP
 
-//#include <mkl.h>
 #include <string>
 
 #include "caffe/common.hpp"
index 95a5e93..29501bb 100644 (file)
@@ -22,7 +22,6 @@ int64_t cluster_seedgen(void) {
 Caffe::Caffe()
     : mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
       curand_generator_(NULL),
-      //vsl_stream_(NULL)
       random_generator_()
 {
   // Try to create a cublas handler, and report an error if failed (but we will
@@ -37,13 +36,6 @@ Caffe::Caffe()
       != CURAND_STATUS_SUCCESS) {
     LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
   }
-
-  // Try to create a vsl stream. This should almost always work, but we will
-  // check it anyway.
-  //if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
-  //  LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
-  //      << "won't be available.";
-  //}
 }
 
 Caffe::~Caffe() {
@@ -51,7 +43,6 @@ Caffe::~Caffe() {
   if (curand_generator_) {
     CURAND_CHECK(curandDestroyGenerator(curand_generator_));
   }
-  //if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
 }
 
 void Caffe::set_random_seed(const unsigned int seed) {
@@ -67,11 +58,8 @@ void Caffe::set_random_seed(const unsigned int seed) {
   } else {
     LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
   }
-  // VSL seed
-  //VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
-  //VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
+  // RNG seed
   Get().random_generator_ = random_generator_t(seed);
-
 }
 
 void Caffe::SetDevice(const int device_id) {
index bfb854b..f07547a 100644 (file)
@@ -32,8 +32,6 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const int count = bottom[0]->count();
   if (Caffe::phase() == Caffe::TRAIN) {
     // Create random numbers
-    //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
-    //    count, mask, 1. - threshold_);
     caffe_vRngBernoulli<int>(count, mask, 1. - threshold_);
     for (int i = 0; i < count; ++i) {
       top_data[i] = bottom_data[i] * mask[i] * scale_;
index a00e2f2..6ea228f 100644 (file)
@@ -1,8 +1,5 @@
 // Copyright 2013 Yangqing Jia
 
-
-//#include <mkl.h>
-
 #include <vector>
 
 #include "caffe/blob.hpp"
index 0d397dc..37463b5 100644 (file)
@@ -1,7 +1,5 @@
 // Copyright 2013 Yangqing Jia
 
-
-//#include <mkl.h>
 #include <cublas_v2.h>
 
 #include <vector>
index f5e3fe4..3ce15bb 100644 (file)
@@ -19,11 +19,6 @@ TEST_F(CommonTest, TestCublasHandler) {
   EXPECT_TRUE(Caffe::cublas_handle());
 }
 
-TEST_F(CommonTest, TestVslStream) {
-  //EXPECT_TRUE(Caffe::vsl_stream());
-    EXPECT_TRUE(true);
-}
-
 TEST_F(CommonTest, TestBrewMode) {
   Caffe::set_mode(Caffe::CPU);
   EXPECT_EQ(Caffe::mode(), Caffe::CPU);
@@ -41,13 +36,9 @@ TEST_F(CommonTest, TestRandSeedCPU) {
   SyncedMemory data_a(10 * sizeof(int));
   SyncedMemory data_b(10 * sizeof(int));
   Caffe::set_random_seed(1701);
-  //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
-  //      10, (int*)data_a.mutable_cpu_data(), 0.5);
   caffe_vRngBernoulli(10, reinterpret_cast<int*>(data_a.mutable_cpu_data()), 0.5);
 
   Caffe::set_random_seed(1701);
-  //viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
-  //      10, (int*)data_b.mutable_cpu_data(), 0.5);
   caffe_vRngBernoulli(10, reinterpret_cast<int*>(data_b.mutable_cpu_data()), 0.5);
 
   for (int i = 0; i < 10; ++i) {
@@ -56,7 +47,6 @@ TEST_F(CommonTest, TestRandSeedCPU) {
   }
 }
 
-
 TEST_F(CommonTest, TestRandSeedGPU) {
   SyncedMemory data_a(10 * sizeof(unsigned int));
   SyncedMemory data_b(10 * sizeof(unsigned int));
@@ -72,5 +62,4 @@ TEST_F(CommonTest, TestRandSeedGPU) {
   }
 }
 
-
 }  // namespace caffe
index 4ac4955..57f4eaf 100644 (file)
@@ -3,7 +3,6 @@
 #include <cstring>
 
 #include "cuda_runtime.h"
-//#include "mkl.h"
 #include "cublas_v2.h"
 
 #include "gtest/gtest.h"
index fb2b112..d68c05c 100644 (file)
@@ -2,7 +2,6 @@
 // Copyright 2014 kloudkl@github
 
 #include <limits>
-//#include <mkl.h>
 #include <boost/math/special_functions/next.hpp>
 #include <boost/random.hpp>
 
@@ -284,14 +283,10 @@ void caffe_vRngUniform(const int n, Dtype* r,
   CHECK_GE(n, 0);
   CHECK(r);
   CHECK_LE(a, b);
-  //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
-  //    n, r, a, b));
 
-  // FIXME check if boundaries are handled in the same way ?
-  // Fixed by caffe_nextafter
   boost::uniform_real<Dtype> random_distribution(
       a, caffe_nextafter<Dtype>(b));
-  Caffe::random_generator_t &generator = Caffe::vsl_stream();
+  Caffe::random_generator_t &generator = Caffe::rng_stream();
   boost::variate_generator<Caffe::random_generator_t,
       boost::uniform_real<Dtype> > variate_generator(
       generator, random_distribution);
@@ -314,17 +309,8 @@ void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
   CHECK_GE(n, 0);
   CHECK(r);
   CHECK_GT(sigma, 0);
-  //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
-//      Caffe::vsl_stream(), n, r, a, sigma));
-
-    // FIXME check if parameters are handled in the same way ?
-    // http://www.boost.org/doc/libs/1_55_0/doc/html/boost/random/normal_distribution.html
-    // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-63196F25-5013-4038-8BCD-2613C4EF3DE4.htm
-    // The above two documents show that the probability density functions are different.
-    // But the unit tests still pass. Maybe their codes are the same or
-    // the tests are irrelevant to the random numbers.
   boost::normal_distribution<Dtype> random_distribution(a, sigma);
-  Caffe::random_generator_t &generator = Caffe::vsl_stream();
+  Caffe::random_generator_t &generator = Caffe::rng_stream();
   boost::variate_generator<Caffe::random_generator_t,
       boost::normal_distribution<Dtype> > variate_generator(
       generator, random_distribution);
@@ -349,7 +335,7 @@ void caffe_vRngBernoulli(const int n, Dtype* r, const double p) {
   CHECK_GE(p, 0);
   CHECK_LE(p, 1);
   boost::bernoulli_distribution<double> random_distribution(p);
-  Caffe::random_generator_t &generator = Caffe::vsl_stream();
+  Caffe::random_generator_t &generator = Caffe::rng_stream();
   boost::variate_generator<Caffe::random_generator_t,
       boost::bernoulli_distribution<double> > variate_generator(
       generator, random_distribution);