strategize relu, sigmoid, tanh
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 2 Sep 2014 04:01:11 +0000 (21:01 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 7 Sep 2014 01:27:07 +0000 (03:27 +0200)
include/caffe/loss_layers.hpp
include/caffe/neuron_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/caffe_relu_layer.cpp [moved from src/caffe/layers/relu_layer.cpp with 84% similarity]
src/caffe/layers/caffe_relu_layer.cu [moved from src/caffe/layers/relu_layer.cu with 91% similarity]
src/caffe/layers/caffe_sigmoid_layer.cpp [moved from src/caffe/layers/sigmoid_layer.cpp with 82% similarity]
src/caffe/layers/caffe_sigmoid_layer.cu [moved from src/caffe/layers/sigmoid_layer.cu with 90% similarity]
src/caffe/layers/caffe_tanh_layer.cpp [moved from src/caffe/layers/tanh_layer.cpp with 77% similarity]
src/caffe/layers/caffe_tanh_layer.cu [moved from src/caffe/layers/tanh_layer.cu with 85% similarity]
src/caffe/test/test_neuron_layer.cpp
src/caffe/test/test_tanh_layer.cpp [deleted file]

index a29c445..555c06e 100644 (file)
@@ -506,7 +506,7 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
  public:
   explicit SigmoidCrossEntropyLossLayer(const LayerParameter& param)
       : LossLayer<Dtype>(param),
-          sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
+          sigmoid_layer_(new CaffeSigmoidLayer<Dtype>(param)),
           sigmoid_output_(new Blob<Dtype>()) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
@@ -558,7 +558,7 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   /// The internal SigmoidLayer used to map predictions to probabilities.
-  shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
+  shared_ptr<CaffeSigmoidLayer<Dtype> > sigmoid_layer_;
   /// sigmoid_output stores the output of the SigmoidLayer.
   shared_ptr<Blob<Dtype> > sigmoid_output_;
   /// bottom vector holder to call the underlying SigmoidLayer::Forward
index 8c882ee..f1d8f51 100644 (file)
@@ -318,9 +318,9 @@ class ReLULayer : public NeuronLayer<Dtype> {
    *      the computed outputs are @f$ y = \max(0, x) + \nu \min(0, x) @f$.
    */
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
+      vector<Blob<Dtype>*>* top) = 0;
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
+      vector<Blob<Dtype>*>* top) = 0;
 
   /**
    * @brief Computes the error gradient w.r.t. the ReLU inputs.
@@ -351,6 +351,31 @@ class ReLULayer : public NeuronLayer<Dtype> {
    *      @f$.
    */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+};
+
+/**
+ * @brief standard Caffe implementation of ReLULayer.
+ */
+template <typename Dtype>
+class CaffeReLULayer : public ReLULayer<Dtype> {
+ public:
+  explicit CaffeReLULayer(const LayerParameter& param)
+      : ReLULayer<Dtype>(param) {}
+
+  virtual inline LayerParameter_LayerType type() const {
+    return LayerParameter_LayerType_RELU;
+  }
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
@@ -386,9 +411,9 @@ class SigmoidLayer : public NeuronLayer<Dtype> {
    *      @f$
    */
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
+      vector<Blob<Dtype>*>* top) = 0;
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
+      vector<Blob<Dtype>*>* top) = 0;
 
   /**
    * @brief Computes the error gradient w.r.t. the sigmoid inputs.
@@ -408,6 +433,31 @@ class SigmoidLayer : public NeuronLayer<Dtype> {
    *      @f$ if propagate_down[0]
    */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+};
+
+/**
+ * @brief standard Caffe implementation of SigmoidLayer.
+ */
+template <typename Dtype>
+class CaffeSigmoidLayer : public SigmoidLayer<Dtype> {
+ public:
+  explicit CaffeSigmoidLayer(const LayerParameter& param)
+      : SigmoidLayer<Dtype>(param) {}
+
+  virtual inline LayerParameter_LayerType type() const {
+    return LayerParameter_LayerType_SIGMOID;
+  }
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
@@ -443,9 +493,9 @@ class TanHLayer : public NeuronLayer<Dtype> {
    *      @f$
    */
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
+      vector<Blob<Dtype>*>* top) = 0;
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
+      vector<Blob<Dtype>*>* top) = 0;
 
   /**
    * @brief Computes the error gradient w.r.t. the sigmoid inputs.
@@ -467,6 +517,31 @@ class TanHLayer : public NeuronLayer<Dtype> {
    *      @f$ if propagate_down[0]
    */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) = 0;
+};
+
+/**
+ * @brief standard Caffe implementation of TanHLayer.
+ */
+template <typename Dtype>
+class CaffeTanHLayer : public TanHLayer<Dtype> {
+ public:
+  explicit CaffeTanHLayer(const LayerParameter& param)
+      : TanHLayer<Dtype>(param) {}
+
+  virtual inline LayerParameter_LayerType type() const {
+    return LayerParameter_LayerType_TANH;
+  }
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
index 2ffb7f8..3a535f5 100644 (file)
@@ -43,6 +43,57 @@ template PoolingLayer<float>* GetPoolingLayer(const string& name,
 template PoolingLayer<double>* GetPoolingLayer(const string& name,
     const LayerParameter& param);
 
+// Get relu layer according to engine.
+template <typename Dtype>
+ReLULayer<Dtype>* GetReLULayer(const string& name,
+    const LayerParameter& param) {
+  ReLUParameter_Engine engine = param.relu_param().engine();
+  if (engine == ReLUParameter_Engine_CAFFE) {
+    return new CaffeReLULayer<Dtype>(param);
+  } else {
+    LOG(FATAL) << "Layer " << name << " has unknown engine.";
+  }
+}
+
+template ReLULayer<float>* GetReLULayer(const string& name,
+    const LayerParameter& param);
+template ReLULayer<double>* GetReLULayer(const string& name,
+    const LayerParameter& param);
+
+// Get sigmoid layer according to engine.
+template <typename Dtype>
+SigmoidLayer<Dtype>* GetSigmoidLayer(const string& name,
+    const LayerParameter& param) {
+  SigmoidParameter_Engine engine = param.sigmoid_param().engine();
+  if (engine == SigmoidParameter_Engine_CAFFE) {
+    return new CaffeSigmoidLayer<Dtype>(param);
+  } else {
+    LOG(FATAL) << "Layer " << name << " has unknown engine.";
+  }
+}
+
+template SigmoidLayer<float>* GetSigmoidLayer(const string& name,
+    const LayerParameter& param);
+template SigmoidLayer<double>* GetSigmoidLayer(const string& name,
+    const LayerParameter& param);
+
+// Get tanh layer according to engine.
+template <typename Dtype>
+TanHLayer<Dtype>* GetTanHLayer(const string& name,
+    const LayerParameter& param) {
+  TanHParameter_Engine engine = param.tanh_param().engine();
+  if (engine == TanHParameter_Engine_CAFFE) {
+    return new CaffeTanHLayer<Dtype>(param);
+  } else {
+    LOG(FATAL) << "Layer " << name << " has unknown engine.";
+  }
+}
+
+template TanHLayer<float>* GetTanHLayer(const string& name,
+    const LayerParameter& param);
+template TanHLayer<double>* GetTanHLayer(const string& name,
+    const LayerParameter& param);
+
 // A function to get a specific layer from the specification given in
 // LayerParameter. Ideally this would be replaced by a factory pattern,
 // but we will leave it this way for now.
@@ -102,11 +153,11 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
   case LayerParameter_LayerType_POWER:
     return new PowerLayer<Dtype>(param);
   case LayerParameter_LayerType_RELU:
-    return new ReLULayer<Dtype>(param);
+    return GetReLULayer<Dtype>(name, param);
   case LayerParameter_LayerType_SILENCE:
     return new SilenceLayer<Dtype>(param);
   case LayerParameter_LayerType_SIGMOID:
-    return new SigmoidLayer<Dtype>(param);
+    return GetSigmoidLayer<Dtype>(name, param);
   case LayerParameter_LayerType_SIGMOID_CROSS_ENTROPY_LOSS:
     return new SigmoidCrossEntropyLossLayer<Dtype>(param);
   case LayerParameter_LayerType_SLICE:
@@ -118,7 +169,7 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
   case LayerParameter_LayerType_SPLIT:
     return new SplitLayer<Dtype>(param);
   case LayerParameter_LayerType_TANH:
-    return new TanHLayer<Dtype>(param);
+    return GetTanHLayer<Dtype>(name, param);
   case LayerParameter_LayerType_WINDOW_DATA:
     return new WindowDataLayer<Dtype>(param);
   case LayerParameter_LayerType_NONE:
similarity index 84%
rename from src/caffe/layers/relu_layer.cpp
rename to src/caffe/layers/caffe_relu_layer.cpp
index b50352f..d708b3f 100644 (file)
@@ -7,7 +7,7 @@
 namespace caffe {
 
 template <typename Dtype>
-void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -20,7 +20,7 @@ void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
-void ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+void CaffeReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down[0]) {
@@ -36,12 +36,10 @@ void ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   }
 }
 
-
 #ifdef CPU_ONLY
-STUB_GPU(ReLULayer);
+STUB_GPU(CaffeReLULayer);
 #endif
 
-INSTANTIATE_CLASS(ReLULayer);
-
+INSTANTIATE_CLASS(CaffeReLULayer);
 
 }  // namespace caffe
similarity index 91%
rename from src/caffe/layers/relu_layer.cu
rename to src/caffe/layers/caffe_relu_layer.cu
index def2bbc..ad89968 100644 (file)
@@ -15,7 +15,7 @@ __global__ void ReLUForward(const int n, const Dtype* in, Dtype* out,
 }
 
 template <typename Dtype>
-void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -42,7 +42,7 @@ __global__ void ReLUBackward(const int n, const Dtype* in_diff,
 }
 
 template <typename Dtype>
-void ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void CaffeReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down[0]) {
@@ -58,8 +58,6 @@ void ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   }
 }
 
-
-INSTANTIATE_CLASS(ReLULayer);
-
+INSTANTIATE_CLASS(CaffeReLULayer);
 
 }  // namespace caffe
similarity index 82%
rename from src/caffe/layers/sigmoid_layer.cpp
rename to src/caffe/layers/caffe_sigmoid_layer.cpp
index d7bba7f..b5bb0e3 100644 (file)
@@ -13,7 +13,7 @@ inline Dtype sigmoid(Dtype x) {
 }
 
 template <typename Dtype>
-void SigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeSigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -24,7 +24,7 @@ void SigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
-void SigmoidLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+void CaffeSigmoidLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down[0]) {
@@ -40,10 +40,10 @@ void SigmoidLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
 }
 
 #ifdef CPU_ONLY
-STUB_GPU(SigmoidLayer);
+STUB_GPU(CaffeSigmoidLayer);
 #endif
 
-INSTANTIATE_CLASS(SigmoidLayer);
+INSTANTIATE_CLASS(CaffeSigmoidLayer);
 
 
 }  // namespace caffe
similarity index 90%
rename from src/caffe/layers/sigmoid_layer.cu
rename to src/caffe/layers/caffe_sigmoid_layer.cu
index e1ebb1f..030d34e 100644 (file)
@@ -15,7 +15,7 @@ __global__ void SigmoidForward(const int n, const Dtype* in, Dtype* out) {
 }
 
 template <typename Dtype>
-void SigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -41,7 +41,7 @@ __global__ void SigmoidBackward(const int n, const Dtype* in_diff,
 }
 
 template <typename Dtype>
-void SigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void CaffeSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down[0]) {
@@ -56,7 +56,7 @@ void SigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   }
 }
 
-INSTANTIATE_CLASS(SigmoidLayer);
+INSTANTIATE_CLASS(CaffeSigmoidLayer);
 
 
 }  // namespace caffe
similarity index 77%
rename from src/caffe/layers/tanh_layer.cpp
rename to src/caffe/layers/caffe_tanh_layer.cpp
index 8dae005..a743339 100644 (file)
@@ -1,6 +1,3 @@
-// TanH neuron activation function layer.
-// Adapted from ReLU layer code written by Yangqing Jia
-
 #include <algorithm>
 #include <vector>
 
@@ -10,7 +7,7 @@
 namespace caffe {
 
 template <typename Dtype>
-void TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeTanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -23,7 +20,7 @@ void TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
-void TanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+void CaffeTanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down[0]) {
@@ -40,9 +37,9 @@ void TanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
 }
 
 #ifdef CPU_ONLY
-STUB_GPU(TanHLayer);
+STUB_GPU(CaffeTanHLayer);
 #endif
 
-INSTANTIATE_CLASS(TanHLayer);
+INSTANTIATE_CLASS(CaffeTanHLayer);
 
 }  // namespace caffe
similarity index 85%
rename from src/caffe/layers/tanh_layer.cu
rename to src/caffe/layers/caffe_tanh_layer.cu
index bdb7a94..f2096e6 100644 (file)
@@ -1,6 +1,3 @@
-// TanH neuron activation function layer.
-// Adapted from ReLU layer code written by Yangqing Jia
-
 #include <algorithm>
 #include <vector>
 
@@ -18,7 +15,7 @@ __global__ void TanHForward(const int n, const Dtype* in, Dtype* out) {
 }
 
 template <typename Dtype>
-void TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void CaffeTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -39,7 +36,7 @@ __global__ void TanHBackward(const int n, const Dtype* in_diff,
 }
 
 template <typename Dtype>
-void TanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+void CaffeTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   if (propagate_down[0]) {
@@ -54,7 +51,6 @@ void TanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   }
 }
 
-INSTANTIATE_CLASS(TanHLayer);
-
+INSTANTIATE_CLASS(CaffeTanHLayer);
 
 }  // namespace caffe
index 29dcec5..322e497 100644 (file)
@@ -96,7 +96,7 @@ TYPED_TEST(NeuronLayerTest, TestAbsGradient) {
 TYPED_TEST(NeuronLayerTest, TestReLU) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  ReLULayer<Dtype> layer(layer_param);
+  CaffeReLULayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // Now, check values
@@ -111,7 +111,7 @@ TYPED_TEST(NeuronLayerTest, TestReLU) {
 TYPED_TEST(NeuronLayerTest, TestReLUGradient) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  ReLULayer<Dtype> layer(layer_param);
+  CaffeReLULayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
   checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
       &(this->blob_top_vec_));
@@ -121,7 +121,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
   layer_param.ParseFromString("relu_param{negative_slope:0.01}");
-  ReLULayer<Dtype> layer(layer_param);
+  CaffeReLULayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // Now, check values
@@ -137,7 +137,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
   layer_param.ParseFromString("relu_param{negative_slope:0.01}");
-  ReLULayer<Dtype> layer(layer_param);
+  CaffeReLULayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
   checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
       &(this->blob_top_vec_));
@@ -146,7 +146,7 @@ TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) {
 TYPED_TEST(NeuronLayerTest, TestSigmoid) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  SigmoidLayer<Dtype> layer(layer_param);
+  CaffeSigmoidLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
   layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
   // Now, check values
@@ -163,12 +163,44 @@ TYPED_TEST(NeuronLayerTest, TestSigmoid) {
 TYPED_TEST(NeuronLayerTest, TestSigmoidGradient) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  SigmoidLayer<Dtype> layer(layer_param);
+  CaffeSigmoidLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
   checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
       &(this->blob_top_vec_));
 }
 
+TYPED_TEST(NeuronLayerTest, TestTanH) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CaffeTanHLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Test exact values
+  for (int i = 0; i < this->blob_bottom_->num(); ++i) {
+    for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+      for (int k = 0; k < this->blob_bottom_->height(); ++k) {
+        for (int l = 0; l < this->blob_bottom_->width(); ++l) {
+          EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
+          EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
+        }
+      }
+    }
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestTanHGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CaffeTanHLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
 TYPED_TEST(NeuronLayerTest, TestDropoutHalf) {
   const float kDropoutRatio = 0.5;
   this->TestDropoutForward(kDropoutRatio);
diff --git a/src/caffe/test/test_tanh_layer.cpp b/src/caffe/test/test_tanh_layer.cpp
deleted file mode 100644 (file)
index 9b8e745..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-// Adapted from other test files
-
-#include <cmath>
-#include <cstring>
-#include <vector>
-
-#include "gtest/gtest.h"
-
-#include "caffe/blob.hpp"
-#include "caffe/common.hpp"
-#include "caffe/filler.hpp"
-#include "caffe/vision_layers.hpp"
-
-#include "caffe/test/test_caffe_main.hpp"
-#include "caffe/test/test_gradient_check_util.hpp"
-
-namespace caffe {
-
-template <typename TypeParam>
-class TanHLayerTest : public MultiDeviceTest<TypeParam> {
-  typedef typename TypeParam::Dtype Dtype;
- protected:
-  TanHLayerTest()
-      : blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
-        blob_top_(new Blob<Dtype>()) {
-    // fill the values
-    FillerParameter filler_param;
-    GaussianFiller<Dtype> filler(filler_param);
-    filler.Fill(this->blob_bottom_);
-    blob_bottom_vec_.push_back(blob_bottom_);
-    blob_top_vec_.push_back(blob_top_);
-  }
-  virtual ~TanHLayerTest() { delete blob_bottom_; delete blob_top_; }
-  Blob<Dtype>* const blob_bottom_;
-  Blob<Dtype>* const blob_top_;
-  vector<Blob<Dtype>*> blob_bottom_vec_;
-  vector<Blob<Dtype>*> blob_top_vec_;
-};
-
-TYPED_TEST_CASE(TanHLayerTest, TestDtypesAndDevices);
-
-TYPED_TEST(TanHLayerTest, TestForward) {
-  typedef typename TypeParam::Dtype Dtype;
-  LayerParameter layer_param;
-  TanHLayer<Dtype> layer(layer_param);
-  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
-  // Test exact values
-  for (int i = 0; i < this->blob_bottom_->num(); ++i) {
-    for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
-      for (int k = 0; k < this->blob_bottom_->height(); ++k) {
-        for (int l = 0; l < this->blob_bottom_->width(); ++l) {
-          EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
-             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
-             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
-          EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
-             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
-             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
-        }
-      }
-    }
-  }
-}
-
-TYPED_TEST(TanHLayerTest, TestGradient) {
-  typedef typename TypeParam::Dtype Dtype;
-  LayerParameter layer_param;
-  TanHLayer<Dtype> layer(layer_param);
-  GradientChecker<Dtype> checker(1e-2, 1e-3);
-  checker.CheckGradientEltwise(&layer, &(this->blob_bottom_vec_),
-      &(this->blob_top_vec_));
-}
-
-}  // namespace caffe