relu gradient: >=0 -> >0
authorYangqing Jia <jiayq84@gmail.com>
Thu, 10 Oct 2013 03:44:58 +0000 (20:44 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 10 Oct 2013 03:44:58 +0000 (20:44 -0700)
src/caffe/layers/relu_layer.cu

index c386dd0..8613b3b 100644 (file)
@@ -29,7 +29,7 @@ Dtype ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const int count = (*bottom)[0]->count();
     for (int i = 0; i < count; ++i) {
-      bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
+      bottom_diff[i] = top_diff[i] * (bottom_data[i] > 0);
     }
   }
   return Dtype(0);
@@ -63,7 +63,7 @@ __global__ void ReLUBackward(const int n, const Dtype* in_diff,
     const Dtype* in_data, Dtype* out_diff) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
-    out_diff[index] = in_diff[index] * (in_data[index] >= 0);
+    out_diff[index] = in_diff[index] * (in_data[index] > 0);
   }
 }