From ba3ea2b6b1733a330057aff7e94e64834b2df439 Mon Sep 17 00:00:00 2001 From: qipeng Date: Sat, 19 Jul 2014 17:28:52 -0700 Subject: [PATCH] reduced multiplications & fixed unit test --- src/caffe/layers/relu_layer.cpp | 4 ++-- src/caffe/layers/relu_layer.cu | 4 ++-- src/caffe/test/test_neuron_layer.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/caffe/layers/relu_layer.cpp b/src/caffe/layers/relu_layer.cpp index 57ccf97..3a3e8a2 100644 --- a/src/caffe/layers/relu_layer.cpp +++ b/src/caffe/layers/relu_layer.cpp @@ -33,8 +33,8 @@ void ReLULayer::Backward_cpu(const vector*>& top, const int count = (*bottom)[0]->count(); Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); for (int i = 0; i < count; ++i) { - bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0) - + top_diff[i] * negative_slope * (bottom_data[i] < 0); + bottom_diff[i] = top_diff[i] * ((bottom_data[i] >= 0) + + negative_slope * (bottom_data[i] < 0)); } } } diff --git a/src/caffe/layers/relu_layer.cu b/src/caffe/layers/relu_layer.cu index f5a24a9..503147c 100644 --- a/src/caffe/layers/relu_layer.cu +++ b/src/caffe/layers/relu_layer.cu @@ -39,8 +39,8 @@ template __global__ void ReLUBackward(const int n, const Dtype* in_diff, const Dtype* in_data, Dtype* out_diff, Dtype negative_slope) { CUDA_KERNEL_LOOP(index, n) { - out_diff[index] = in_diff[index] * (in_data[index] >= 0) - + in_diff[index] * (in_data[index] < 0) * negative_slope; + out_diff[index] = in_diff[index] * ((in_data[index] >= 0) + + (in_data[index] < 0) * negative_slope); } } diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp index db9b934..7a7af31 100644 --- a/src/caffe/test/test_neuron_layer.cpp +++ b/src/caffe/test/test_neuron_layer.cpp @@ -81,8 +81,8 @@ TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) { TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - ReLULayer layer(layer_param); layer_param.ParseFromString("relu_param{negative_slope:0.01}"); + ReLULayer layer(layer_param); GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_), &(this->blob_top_vec_)); -- 2.7.4