[ RNN ] Set tanh as an default activation function
authorjijoong.moon <jijoong.moon@samsung.com>
Wed, 24 Mar 2021 01:35:58 +0000 (10:35 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 29 Mar 2021 12:51:56 +0000 (21:51 +0900)
Set the tanh as an default activation function

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/layers/rnn.h
test/unittest/unittest_nntrainer_layers.cpp

index 19ec7d100a8450cb42c6467505f2a74bb641c95e..84b6de476154134fd35100a9b6747bd13c36367c 100644 (file)
@@ -264,6 +264,9 @@ int NetworkGraph::realizeActivationType(Layer &current) {
 
   if (current.getType() == RNNLayer::type) {
     // No need to add activation layer for RNN Layer
+    // Default activation is tanh
+    if (act == ActivationType::ACT_NONE)
+      act = ActivationType::ACT_TANH;
     current.setActivation(act);
     return status;
   }
index 16e5c0be5678ed2a4fb96d91b2b6ac228a96bab0..68136bcc5dcfc6c411359a33097e278c21b70458 100644 (file)
@@ -30,9 +30,11 @@ public:
    * @brief     Constructor of RNNLayer
    */
   template <typename... Args>
-  RNNLayer(unsigned int unit_ = 0, Args... args) :
-    Layer(args...),
-    unit(unit_) {}
+  RNNLayer(unsigned int unit_ = 0, Args... args) : Layer(args...), unit(unit_) {
+    /* Default Activation Type is tanh */
+    if (getActivationType() == ActivationType::ACT_NONE)
+      setActivation(ActivationType::ACT_TANH);
+  }
 
   /**
    * @brief     Destructor of RNNLayer
index 2d9b4e16211b8033b617652967796284b66ebcad..f4ec5b74e33bb33d1d4a218090bcc8c8d19abe81 100644 (file)
@@ -2337,8 +2337,7 @@ protected:
   typedef nntrainer_abstractLayer<nntrainer::RNNLayer> super;
 
   virtual void prepareLayer() {
-    int status =
-      setProperty("unit=3 | weight_initializer=ones | activation=tanh");
+    int status = setProperty("unit=3 | weight_initializer=ones");
     EXPECT_EQ(status, ML_ERROR_NONE);
     setInputDim("2:1:3:3");
     setBatch(2);