leaky relu + unit test
authorqipeng <pengrobertqi@163.com>
Sat, 19 Jul 2014 20:24:01 +0000 (13:24 -0700)
committerqipeng <pengrobertqi@163.com>
Mon, 21 Jul 2014 22:32:22 +0000 (15:32 -0700)
src/caffe/layers/relu_layer.cpp
src/caffe/layers/relu_layer.cu
src/caffe/proto/caffe.proto
src/caffe/test/test_neuron_layer.cpp

index 1e11955..57ccf97 100644 (file)
@@ -14,8 +14,10 @@ Dtype ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
   const int count = bottom[0]->count();
+  Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
   for (int i = 0; i < count; ++i) {
-    top_data[i] = std::max(bottom_data[i], Dtype(0));
+    top_data[i] = std::max(bottom_data[i], Dtype(0))
+        + negative_slope * std::min(bottom_data[i], Dtype(0));
   }
   return Dtype(0);
 }
@@ -29,8 +31,10 @@ void ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const int count = (*bottom)[0]->count();
+    Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
     for (int i = 0; i < count; ++i) {
-      bottom_diff[i] = top_diff[i] * (bottom_data[i] > 0);
+      bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0)
+          + top_diff[i] * negative_slope * (bottom_data[i] < 0);
     }
   }
 }
index e8b0fbe..f5a24a9 100644 (file)
@@ -9,9 +9,10 @@
 namespace caffe {
 
 template <typename Dtype>
-__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) {
+__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out,
+    Dtype negative_slope) {
   CUDA_KERNEL_LOOP(index, n) {
-    out[index] = in[index] > 0 ? in[index] : 0;
+    out[index] = in[index] > 0 ? in[index] : in[index] * negative_slope;
   }
 }
 
@@ -21,9 +22,10 @@ Dtype ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   const int count = bottom[0]->count();
+  Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
   // NOLINT_NEXT_LINE(whitespace/operators)
   ReLUForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-      count, bottom_data, top_data);
+      count, bottom_data, top_data, negative_slope);
   CUDA_POST_KERNEL_CHECK;
   // << " count: " << count << " bottom_data: "
   //     << (unsigned long)bottom_data
@@ -35,9 +37,10 @@ Dtype ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 __global__ void ReLUBackward(const int n, const Dtype* in_diff,
-    const Dtype* in_data, Dtype* out_diff) {
+    const Dtype* in_data, Dtype* out_diff, Dtype negative_slope) {
   CUDA_KERNEL_LOOP(index, n) {
-    out_diff[index] = in_diff[index] * (in_data[index] > 0);
+    out_diff[index] = in_diff[index] * (in_data[index] >= 0)
+        + in_diff[index] * (in_data[index] < 0) * negative_slope;
   }
 }
 
@@ -50,9 +53,10 @@ void ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
     const int count = (*bottom)[0]->count();
+    Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
     // NOLINT_NEXT_LINE(whitespace/operators)
     ReLUBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-        count, top_diff, bottom_data, bottom_diff);
+        count, top_diff, bottom_data, bottom_diff, negative_slope);
     CUDA_POST_KERNEL_CHECK;
   }
 }
index 10e52d2..9c3ddfe 100644 (file)
@@ -117,7 +117,7 @@ message SolverState {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available ID: 28 (last added: accuracy_param)
+// LayerParameter next available ID: 31 (last added: relu_param)
 message LayerParameter {
   repeated string bottom = 2; // the name of the bottom blobs
   repeated string top = 3; // the name of the top blobs
@@ -207,6 +207,7 @@ message LayerParameter {
   optional MemoryDataParameter memory_data_param = 22;
   optional PoolingParameter pooling_param = 19;
   optional PowerParameter power_param = 21;
+  optional ReLUParameter relu_param = 30;
   optional WindowDataParameter window_data_param = 20;
   optional ThresholdParameter threshold_param = 25;
   optional HingeLossParameter hinge_loss_param = 29;
@@ -429,6 +430,16 @@ message PowerParameter {
   optional float shift = 3 [default = 0.0];
 }
 
+// Message that stores parameters used by ReLULayer
+message ReLUParameter {
+  // Allow non-zero slope for negative inputs to speed up optimization
+  // Described in:
+  // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
+  // improve neural network acoustic models. In ICML Workshop on Deep Learning
+  // for Audio, Speech, and Language Processing.
+  optional float negative_slope = 1 [default = 0];
+}
+
 // Message that stores parameters used by WindowDataLayer
 message WindowDataParameter {
   // Specify the data source.
index 246832d..db9b934 100644 (file)
@@ -62,6 +62,32 @@ TYPED_TEST(NeuronLayerTest, TestReLUGradient) {
       &(this->blob_top_vec_));
 }
 
+TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.ParseFromString("relu_param{negative_slope:0.01}");
+  ReLULayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ReLULayer<Dtype> layer(layer_param);
+  layer_param.ParseFromString("relu_param{negative_slope:0.01}");
+  GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
 TYPED_TEST(NeuronLayerTest, TestSigmoid) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;