working update
authorYangqing Jia <jiayq84@gmail.com>
Mon, 16 Sep 2013 23:52:26 +0000 (16:52 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 16 Sep 2013 23:52:26 +0000 (16:52 -0700)
README.md
src/Makefile
src/caffeine/common.cpp
src/caffeine/common.hpp
src/caffeine/dropout_layer.cu
src/caffeine/filler.hpp
src/caffeine/relu_layer.cu
src/caffeine/vision_layers.hpp

index 7b930cc..72ff00a 100644 (file)
--- a/README.md
+++ b/README.md
@@ -1,4 +1,6 @@
 caffeine
 ========
 
-caffeine.
+caffeine: Convolutional Algorithms For Feature Extraction.
+
+Copyright Yangqing Jia
index 9ab43e5..111ca16 100644 (file)
@@ -31,7 +31,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64
 
 INCLUDE_DIRS := . /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
 LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
-LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread
+LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand
 WARNINGS := -Wall
 
 CXXFLAGS += -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
index 7ac0ead..d5a7b22 100644 (file)
@@ -7,11 +7,16 @@ shared_ptr<Caffeine> Caffeine::singleton_;
 Caffeine::Caffeine()
     : mode_(Caffeine::CPU), phase_(Caffeine::TRAIN) {
   CUBLAS_CHECK(cublasCreate(&cublas_handle_));
+  CURAND_CHECK(curandCreateGenerator(&curand_generator_,
+      CURAND_RNG_PSEUDO_XORWOW));
   VSL_CHECK(vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701));
 }
 
 Caffeine::~Caffeine() {
   if (!cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
+  if (!curand_generator_) {
+    CURAND_CHECK(curandDestroyGenerator(curand_generator_));
+  }
   if (!vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
 };
 
@@ -30,6 +35,10 @@ cublasHandle_t Caffeine::cublas_handle() {
   return Get().cublas_handle_;
 };
 
+curandGenerator_t Caffeine::curand_generator() {
+  return Get().curand_generator_;
+};
+
 Caffeine::Brew Caffeine::mode() {
   return Get().mode_;
 }
index 060d1f7..080cb9a 100644 (file)
@@ -3,6 +3,8 @@
 
 #include <boost/shared_ptr.hpp>
 #include <cublas_v2.h>
+#include <cuda.h>
+#include <curand.h>
 #include <glog/logging.h>
 #include <mkl_vsl.h>
 
@@ -10,6 +12,7 @@
 
 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
+#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
 
 namespace caffeine {
@@ -21,6 +24,10 @@ using boost::shared_ptr;
 // For backward compatibility we will just use 512 threads per block
 const int CAFFEINE_CUDA_NUM_THREADS = 512;
 
+inline int CAFFEINE_GET_BLOCKS(const int N) {
+  return (N + CAFFEINE_CUDA_NUM_THREADS - 1) / CAFFEINE_CUDA_NUM_THREADS;
+}
+
 // A singleton class to hold common caffeine stuff, such as the handler that
 // caffeine is going to use for cublas.
 class Caffeine {
@@ -32,6 +39,7 @@ class Caffeine {
 
   // The getters for the variables. 
   static cublasHandle_t cublas_handle();
+  static curandGenerator_t curand_generator();
   static VSLStreamStatePtr vsl_stream();
   static Brew mode();
   static Phase phase();
@@ -42,6 +50,7 @@ class Caffeine {
   Caffeine();
   static shared_ptr<Caffeine> singleton_;
   cublasHandle_t cublas_handle_;
+  curandGenerator_t curand_generator_;
   VSLStreamStatePtr vsl_stream_;
   Brew mode_;
   Phase phase_;
index 23999fb..29398dd 100644 (file)
@@ -1,6 +1,11 @@
+#include <algorithm>
+#include <limits>
+
+#include "caffeine/common.hpp"
 #include "caffeine/layer.hpp"
+#include "caffeine/syncedmem.hpp"
 #include "caffeine/vision_layers.hpp"
-#include <algorithm>
+
 
 using std::max;
 
@@ -11,24 +16,29 @@ void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   NeuronLayer<Dtype>::SetUp(bottom, top);
   // Set up the cache for random number generation
-  rand_mat_.reset(new Blob<float>(bottom.num(), bottom.channels(),
-      bottom.height(), bottom.width());
-  filler_.reset(new UniformFiller<float>(FillerParameter()));
+  rand_vec_.reset(new SyncedMemory(bottom[0]->count() * sizeof(int)));
 };
 
 template <typename Dtype>
 void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
-  // First, create the random matrix
-  filler_->Fill(rand_mat_.get()); 
   const Dtype* bottom_data = bottom[0]->cpu_data();
-  const Dtype* rand_vals = rand_mat_->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  float threshold = layer_param_->dropout_ratio();
-  float scale = layer_param_->dropo
+  float threshold = this->layer_param_.dropout_ratio();
+  DCHECK(threshold > 0.);
+  DCHECK(threshold < 1.);
+  float scale = 1. / threshold;
   const int count = bottom[0]->count();
-  for (int i = 0; i < count; ++i) {
-    top_data[i] = rand_mat_ > ;
+  if (Caffeine::phase() == Caffeine::TRAIN) {
+    // Create random numbers
+    viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffeine::vsl_stream(),
+        count, (int*)(rand_vec_->mutable_cpu_data()),
+        1. - threshold);
+    for (int i = 0; i < count; ++i) {
+      top_data[i] = bottom_data[i] * rand_vec_[i] * scale;
+    }
+  } else {
+    memcpy(top_data, bottom_data, bottom[0]->count() * sizeof(Dtype));
   }
 }
 
@@ -36,23 +46,25 @@ template <typename Dtype>
 Dtype DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
+  CHECK(Caffeine::phase() == Caffeine::TRAIN);
   if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    const int* mask = (int*)(rand_vec_->cpu_data());
     const int count = (*bottom)[0]->count();
     for (int i = 0; i < count; ++i) {
-      bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
+      bottom_diff[i] = top_diff[i] * mask[i];
     }
   }
   return Dtype(0);
 }
 
 template <typename Dtype>
-__global__ void DropoutForward(const int n, const Dtype* in, Dtype* out) {
+__global__ void DropoutForward(const int n, const Dtype* in,
+    const unsigned int* mask, const unsigned int threshold, Dtype* out) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
-    out[index] = max(in[index], Dtype(0.));
+    out[index] = in[index] * (mask[index] > threshold);
   }
 }
 
@@ -61,19 +73,32 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  float threshold = this->layer_param_.dropout_ratio();
+  DCHECK(threshold > 0.);
+  DCHECK(threshold < 1.);
+  float scale = 1. / threshold;
   const int count = bottom[0]->count();
-  const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
-      CAFFEINE_CUDA_NUM_THREADS;
-  DropoutForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
-      top_data);
+  if (Caffeine::phase() == Caffeine::TRAIN) {
+    // Create random numbers
+    CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
+        (unsigned int*)(rand_vec_->mutable_gpu_data()), count));
+    unsigned int uint_thres = (unsigned int)(UINT_MAX * threshold);
+    // set thresholds
+    DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+        count, bottom_data, (unsigned int*)(rand_vec_->gpu_data(), uint_thres,
+        top_data);
+  } else {
+    CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
+        count * sizeof(Dtype)));
+  }
 }
 
 template <typename Dtype>
 __global__ void DropoutBackward(const int n, const Dtype* in_diff,
-    const Dtype* in_data, Dtype* out_diff) {
+    const unsigned int* mask, const unsigned int threshold, Dtype* out_diff) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
-    out_diff[index] = in_diff[index] * (in_data[index] >= 0);
+    out_diff[index] = in_diff[index] * (mask[index] > threshold);
   }
 }
 
@@ -81,15 +106,15 @@ template <typename Dtype>
 Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
+  CHECK(Caffeine::phase() == Caffeine::TRAIN);
   if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+    const unsigned int* mask = (int*)(rand_vec_->gpu_data());
     const int count = (*bottom)[0]->count();
-    const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
-        CAFFEINE_CUDA_NUM_THREADS;
-    DropoutBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
-        bottom_data, bottom_diff);
+    DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+        count, top_diff, (unsigned int*)(rand_vec_->gpu_data(), uint_thres,
+        bottom_diff);
   }
   return Dtype(0);
 }
index 04ba649..880e615 100644 (file)
@@ -1,3 +1,7 @@
+// Fillers are random number generators that fills a blob using the specified
+// algorithm. The expectation is that they are only going to be used during
+// initialization time and will not involve any GPUs.
+
 #ifndef CAFFEINE_FILLER_HPP
 #define CAFFEINE_FILLER_HPP
 
index 158131a..fb95b04 100644 (file)
@@ -47,10 +47,8 @@ void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   const int count = bottom[0]->count();
-  const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
-      CAFFEINE_CUDA_NUM_THREADS;
-  ReLUForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
-      top_data);
+  ReLUForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+      count, bottom_data, top_data);
 }
 
 template <typename Dtype>
@@ -71,10 +69,8 @@ Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
     const int count = (*bottom)[0]->count();
-    const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
-        CAFFEINE_CUDA_NUM_THREADS;
-    ReLUBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
-        bottom_data, bottom_diff);
+    ReLUBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+        count, top_diff, bottom_data, bottom_diff);
   }
   return Dtype(0);
 }
index 08561bc..f1cea34 100644 (file)
@@ -48,9 +48,7 @@ class DropoutLayer : public NeuronLayer<Dtype> {
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
- private:
-  shared_ptr<Blob<float> > rand_mat_;
-  shared_ptr<UniformFiller<float> > filler_;
+  shared_ptr<SyncedMemory> rand_vec_;
 };