From 0193012e5cccb3e1a8d74f84f988bce8966640f7 Mon Sep 17 00:00:00 2001 From: qipeng Date: Sat, 19 Jul 2014 13:24:01 -0700 Subject: [PATCH] leaky relu + unit test --- src/caffe/layers/relu_layer.cpp | 8 ++++++-- src/caffe/layers/relu_layer.cu | 16 ++++++++++------ src/caffe/proto/caffe.proto | 13 ++++++++++++- src/caffe/test/test_neuron_layer.cpp | 26 ++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/src/caffe/layers/relu_layer.cpp b/src/caffe/layers/relu_layer.cpp index 1e11955..57ccf97 100644 --- a/src/caffe/layers/relu_layer.cpp +++ b/src/caffe/layers/relu_layer.cpp @@ -14,8 +14,10 @@ Dtype ReLULayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = (*top)[0]->mutable_cpu_data(); const int count = bottom[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); for (int i = 0; i < count; ++i) { - top_data[i] = std::max(bottom_data[i], Dtype(0)); + top_data[i] = std::max(bottom_data[i], Dtype(0)) + + negative_slope * std::min(bottom_data[i], Dtype(0)); } return Dtype(0); } @@ -29,8 +31,10 @@ void ReLULayer::Backward_cpu(const vector*>& top, const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); const int count = (*bottom)[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); for (int i = 0; i < count; ++i) { - bottom_diff[i] = top_diff[i] * (bottom_data[i] > 0); + bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0) + + top_diff[i] * negative_slope * (bottom_data[i] < 0); } } } diff --git a/src/caffe/layers/relu_layer.cu b/src/caffe/layers/relu_layer.cu index e8b0fbe..f5a24a9 100644 --- a/src/caffe/layers/relu_layer.cu +++ b/src/caffe/layers/relu_layer.cu @@ -9,9 +9,10 @@ namespace caffe { template -__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) { +__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out, + Dtype negative_slope) { CUDA_KERNEL_LOOP(index, n) { - out[index] = in[index] > 0 ? in[index] : 0; + out[index] = in[index] > 0 ? in[index] : in[index] * negative_slope; } } @@ -21,9 +22,10 @@ Dtype ReLULayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = (*top)[0]->mutable_gpu_data(); const int count = bottom[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); // NOLINT_NEXT_LINE(whitespace/operators) ReLUForward<<>>( - count, bottom_data, top_data); + count, bottom_data, top_data, negative_slope); CUDA_POST_KERNEL_CHECK; // << " count: " << count << " bottom_data: " // << (unsigned long)bottom_data @@ -35,9 +37,10 @@ Dtype ReLULayer::Forward_gpu(const vector*>& bottom, template __global__ void ReLUBackward(const int n, const Dtype* in_diff, - const Dtype* in_data, Dtype* out_diff) { + const Dtype* in_data, Dtype* out_diff, Dtype negative_slope) { CUDA_KERNEL_LOOP(index, n) { - out_diff[index] = in_diff[index] * (in_data[index] > 0); + out_diff[index] = in_diff[index] * (in_data[index] >= 0) + + in_diff[index] * (in_data[index] < 0) * negative_slope; } } @@ -50,9 +53,10 @@ void ReLULayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); const int count = (*bottom)[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); // NOLINT_NEXT_LINE(whitespace/operators) ReLUBackward<<>>( - count, top_diff, bottom_data, bottom_diff); + count, top_diff, bottom_data, bottom_diff, negative_slope); CUDA_POST_KERNEL_CHECK; } } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 10e52d2..9c3ddfe 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -117,7 +117,7 @@ message SolverState { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available ID: 28 (last added: accuracy_param) +// LayerParameter next available ID: 31 (last added: relu_param) message LayerParameter { repeated string bottom = 2; // the name of the bottom blobs repeated string top = 3; // the name of the top blobs @@ -207,6 +207,7 @@ message LayerParameter { optional MemoryDataParameter memory_data_param = 22; optional PoolingParameter pooling_param = 19; optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; optional WindowDataParameter window_data_param = 20; optional ThresholdParameter threshold_param = 25; optional HingeLossParameter hinge_loss_param = 29; @@ -429,6 +430,16 @@ message PowerParameter { optional float shift = 3 [default = 0.0]; } +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; +} + // Message that stores parameters used by WindowDataLayer message WindowDataParameter { // Specify the data source. diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp index 246832d..db9b934 100644 --- a/src/caffe/test/test_neuron_layer.cpp +++ b/src/caffe/test/test_neuron_layer.cpp @@ -62,6 +62,32 @@ TYPED_TEST(NeuronLayerTest, TestReLUGradient) { &(this->blob_top_vec_)); } +TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.ParseFromString("relu_param{negative_slope:0.01}"); + ReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_GE(top_data[i], 0.); + EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]); + } +} + +TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ReLULayer layer(layer_param); + layer_param.ParseFromString("relu_param{negative_slope:0.01}"); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_), + &(this->blob_top_vec_)); +} + TYPED_TEST(NeuronLayerTest, TestSigmoid) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; -- 2.7.4