working update
authorYangqing Jia <jiayq84@gmail.com>
Tue, 17 Sep 2013 00:11:49 +0000 (17:11 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 17 Sep 2013 00:11:49 +0000 (17:11 -0700)
src/Makefile
src/caffeine/dropout_layer.cu
src/caffeine/vision_layers.hpp

index 111ca16..9fff4b9 100644 (file)
@@ -34,7 +34,7 @@ LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
 LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand
 WARNINGS := -Wall
 
-CXXFLAGS += -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
 LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir))
 LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))
 
index 29398dd..bfed41d 100644 (file)
@@ -17,6 +17,11 @@ void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   NeuronLayer<Dtype>::SetUp(bottom, top);
   // Set up the cache for random number generation
   rand_vec_.reset(new SyncedMemory(bottom[0]->count() * sizeof(int)));
+  threshold_ = this->layer_param_.dropout_ratio();
+  DCHECK(threshold_ > 0.);
+  DCHECK(threshold_ < 1.);
+  scale_ = 1. / (1. - threshold_);
+  uint_thres_ = (unsigned int)(UINT_MAX * threshold_);
 };
 
 template <typename Dtype>
@@ -24,18 +29,14 @@ void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  float threshold = this->layer_param_.dropout_ratio();
-  DCHECK(threshold > 0.);
-  DCHECK(threshold < 1.);
-  float scale = 1. / threshold;
+  int* mask = (int*)rand_vec_->mutable_cpu_data();
   const int count = bottom[0]->count();
   if (Caffeine::phase() == Caffeine::TRAIN) {
     // Create random numbers
     viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffeine::vsl_stream(),
-        count, (int*)(rand_vec_->mutable_cpu_data()),
-        1. - threshold);
+        count, mask, 1. - threshold_);
     for (int i = 0; i < count; ++i) {
-      top_data[i] = bottom_data[i] * rand_vec_[i] * scale;
+      top_data[i] = bottom_data[i] * mask[i] * scale_;
     }
   } else {
     memcpy(top_data, bottom_data, bottom[0]->count() * sizeof(Dtype));
@@ -53,7 +54,7 @@ Dtype DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const int* mask = (int*)(rand_vec_->cpu_data());
     const int count = (*bottom)[0]->count();
     for (int i = 0; i < count; ++i) {
-      bottom_diff[i] = top_diff[i] * mask[i];
+      bottom_diff[i] = top_diff[i] * mask[i] * scale_;
     }
   }
   return Dtype(0);
@@ -61,10 +62,11 @@ Dtype DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
 
 template <typename Dtype>
 __global__ void DropoutForward(const int n, const Dtype* in,
-    const unsigned int* mask, const unsigned int threshold, Dtype* out) {
+    const unsigned int* mask, const unsigned int threshold, const float scale,
+    Dtype* out) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
-    out[index] = in[index] * (mask[index] > threshold);
+    out[index] = in[index] * (mask[index] > threshold) * scale;
   }
 }
 
@@ -73,32 +75,28 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
-  float threshold = this->layer_param_.dropout_ratio();
-  DCHECK(threshold > 0.);
-  DCHECK(threshold < 1.);
-  float scale = 1. / threshold;
   const int count = bottom[0]->count();
   if (Caffeine::phase() == Caffeine::TRAIN) {
     // Create random numbers
     CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
         (unsigned int*)(rand_vec_->mutable_gpu_data()), count));
-    unsigned int uint_thres = (unsigned int)(UINT_MAX * threshold);
     // set thresholds
     DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
-        count, bottom_data, (unsigned int*)(rand_vec_->gpu_data(), uint_thres,
+        count, bottom_data, (unsigned int*)rand_vec_->gpu_data(), uint_thres_, scale_,
         top_data);
   } else {
     CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
-        count * sizeof(Dtype)));
+        count * sizeof(Dtype), cudaMemcpyDeviceToDevice));
   }
 }
 
 template <typename Dtype>
 __global__ void DropoutBackward(const int n, const Dtype* in_diff,
-    const unsigned int* mask, const unsigned int threshold, Dtype* out_diff) {
+    const unsigned int* mask, const unsigned int threshold, const float scale,
+    Dtype* out_diff) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
-    out_diff[index] = in_diff[index] * (mask[index] > threshold);
+    out_diff[index] = in_diff[index] * (mask[index] > threshold) * scale;
   }
 }
 
@@ -110,10 +108,10 @@ Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   if (propagate_down) {
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
-    const unsigned int* mask = (int*)(rand_vec_->gpu_data());
+    const unsigned int* mask = (unsigned int*)rand_vec_->gpu_data();
     const int count = (*bottom)[0]->count();
     DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
-        count, top_diff, (unsigned int*)(rand_vec_->gpu_data(), uint_thres,
+        count, top_diff, (unsigned int*)rand_vec_->gpu_data(), uint_thres_, scale_,
         bottom_diff);
   }
   return Dtype(0);
index f1cea34..432da7e 100644 (file)
@@ -49,6 +49,9 @@ class DropoutLayer : public NeuronLayer<Dtype> {
   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   shared_ptr<SyncedMemory> rand_vec_;
+  float threshold_;
+  float scale_;
+  unsigned int uint_thres_;
 };