sigmoid layer backward pass optimization: don't recompute forward pass
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 12 Apr 2014 11:16:23 +0000 (04:16 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sat, 12 Apr 2014 11:16:23 +0000 (04:16 -0700)
src/caffe/layers/sigmoid_layer.cpp
src/caffe/layers/sigmoid_layer.cu

index 601ba47..88a7920 100644 (file)
@@ -31,12 +31,12 @@ void SigmoidLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    const Dtype* top_data = top[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const int count = (*bottom)[0]->count();
     for (int i = 0; i < count; ++i) {
-      Dtype sigmoid_x = sigmoid(bottom_data[i]);
+      const Dtype sigmoid_x = top_data[i];
       bottom_diff[i] = top_diff[i] * sigmoid_x * (1. - sigmoid_x);
     }
   }
index d10b4d9..aa8568a 100644 (file)
@@ -12,14 +12,9 @@ using std::max;
 namespace caffe {
 
 template <typename Dtype>
-__device__ inline Dtype sigmoid_gpu(Dtype x) {
-  return 1. / (1. + exp(-x));
-}
-
-template <typename Dtype>
 __global__ void SigmoidForward(const int n, const Dtype* in, Dtype* out) {
   CUDA_KERNEL_LOOP(index, n) {
-    out[index] = sigmoid_gpu(in[index]);
+    out[index] = 1. / (1. + exp(-in[index]));
   }
 }
 
@@ -43,9 +38,9 @@ Dtype SigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 __global__ void SigmoidBackward(const int n, const Dtype* in_diff,
-    const Dtype* in_data, Dtype* out_diff) {
+    const Dtype* out_data, Dtype* out_diff) {
   CUDA_KERNEL_LOOP(index, n) {
-    Dtype sigmoid_x = sigmoid_gpu(in_data[index]);
+    const Dtype sigmoid_x = out_data[index];
     out_diff[index] = in_diff[index] * sigmoid_x * (1 - sigmoid_x);
   }
 }
@@ -55,13 +50,13 @@ void SigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+    const Dtype* top_data = top[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
     const int count = (*bottom)[0]->count();
     // NOLINT_NEXT_LINE(whitespace/operators)
     SigmoidBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-        count, top_diff, bottom_data, bottom_diff);
+        count, top_diff, top_data, bottom_diff);
     CUDA_POST_KERNEL_CHECK;
   }
 }