strategize cuDNN activations: ReLU, Sigmoid, TanH
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 6 Sep 2014 05:43:58 +0000 (22:43 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 7 Sep 2014 17:25:23 +0000 (19:25 +0200)
include/caffe/neuron_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/cudnn_relu_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_relu_layer.cu [new file with mode: 0644]
src/caffe/layers/cudnn_sigmoid_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_sigmoid_layer.cu [new file with mode: 0644]
src/caffe/layers/cudnn_tanh_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_tanh_layer.cu [new file with mode: 0644]
src/caffe/test/test_neuron_layer.cpp

index 8c882ee..36acf96 100644 (file)
@@ -356,6 +356,31 @@ class ReLULayer : public NeuronLayer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+#ifdef USE_CUDNN
+/**
+ * @brief CuDNN acceleration of ReLULayer.
+ */
+template <typename Dtype>
+class CuDNNReLULayer : public ReLULayer<Dtype> {
+ public:
+  explicit CuDNNReLULayer(const LayerParameter& param)
+      : ReLULayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual ~CuDNNReLULayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  cudnnHandle_t             handle_;
+  cudnnTensor4dDescriptor_t bottom_desc_;
+  cudnnTensor4dDescriptor_t top_desc_;
+};
+#endif
+
 /**
  * @brief Sigmoid function non-linearity @f$
  *         y = (1 + \exp(-x))^{-1}
@@ -413,6 +438,31 @@ class SigmoidLayer : public NeuronLayer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+#ifdef USE_CUDNN
+/**
+ * @brief CuDNN acceleration of SigmoidLayer.
+ */
+template <typename Dtype>
+class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
+ public:
+  explicit CuDNNSigmoidLayer(const LayerParameter& param)
+      : SigmoidLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual ~CuDNNSigmoidLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  cudnnHandle_t             handle_;
+  cudnnTensor4dDescriptor_t bottom_desc_;
+  cudnnTensor4dDescriptor_t top_desc_;
+};
+#endif
+
 /**
  * @brief TanH hyperbolic tangent non-linearity @f$
  *         y = \frac{\exp(2x) - 1}{\exp(2x) + 1}
@@ -472,6 +522,31 @@ class TanHLayer : public NeuronLayer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+#ifdef USE_CUDNN
+/**
+ * @brief CuDNN acceleration of TanHLayer.
+ */
+template <typename Dtype>
+class CuDNNTanHLayer : public TanHLayer<Dtype> {
+ public:
+  explicit CuDNNTanHLayer(const LayerParameter& param)
+      : TanHLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual ~CuDNNTanHLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  cudnnHandle_t             handle_;
+  cudnnTensor4dDescriptor_t bottom_desc_;
+  cudnnTensor4dDescriptor_t top_desc_;
+};
+#endif
+
 /**
  * @brief Tests whether the input exceeds a threshold: outputs 1 for inputs
  *        above threshold; 0 otherwise.
index 516fe87..c519485 100644 (file)
@@ -76,6 +76,10 @@ ReLULayer<Dtype>* GetReLULayer(const string& name,
   }
   if (engine == ReLUParameter_Engine_CAFFE) {
     return new ReLULayer<Dtype>(param);
+#ifdef USE_CUDNN
+  } else if (engine == ReLUParameter_Engine_CUDNN) {
+    return new CuDNNReLULayer<Dtype>(param);
+#endif
   } else {
     LOG(FATAL) << "Layer " << name << " has unknown engine.";
   }
@@ -99,6 +103,10 @@ SigmoidLayer<Dtype>* GetSigmoidLayer(const string& name,
   }
   if (engine == SigmoidParameter_Engine_CAFFE) {
     return new SigmoidLayer<Dtype>(param);
+#ifdef USE_CUDNN
+  } else if (engine == SigmoidParameter_Engine_CUDNN) {
+    return new CuDNNSigmoidLayer<Dtype>(param);
+#endif
   } else {
     LOG(FATAL) << "Layer " << name << " has unknown engine.";
   }
@@ -122,6 +130,10 @@ TanHLayer<Dtype>* GetTanHLayer(const string& name,
   }
   if (engine == TanHParameter_Engine_CAFFE) {
     return new TanHLayer<Dtype>(param);
+#ifdef USE_CUDNN
+  } else if (engine == TanHParameter_Engine_CUDNN) {
+    return new CuDNNTanHLayer<Dtype>(param);
+#endif
   } else {
     LOG(FATAL) << "Layer " << name << " has unknown engine.";
   }
diff --git a/src/caffe/layers/cudnn_relu_layer.cpp b/src/caffe/layers/cudnn_relu_layer.cpp
new file mode 100644 (file)
index 0000000..f8bf77f
--- /dev/null
@@ -0,0 +1,34 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  ReLULayer<Dtype>::LayerSetUp(bottom, top);
+  // initialize cuDNN
+  CUDNN_CHECK(cudnnCreate(&handle_));
+  const int N = bottom[0]->num();
+  const int K = bottom[0]->channels();
+  const int H = bottom[0]->height();
+  const int W = bottom[0]->width();
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
+}
+
+template <typename Dtype>
+CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
+  cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
+  cudnnDestroyTensor4dDescriptor(this->top_desc_);
+  cudnnDestroy(this->handle_);
+}
+
+INSTANTIATE_CLASS(CuDNNReLULayer);
+
+}  // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_relu_layer.cu b/src/caffe/layers/cudnn_relu_layer.cu
new file mode 100644 (file)
index 0000000..8c8ca58
--- /dev/null
@@ -0,0 +1,55 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  // Fallback to standard Caffe for leaky ReLU.
+  if (ReLULayer<Dtype>::layer_param_.relu_param().negative_slope() != 0) {
+    return ReLULayer<Dtype>::Forward_gpu(bottom, top);
+  }
+
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  cudnnStatus_t stat = cudnnActivationForward(this->handle_,
+      CUDNN_ACTIVATION_RELU,
+      this->bottom_desc_, bottom_data, this->top_desc_, top_data);
+  CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+      << "Error in cudnnActivationForward.";
+}
+
+template <typename Dtype>
+void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down[0]) {
+    return;
+  }
+
+  // Fallback to standard Caffe for leaky ReLU.
+  if (ReLULayer<Dtype>::layer_param_.relu_param().negative_slope() != 0) {
+    return ReLULayer<Dtype>::Backward_gpu(top, propagate_down, bottom);
+  }
+
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  cudnnStatus_t stat = cudnnActivationBackward(this->handle_,
+      CUDNN_ACTIVATION_RELU,
+      this->top_desc_, top_data, this->top_desc_, top_diff,
+      this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff);
+  CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+      << "Error in cudnnActivationBackward.";
+}
+
+INSTANTIATE_CLASS(CuDNNReLULayer);
+
+}  // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cpp b/src/caffe/layers/cudnn_sigmoid_layer.cpp
new file mode 100644 (file)
index 0000000..488c754
--- /dev/null
@@ -0,0 +1,34 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  SigmoidLayer<Dtype>::LayerSetUp(bottom, top);
+  // initialize cuDNN
+  CUDNN_CHECK(cudnnCreate(&handle_));
+  const int N = bottom[0]->num();
+  const int K = bottom[0]->channels();
+  const int H = bottom[0]->height();
+  const int W = bottom[0]->width();
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
+}
+
+template <typename Dtype>
+CuDNNSigmoidLayer<Dtype>::~CuDNNSigmoidLayer() {
+  cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
+  cudnnDestroyTensor4dDescriptor(this->top_desc_);
+  cudnnDestroy(this->handle_);
+}
+
+INSTANTIATE_CLASS(CuDNNSigmoidLayer);
+
+}  // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cu b/src/caffe/layers/cudnn_sigmoid_layer.cu
new file mode 100644 (file)
index 0000000..c548a02
--- /dev/null
@@ -0,0 +1,45 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  cudnnStatus_t stat = cudnnActivationForward(this->handle_,
+      CUDNN_ACTIVATION_SIGMOID,
+      this->bottom_desc_, bottom_data, this->top_desc_, top_data);
+  CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+      << "Error in cudnnActivationForward.";
+}
+
+template <typename Dtype>
+void CuDNNSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down[0]) {
+    return;
+  }
+
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  cudnnStatus_t stat = cudnnActivationBackward(this->handle_,
+      CUDNN_ACTIVATION_SIGMOID,
+      this->top_desc_, top_data, this->top_desc_, top_diff,
+      this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff);
+  CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+      << "Error in cudnnActivationBackward.";
+}
+
+INSTANTIATE_CLASS(CuDNNSigmoidLayer);
+
+}  // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_tanh_layer.cpp b/src/caffe/layers/cudnn_tanh_layer.cpp
new file mode 100644 (file)
index 0000000..32b6611
--- /dev/null
@@ -0,0 +1,34 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  TanHLayer<Dtype>::LayerSetUp(bottom, top);
+  // initialize cuDNN
+  CUDNN_CHECK(cudnnCreate(&handle_));
+  const int N = bottom[0]->num();
+  const int K = bottom[0]->channels();
+  const int H = bottom[0]->height();
+  const int W = bottom[0]->width();
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
+}
+
+template <typename Dtype>
+CuDNNTanHLayer<Dtype>::~CuDNNTanHLayer() {
+  cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
+  cudnnDestroyTensor4dDescriptor(this->top_desc_);
+  cudnnDestroy(this->handle_);
+}
+
+INSTANTIATE_CLASS(CuDNNTanHLayer);
+
+}  // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_tanh_layer.cu b/src/caffe/layers/cudnn_tanh_layer.cu
new file mode 100644 (file)
index 0000000..090b38b
--- /dev/null
@@ -0,0 +1,45 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  cudnnStatus_t stat = cudnnActivationForward(this->handle_,
+      CUDNN_ACTIVATION_TANH,
+      this->bottom_desc_, bottom_data, this->top_desc_, top_data);
+  CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+      << "Error in cudnnActivationForward.";
+}
+
+template <typename Dtype>
+void CuDNNTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down[0]) {
+    return;
+  }
+
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  cudnnStatus_t stat = cudnnActivationBackward(this->handle_,
+      CUDNN_ACTIVATION_TANH,
+      this->top_desc_, top_data, this->top_desc_, top_diff,
+      this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff);
+  CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+      << "Error in cudnnActivationBackward.";
+}
+
+INSTANTIATE_CLASS(CuDNNTanHLayer);
+
+}  // namespace caffe
+#endif
index 2333c3a..4c19d3f 100644 (file)
@@ -272,5 +272,137 @@ TYPED_TEST(NeuronLayerTest, TestBNLLGradient) {
       &(this->blob_top_vec_));
 }
 
+#ifdef USE_CUDNN
+template <typename Dtype>
+class CuDNNNeuronLayerTest : public ::testing::Test {
+ protected:
+  CuDNNNeuronLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+        blob_top_(new Blob<Dtype>()) {
+    Caffe::set_random_seed(1701);
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~CuDNNNeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(CuDNNNeuronLayerTest, TestDtypes);
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestReLUCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  CuDNNReLULayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
+  }
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestReLUGradientCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  CuDNNReLULayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestReLUWithNegativeSlopeCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  layer_param.ParseFromString("relu_param{negative_slope:0.01}");
+  CuDNNReLULayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
+  }
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestReLUGradientWithNegativeSlopeCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  layer_param.ParseFromString("relu_param{negative_slope:0.01}");
+  CuDNNReLULayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestSigmoidCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  CuDNNSigmoidLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_FLOAT_EQ(top_data[i], 1. / (1 + exp(-bottom_data[i])));
+    // check that we squashed the value between 0 and 1
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_LE(top_data[i], 1.);
+  }
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestSigmoidGradientCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  CuDNNSigmoidLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3, 1701, 0., 0.01);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestTanHCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  CuDNNTanHLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Test exact values
+  for (int i = 0; i < this->blob_bottom_->num(); ++i) {
+    for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+      for (int k = 0; k < this->blob_bottom_->height(); ++k) {
+        for (int l = 0; l < this->blob_bottom_->width(); ++l) {
+          EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
+          EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
+        }
+      }
+    }
+  }
+}
+
+TYPED_TEST(CuDNNNeuronLayerTest, TestTanHGradientCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  CuDNNTanHLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+#endif
 
 }  // namespace caffe