Make TanH cleaner, more efficient, and possible to use in-place
authorJeff Donahue <jeff.donahue@gmail.com>
Fri, 6 Jun 2014 00:20:30 +0000 (17:20 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 6 Jun 2014 01:07:37 +0000 (18:07 -0700)
src/caffe/layers/tanh_layer.cpp
src/caffe/layers/tanh_layer.cu

index 46d11d0..66f530f 100644 (file)
@@ -18,8 +18,8 @@ Dtype TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   Dtype exp2x;
   const int count = bottom[0]->count();
   for (int i = 0; i < count; ++i) {
-    exp2x = exp(2*bottom_data[i]);
-    top_data[i] = (exp2x - Dtype(1))/(exp2x + Dtype(1));
+    exp2x = exp(2 * bottom_data[i]);
+    top_data[i] = (exp2x - Dtype(1)) / (exp2x + Dtype(1));
   }
   return Dtype(0);
 }
@@ -29,16 +29,15 @@ void TanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    const Dtype* top_data = top[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const int count = (*bottom)[0]->count();
     Dtype exp2x;
     Dtype tanhx;
     for (int i = 0; i < count; ++i) {
-      exp2x = exp(2*bottom_data[i]);
-      tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1));
-      bottom_diff[i] = top_diff[i] * (1 - tanhx*tanhx);
+      tanhx = top_data[i];
+      bottom_diff[i] = top_diff[i] * (1 - tanhx * tanhx);
     }
   }
 }
index 13bb001..aa822d8 100644 (file)
@@ -13,8 +13,8 @@ namespace caffe {
 template <typename Dtype>
 __global__ void TanHForward(const int n, const Dtype* in, Dtype* out) {
   CUDA_KERNEL_LOOP(index, n) {
-    Dtype exp2x = exp(2*in[index]);
-    out[index] = (exp2x - Dtype(1))/(exp2x + Dtype(1));
+    Dtype exp2x = exp(2 * in[index]);
+    out[index] = (exp2x - Dtype(1)) / (exp2x + Dtype(1));
   }
 }
 
@@ -28,21 +28,15 @@ Dtype TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   TanHForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
       count, bottom_data, top_data);
   CUDA_POST_KERNEL_CHECK;
-  // << " count: " << count << " bottom_data: "
-  //     << (unsigned long)bottom_data
-  //     << " top_data: " << (unsigned long)top_data
-  //     << " blocks: " << CAFFE_GET_BLOCKS(count)
-  //     << " threads: " << CAFFE_CUDA_NUM_THREADS;
   return Dtype(0);
 }
 
 template <typename Dtype>
 __global__ void TanHBackward(const int n, const Dtype* in_diff,
-    const Dtype* in_data, Dtype* out_diff) {
+    const Dtype* out_data, Dtype* out_diff) {
   CUDA_KERNEL_LOOP(index, n) {
-    Dtype exp2x = exp(2*in_data[index]);
-    Dtype tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1));
-    out_diff[index] = in_diff[index] * (1 - tanhx*tanhx);
+    Dtype tanhx = out_data[index];
+    out_diff[index] = in_diff[index] * (1 - tanhx * tanhx);
   }
 }
 
@@ -51,13 +45,13 @@ void TanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+    const Dtype* top_data = top[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
     const int count = (*bottom)[0]->count();
     // NOLINT_NEXT_LINE(whitespace/operators)
     TanHBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-        count, top_diff, bottom_data, bottom_diff);
+        count, top_diff, top_data, bottom_diff);
     CUDA_POST_KERNEL_CHECK;
   }
 }