bugfix
authorYangqing Jia <jiayq84@gmail.com>
Fri, 25 Oct 2013 20:55:23 +0000 (13:55 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 25 Oct 2013 20:55:23 +0000 (13:55 -0700)
src/caffe/layers/bnll_layer.cu

index c9a33ed..fd261a3 100644 (file)
@@ -30,9 +30,10 @@ Dtype BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const int count = (*bottom)[0]->count();
+    Dtype expval;
     for (int i = 0; i < count; ++i) {
-      Dtype expval = exp(min(bottom_data[index], Dtype(kBNLL_THRESHOLD)));
-      bottom_diff[index] = top_diff[index] * expval / (expval + 1.);
+      expval = exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD)));
+      bottom_diff[i] = top_diff[i] * expval / (expval + 1.);
     }
   }
   return Dtype(0);
@@ -42,7 +43,7 @@ template <typename Dtype>
 __global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < n) {
-    out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD)));
+    out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD))));
   }
 }