[weight] Support weight decay
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 24 Jan 2022 03:58:43 +0000 (12:58 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 9 Feb 2022 09:34:12 +0000 (18:34 +0900)
Add support for weight decay property which will enable decay of weights
with each applying of the gradient.
Weight decay can be enabled individually for both weight and bias.
This is kept separate from regularizer as they both behave differently.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/common_properties.cpp
nntrainer/layers/common_properties.h
nntrainer/layers/layer_impl.cpp
nntrainer/layers/layer_impl.h
nntrainer/layers/lstmcell_core.cpp
nntrainer/layers/lstmcell_core.h
nntrainer/tensor/weight.cpp
nntrainer/tensor/weight.h
nntrainer/utils/node_exporter.cpp
nntrainer/utils/node_exporter.h

index c0d6c2c..96e17e5 100644 (file)
@@ -232,11 +232,14 @@ std::array<unsigned int, 2> Padding1D::compute(const TensorDim &input,
   return {0, 0};
 }
 
-WeightRegularizerConstant::WeightRegularizerConstant(float value) {
-  set(value);
-}
+BasicRegularizerConstant::BasicRegularizerConstant(float value) { set(value); }
 
-bool WeightRegularizerConstant::isValid(const float &value) const {
+WeightRegularizerConstant::WeightRegularizerConstant(float value) :
+  BasicRegularizerConstant(value) {}
+WeightDecay::WeightDecay(float value) : BasicRegularizerConstant(value) {}
+BiasDecay::BiasDecay(float value) : BasicRegularizerConstant(value) {}
+
+bool BasicRegularizerConstant::isValid(const float &value) const {
   return value >= 0.0f;
 }
 
@@ -270,11 +273,14 @@ BNPARAMS_BETA_INIT::BNPARAMS_BETA_INIT(Tensor::Initializer value) {
   set(value);
 }
 
-WeightRegularizer::WeightRegularizer(nntrainer::WeightRegularizer value) {
+BasicRegularizer::BasicRegularizer(nntrainer::WeightRegularizer value) {
   set(value);
 }
 
-bool WeightRegularizer::isValid(
+WeightRegularizer::WeightRegularizer(nntrainer::WeightRegularizer value) :
+  BasicRegularizer(value) {}
+
+bool BasicRegularizer::isValid(
   const nntrainer::WeightRegularizer &value) const {
   return value != nntrainer::WeightRegularizer::UNKNOWN;
 }
index 13cf6a5..e67c5c3 100644 (file)
@@ -644,20 +644,20 @@ public:
 };
 
 /**
- * @brief WeightRegularizerConstant property, this defines how much regularize
+ * @brief BasicRegularizerConstant property, this defines how much regularize
  * the weight
  *
  */
-class WeightRegularizerConstant : public nntrainer::Property<float> {
+class BasicRegularizerConstant : public nntrainer::Property<float> {
 
 public:
   /**
-   * @brief Construct a new WeightRegularizerConstant object
+   * @brief Construct a new BasicRegularizerConstant object
    *
    */
-  WeightRegularizerConstant(float value = 1.0f);
+  BasicRegularizerConstant(float value = 1.0f);
   static constexpr const char *key =
-    "weight_regularizer_constant"; /**< unique key to access */
+    "basic_regularizer_constant";  /**< unique key to access */
   using prop_tag = float_prop_tag; /**< property type */
 
   /**
@@ -670,6 +670,56 @@ public:
 };
 
 /**
+ * @brief WeightRegularizerConstant property, this defines how much regularize
+ * the weight
+ *
+ */
+class WeightRegularizerConstant final : public BasicRegularizerConstant {
+
+public:
+  /**
+   * @brief Construct a new WeightRegularizerConstant object
+   *
+   */
+  WeightRegularizerConstant(float value = 1.0f);
+  static constexpr const char *key =
+    "weight_regularizer_constant"; /**< unique key to access */
+};
+
+/**
+ * @brief WeightDecay property, this defines how much to decay
+ * the weight
+ *
+ */
+class WeightDecay final : public BasicRegularizerConstant {
+
+public:
+  /**
+   * @brief Construct a new WeightDecay object
+   *
+   */
+  WeightDecay(float value = 0.0f);
+  static constexpr const char *key =
+    "weight_decay"; /**< unique key to access */
+};
+
+/**
+ * @brief BiasDecay property, this defines how much regularize
+ * the weight
+ *
+ */
+class BiasDecay final : public BasicRegularizerConstant {
+
+public:
+  /**
+   * @brief Construct a new BiasDecay object
+   *
+   */
+  BiasDecay(float value = 0.0f);
+  static constexpr const char *key = "bias_decay"; /**< unique key to access */
+};
+
+/**
  * @brief Output Layer name property which saves a single connection
  * (practically, std::vector<InputLayers> is used)
  *
@@ -887,21 +937,20 @@ struct RegularizerInfo {
 };
 
 /**
- * @brief WeightRegularizer Regularization Enumeration Information
+ * @brief BasicRegularizer Regularization Enumeration Information
  *
  */
-class WeightRegularizer final : public EnumProperty<RegularizerInfo> {
+class BasicRegularizer : public EnumProperty<RegularizerInfo> {
 public:
   /**
-   * @brief Construct a WeightRegularizer object
+   * @brief Construct a BasicRegularizer object
    */
-  WeightRegularizer(
-    nntrainer::WeightRegularizer value = nntrainer::WeightRegularizer::NONE);
+  BasicRegularizer(nntrainer::WeightRegularizer value);
   using prop_tag = enum_class_prop_tag;
-  static constexpr const char *key = "weight_regularizer";
+  static constexpr const char *key = "basic_regularizer";
 
   /**
-   * @brief WeightRegularizer validator
+   * @brief BasicRegularizer validator
    *
    * @param value nntrainer::WeightRegularizer to validate
    * @retval true if value is not nntrainer::WeightRegularizer::UNKNOWN
@@ -911,6 +960,20 @@ public:
 };
 
 /**
+ * @brief WeightRegularizer Regularization Enumeration Information
+ *
+ */
+class WeightRegularizer final : public BasicRegularizer {
+public:
+  /**
+   * @brief Construct a WeightRegularizer object
+   */
+  WeightRegularizer(
+    nntrainer::WeightRegularizer value = nntrainer::WeightRegularizer::NONE);
+  static constexpr const char *key = "weight_regularizer";
+};
+
+/**
  * @brief     Enumeration of pooling type
  */
 struct PoolingTypeInfo {
index 83d92dd..6893406 100644 (file)
@@ -26,8 +26,8 @@ LayerImpl::LayerImpl() :
   layer_impl_props(
     std::make_unique<
       std::tuple<props::WeightRegularizer, props::WeightRegularizerConstant,
-                 props::WeightInitializer, props::BiasInitializer,
-                 props::DisableBias>>()) {}
+                 props::WeightInitializer, props::WeightDecay, props::BiasDecay,
+                 props::BiasInitializer, props::DisableBias>>()) {}
 
 void LayerImpl::setProperty(const std::vector<std::string> &values) {
   auto remain_props = loadProperties(values, *layer_impl_props);
index 4614846..b741a91 100644 (file)
@@ -35,6 +35,8 @@ namespace props {
 class WeightRegularizer;
 class WeightRegularizerConstant;
 class WeightInitializer;
+class WeightDecay;
+class BiasDecay;
 class BiasInitializer;
 class DisableBias;
 } // namespace props
@@ -83,9 +85,10 @@ public:
                         const ExportMethods &method) const override;
 
 protected:
-  std::unique_ptr<std::tuple<
-    props::WeightRegularizer, props::WeightRegularizerConstant,
-    props::WeightInitializer, props::BiasInitializer, props::DisableBias>>
+  std::unique_ptr<
+    std::tuple<props::WeightRegularizer, props::WeightRegularizerConstant,
+               props::WeightInitializer, props::WeightDecay, props::BiasDecay,
+               props::BiasInitializer, props::DisableBias>>
     layer_impl_props; /**< layer_impl_props */
 };
 
index 2993f79..5cf5b73 100644 (file)
@@ -70,10 +70,10 @@ void lstmcell_calcGradient(
   ActiFunc &recurrent_acti_func, const Tensor &input,
   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
   const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
-  const Tensor &d_hidden_state, const Tensor &cell_state, const Tensor &d_cell_state,
-  Tensor &d_weight_ih, const Tensor &weight_hh, Tensor &d_weight_hh,
-  Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &ifgo,
-  Tensor &d_ifgo) {
+  const Tensor &d_hidden_state, const Tensor &cell_state,
+  const Tensor &d_cell_state, Tensor &d_weight_ih, const Tensor &weight_hh,
+  Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh,
+  const Tensor &ifgo, Tensor &d_ifgo) {
   Tensor input_forget_gate =
     ifgo.getSharedDataTensor({batch_size, 1, 1, unit * 2}, 0, false);
   Tensor input_gate =
index bf05390..b39d2e8 100644 (file)
 
 namespace nntrainer {
 
+/**
+ * @brief lstm cell forwarding implementation
+ *
+ */
 void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
                          const bool disable_bias, const bool integrate_bias,
                          ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
@@ -28,18 +32,28 @@ void lstmcell_forwarding(const unsigned int unit, const unsigned int batch_size,
                          const Tensor &weight_hh, const Tensor &bias_h,
                          const Tensor &bias_ih, const Tensor &bias_hh,
                          Tensor &ifgo);
+
+/**
+ * @brief lstm cell calculate derivative implementation
+ *
+ */
 void lstmcell_calcDerivative(const Tensor &d_ifgo, const Tensor &weight_ih,
                              Tensor &outgoing_derivative);
+
+/**
+ * @brief lstm cell calculate gradient implementation
+ *
+ */
 void lstmcell_calcGradient(
   const unsigned int unit, const unsigned int batch_size,
   const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
   ActiFunc &recurrent_acti_func, const Tensor &input,
   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
   const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
-  const Tensor &d_hidden_state, const Tensor &cell_state, const Tensor &d_cell_state,
-  Tensor &d_weight_ih, const Tensor &weight_hh, Tensor &d_weight_hh,
-  Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &ifgo,
-  Tensor &d_ifgo);
+  const Tensor &d_hidden_state, const Tensor &cell_state,
+  const Tensor &d_cell_state, Tensor &d_weight_ih, const Tensor &weight_hh,
+  Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh,
+  const Tensor &ifgo, Tensor &d_ifgo);
 
 } // namespace nntrainer
 
index 3890f34..2c53970 100644 (file)
@@ -25,6 +25,7 @@ Weight::Weight(const TensorDim &dim, const Tensor::Initializer init,
   Var_Grad(dim, init, train, alloc_now_, name),
   regularizer(reg),
   regularizer_constant(reg_const),
+  decay(0.0f),
   clip_by_global_norm(max_norm) {
   if (init == Tensor::Initializer::NONE)
     throw std::invalid_argument("Weight initializer cannot be none");
index d13ae1b..d5f2eeb 100644 (file)
@@ -43,6 +43,7 @@ public:
     Var_Grad(),
     regularizer(WeightRegularizer::UNKNOWN),
     regularizer_constant(1.0f),
+    decay(0.0f),
     clip_by_global_norm(0.0f) {}
 
   /**
@@ -98,6 +99,7 @@ public:
     Var_Grad(v, g, n, is_dependent),
     regularizer(WeightRegularizer::NONE),
     regularizer_constant(1.0f),
+    decay(0.0f),
     clip_by_global_norm(0.0f) {}
 
   /**
@@ -114,6 +116,8 @@ public:
     Var_Grad(v, g, is_dependent),
     regularizer(reg),
     regularizer_constant(reg_const),
+    // TODO: set properly
+    decay(0.0f),
     clip_by_global_norm(max_norm) {}
 
   /**
@@ -128,6 +132,7 @@ public:
     swap(static_cast<Var_Grad &>(lhs), static_cast<Var_Grad &>(rhs));
     swap(lhs.regularizer, rhs.regularizer);
     swap(lhs.regularizer_constant, rhs.regularizer_constant);
+    swap(lhs.decay, rhs.decay);
     swap(lhs.clip_by_global_norm, rhs.clip_by_global_norm);
     swap(lhs.opt_vars, rhs.opt_vars);
   }
@@ -212,6 +217,12 @@ public:
   }
 
   /**
+   * @brief     check if weight decay is enabled
+   * @return    true if weight decay is enabled else false
+   */
+  bool isWeightDecay() { return decay > epsilon_decay; }
+
+  /**
    * @brief     Get loss from the regularization of the weight
    */
   float getRegularizationLoss() {
@@ -232,7 +243,12 @@ public:
   /**
    * @brief     Apply the gradient to the weight
    */
-  void applyGradient(double lr) { var->add_i(*grad.get(), -lr); }
+  void applyGradient(double lr) {
+    if (isWeightDecay())
+      applyWeightDecay();
+
+    var->add_i(*grad.get(), -lr);
+  }
 
   /**
    * @brief Check if the gradient is supposed to be clipped by global norm with
@@ -268,11 +284,19 @@ public:
 
 private:
   static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */
+  static constexpr float epsilon_decay =
+    1e-8; /**< epsilon for zero comparison */
 
   WeightRegularizer regularizer; /**< regularizer for this variable */
   float regularizer_constant;    /**< constant factor for regularization */
+  float decay;                   /**< constant factor for the weight decay */
   float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */
   std::vector<Tensor *> opt_vars; /**< optimizer variables */
+
+  /**
+   * @brief     Apply the gradient to the weight
+   */
+  void applyWeightDecay() { var->add_i(*var.get(), -decay); }
 };
 
 } // namespace nntrainer
index 6a11c4b..ee4a5aa 100644 (file)
@@ -66,8 +66,9 @@ void Exporter::saveTflResult(
 template <>
 void Exporter::saveTflResult(
   const std::tuple<props::WeightRegularizer, props::WeightRegularizerConstant,
-                   props::WeightInitializer, props::BiasInitializer,
-                   props::DisableBias> &props,
+                   props::WeightInitializer, props::WeightDecay,
+                   props::BiasDecay, props::BiasInitializer, props::DisableBias>
+    &props,
   const LayerImpl *self) { /// layer impl has nothing to serialize so do nothing
 }
 
index a057b31..2a65dd5 100644 (file)
@@ -212,6 +212,8 @@ class InputShape;
 class WeightRegularizer;
 class WeightRegularizerConstant;
 class WeightInitializer;
+class WeightDecay;
+class BiasDecay;
 class BiasInitializer;
 class SharedFrom;
 class InputConnection;
@@ -241,8 +243,9 @@ class LayerImpl;
 template <>
 void Exporter::saveTflResult(
   const std::tuple<props::WeightRegularizer, props::WeightRegularizerConstant,
-                   props::WeightInitializer, props::BiasInitializer,
-                   props::DisableBias> &props,
+                   props::WeightInitializer, props::WeightDecay,
+                   props::BiasDecay, props::BiasInitializer, props::DisableBias>
+    &props,
   const LayerImpl *self);
 
 class FullyConnectedLayer;