[acti_func] implement quick gelu
authorhyeonseok <hs89.lee@samsung.com>
Fri, 7 Jun 2024 04:14:46 +0000 (13:14 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Sat, 8 Jun 2024 04:26:36 +0000 (13:26 +0900)
 - Implemented quick gelu function.
   Please note that quickGeluPrime which is calculate derivate of quickGelu function is not yet implemented.

Signed-off-by: hyeonseok <hs89.lee@samsung.com>
nntrainer/layers/acti_func.h
nntrainer/layers/common_properties.h

index c39eaecdcf5d0061a210e00b7a6a7ed59893d581..07fc13696a99a55317a552dfa0c13e0d43a3115b 100644 (file)
@@ -78,6 +78,10 @@ public:
       in_place = false;
       this->setActivation<Tensor>(gelu<T>, geluPrime<T>);
       break;
+    case ActivationType::ACT_QUICK_GELU:
+      in_place = false;
+      this->setActivation<Tensor>(quickGelu<T>, quickGeluPrime<T>);
+      break;
     case ActivationType::ACT_ELU:
       this->setActivation<T>(elu<T>, eluPrime<T>);
       break;
@@ -457,6 +461,34 @@ public:
     return outgoing_derivative;
   }
 
+  /**
+   * @brief     quick gelu activation function (gelu approximation)
+   * @param[in] t_in input tensor
+   * @param[in] t_out output tensor
+   */
+  template <typename T = float>
+  static Tensor &quickGelu(Tensor const &t_in, Tensor &t_out) {
+    t_in.apply<T>(
+      [&](T x) { return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x)))); }, t_out);
+    return t_out;
+  }
+
+  /**
+   * @brief     derivative quick gelu function
+   * @param[in] t_in input tensor
+   * @param[in] t_out output tensor
+   * @param[in] outgoing_derivative outgoing derivative
+   * @param[in] incoming_derivative incoming derivative
+   */
+  template <typename T = float>
+  static Tensor &quickGeluPrime(Tensor const &t_in, Tensor const &t_out,
+                           Tensor &outgoing_derivative,
+                           Tensor const &incoming_derivative = Tensor()) {
+    // NYI
+    ml_logw("quickGeluPrime which is calculate derivate of quickGelu function is not yet implemented");
+    return outgoing_derivative;
+  }
+
   /**
    * @brief elu function
    * @note alpha parameter is needed for elu, but supporting property on
index 7db9b0caa0b0d36ec66d26f5af97426348463226..3c5d73300df9ce27a9e5781a3898c5589458b92e 100644 (file)
@@ -35,6 +35,7 @@ enum class ActivationType {
   ACT_RELU,       /**< ReLU */
   ACT_SWISH,      /**< Swish */
   ACT_GELU,       /**< GELU */
+  ACT_QUICK_GELU, /**< Quick GELU */
   ACT_SOFTMAX,    /**< softmax */
   ACT_SOFTPLUS,   /**< softplus */
   ACT_LEAKY_RELU, /**< Leaky ReLU */
@@ -865,11 +866,13 @@ struct ActivationTypeInfo {
   static constexpr std::initializer_list<Enum> EnumList = {
     Enum::ACT_TANH,    Enum::ACT_SIGMOID,    Enum::ACT_RELU,
     Enum::ACT_SOFTMAX, Enum::ACT_LEAKY_RELU, Enum::ACT_SWISH,
-    Enum::ACT_GELU,    Enum::ACT_NONE,       Enum::ACT_UNKNOWN};
+    Enum::ACT_GELU,    Enum::ACT_QUICK_GELU, Enum::ACT_NONE,
+    Enum::ACT_UNKNOWN};
 
   static constexpr const char *EnumStr[] = {"tanh",    "sigmoid",    "relu",
                                             "softmax", "leaky_relu", "swish",
-                                            "gelu",    "none",       "unknown"};
+                                            "gelu",    "quick_gelu", "none",
+                                            "unknown"};
 };
 
 /**