[layers] Support for weight decay to layers
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 24 Jan 2022 09:54:32 +0000 (18:54 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 9 Feb 2022 09:34:12 +0000 (18:34 +0900)
Add support for weight decay to the layers with weights.
Further update the requestWeight API of the layer context to accept
weight decay, and update the usage at manager.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
19 files changed:
nntrainer/graph/network_graph.cpp
nntrainer/layers/bn_layer.cpp
nntrainer/layers/centroid_knn.cpp
nntrainer/layers/conv2d_layer.cpp
nntrainer/layers/embedding.cpp
nntrainer/layers/fc_layer.cpp
nntrainer/layers/gru.cpp
nntrainer/layers/grucell.cpp
nntrainer/layers/layer_context.h
nntrainer/layers/lstm.cpp
nntrainer/layers/lstmcell.cpp
nntrainer/layers/mol_attention_layer.cpp
nntrainer/layers/rnn.cpp
nntrainer/layers/rnncell.cpp
nntrainer/layers/zoneout_lstmcell.cpp
nntrainer/tensor/manager.cpp
nntrainer/tensor/tensor_wrap_specs.h
nntrainer/tensor/weight.cpp
nntrainer/tensor/weight.h

index 6101a84..2b22070 100644 (file)
@@ -792,7 +792,7 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
 
     const auto &w_specs = init_context.getWeightsSpec();
     for (auto i = 0u; i < w_specs.size(); ++i) {
-      shared_weight_names.emplace_back(std::get<6>(w_specs.at(i)));
+      shared_weight_names.emplace_back(std::get<7>(w_specs.at(i)));
     }
   }
 
index f65d042..107bcd8 100644 (file)
@@ -99,14 +99,17 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
     }
   }
 
-  wt_idx[BNParams::mu] = context.requestWeight(
-    dim, bnparams_mu, WeightRegularizer::NONE, 1.0f, "moving_mean", false);
-  wt_idx[BNParams::var] = context.requestWeight(
-    dim, bnparams_var, WeightRegularizer::NONE, 1.0f, "moving_variance", false);
+  wt_idx[BNParams::mu] =
+    context.requestWeight(dim, bnparams_mu, WeightRegularizer::NONE, 1.0f, 0.0f,
+                          "moving_mean", false);
+  wt_idx[BNParams::var] =
+    context.requestWeight(dim, bnparams_var, WeightRegularizer::NONE, 1.0f,
+                          0.0f, "moving_variance", false);
+  // TODO: setup decay for gamma and beta
   wt_idx[BNParams::gamma] = context.requestWeight(
-    dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f, "gamma", true);
+    dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f, 0.0f, "gamma", true);
   wt_idx[BNParams::beta] = context.requestWeight(
-    dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, "beta", true);
+    dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, 0.0f, "beta", true);
 
   /**
    * caches the deviation -> input - avg(input)
index 4ebc5da..ac439ab 100644 (file)
@@ -63,11 +63,11 @@ void CentroidKNN::finalize(nntrainer::InitLayerContext &context) {
 
   weight_idx[KNNParams::map] = context.requestWeight(
     map_dim, nntrainer::Tensor::Initializer::ZEROS,
-    nntrainer::WeightRegularizer::NONE, 1.0f, "map", false);
+    nntrainer::WeightRegularizer::NONE, 1.0f, 0.0f, "map", false);
 
   weight_idx[KNNParams::num_samples] = context.requestWeight(
     samples_seen, nntrainer::Tensor::Initializer::ZEROS,
-    nntrainer::WeightRegularizer::NONE, 1.0f, "num_samples", false);
+    nntrainer::WeightRegularizer::NONE, 1.0f, 0.0f, "num_samples", false);
 }
 
 void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
index 41f4d25..27f47e2 100644 (file)
@@ -265,6 +265,8 @@ void Conv2DLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
   auto &weight_initializer =
     std::get<props::WeightInitializer>(*layer_impl_props);
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
   auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
 
@@ -280,13 +282,14 @@ void Conv2DLayer::finalize(InitLayerContext &context) {
   padding = std::get<props::Padding2D>(conv_props)
               .compute(in_dim, dim, {stride[0], stride[1]});
 
-  wt_idx[ConvParams::weight] =
-    context.requestWeight(dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "filter", true);
+  wt_idx[ConvParams::weight] = context.requestWeight(
+    dim, weight_initializer, weight_regularizer, weight_regularizer_constant,
+    weight_decay, "filter", true);
 
   if (disable_bias.empty() || disable_bias.get() == false) {
-    wt_idx[ConvParams::bias] = context.requestWeight(
-      bias_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias", true);
+    wt_idx[ConvParams::bias] =
+      context.requestWeight(bias_dim, bias_initializer, WeightRegularizer::NONE,
+                            1.0f, bias_decay, "bias", true);
   }
 
   // this output_dim must be the same with dimension of hidden
index 053b6c4..315ad9f 100644 (file)
@@ -47,6 +47,7 @@ void EmbeddingLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
   auto &weight_initializer =
     std::get<props::WeightInitializer>(*layer_impl_props);
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
 
   unsigned int in_dim = std::get<props::InDim>(embedding_props);
   unsigned int out_dim = std::get<props::OutDim>(embedding_props);
@@ -62,9 +63,9 @@ void EmbeddingLayer::finalize(InitLayerContext &context) {
   dim.width(out_dim);
   dim.batch(1);
 
-  weight_idx =
-    context.requestWeight(dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "Embedding", true);
+  weight_idx = context.requestWeight(
+    dim, weight_initializer, weight_regularizer, weight_regularizer_constant,
+    weight_decay, "Embedding", true);
 }
 
 void EmbeddingLayer::setProperty(const std::vector<std::string> &values) {
index cf6d0b6..5fa4f42 100644 (file)
@@ -49,6 +49,8 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
   auto &weight_initializer =
     std::get<props::WeightInitializer>(*layer_impl_props);
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
   auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
 
@@ -75,13 +77,14 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
   TensorDim bias_dim(1, 1, 1, unit, 0b0001);
   TensorDim weight_dim(1, 1, in_dim.width(), unit, 0b0011);
 
-  weight_idx[FCParams::weight] =
-    context.requestWeight(weight_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight", true);
+  weight_idx[FCParams::weight] = context.requestWeight(
+    weight_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight", true);
 
   if (disable_bias.empty() || disable_bias.get() == false) {
-    weight_idx[FCParams::bias] = context.requestWeight(
-      bias_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias", true);
+    weight_idx[FCParams::bias] =
+      context.requestWeight(bias_dim, bias_initializer, WeightRegularizer::NONE,
+                            1.0f, bias_decay, "bias", true);
   }
 }
 
index e874582..c589921 100644 (file)
@@ -72,6 +72,8 @@ void GRULayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizer>(*layer_impl_props).get();
   const float weight_regularizer_constant =
     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -110,36 +112,36 @@ void GRULayer::finalize(InitLayerContext &context) {
   // - weight_ih ( input to hidden )
   // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
-  wt_idx[GRUParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+  wt_idx[GRUParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // - weight_hh ( hidden to hidden )
   // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  wt_idx[GRUParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[GRUParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // - bias_h ( input bias, hidden bias are integrate to 1 bias )
       // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
       TensorDim bias_h_dim({NUM_GATE * unit});
-      wt_idx[GRUParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[GRUParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // - bias_ih ( input bias )
       // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
       TensorDim bias_ih_dim({NUM_GATE * unit});
-      wt_idx[GRUParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      wt_idx[GRUParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // - bias_hh ( hidden bias )
       // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
       TensorDim bias_hh_dim({NUM_GATE * unit});
-      wt_idx[GRUParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+      wt_idx[GRUParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
index 08e8be9..d50b181 100644 (file)
@@ -278,6 +278,8 @@ void GRUCellLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizer>(*layer_impl_props).get();
   const float weight_regularizer_constant =
     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -325,36 +327,36 @@ void GRUCellLayer::finalize(InitLayerContext &context) {
   // - weight_ih ( input to hidden )
   // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
-  wt_idx[GRUCellParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+  wt_idx[GRUCellParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // - weight_hh ( hidden to hidden )
   // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  wt_idx[GRUCellParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[GRUCellParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // - bias_h ( input bias, hidden bias are integrate to 1 bias )
       // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
       TensorDim bias_h_dim({NUM_GATE * unit});
-      wt_idx[GRUCellParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[GRUCellParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // - bias_ih ( input bias )
       // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
       TensorDim bias_ih_dim({NUM_GATE * unit});
-      wt_idx[GRUCellParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      wt_idx[GRUCellParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // - bias_hh ( hidden bias )
       // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
       TensorDim bias_hh_dim({NUM_GATE * unit});
-      wt_idx[GRUCellParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+      wt_idx[GRUCellParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
index ef2bf26..38b01f7 100644 (file)
@@ -128,9 +128,11 @@ public:
   unsigned int requestWeight(const TensorDim &dim,
                              const Tensor::Initializer init,
                              const WeightRegularizer reg, const float reg_const,
-                             const std::string &name, bool trainable = true) {
-    weights_spec.emplace_back(dim, init, reg, reg_const, clip_by_global_norm,
-                              trainable, prefix + ":" + name);
+                             const float decay, const std::string &name,
+                             bool trainable = true) {
+    weights_spec.emplace_back(dim, init, reg, reg_const, decay,
+                              clip_by_global_norm, trainable,
+                              prefix + ":" + name);
     return weights_spec.size() - 1;
   }
 
index b0a35ab..78bbcb3 100644 (file)
@@ -149,6 +149,8 @@ void LSTMLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizer>(*layer_impl_props).get();
   const float weight_regularizer_constant =
     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -200,33 +202,31 @@ void LSTMLayer::finalize(InitLayerContext &context) {
   const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
   wt_idx[LSTMParams::weight_ih] =
     context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+                          weight_regularizer_constant, weight_decay, "weight_ih", true);
   // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i,
   // f, g, o
   const TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  wt_idx[LSTMParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[LSTMParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1,
       // 1, 1, NUM_GATE * unit ] -> i, f, g, o
       const TensorDim bias_h_dim({NUM_GATE * unit});
-      wt_idx[LSTMParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[LSTMParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
       const TensorDim bias_ih_dim({NUM_GATE * unit});
       wt_idx[LSTMParams::bias_ih] =
         context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
-      // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g,
-      // o
-      const TensorDim bias_hh_dim({NUM_GATE * unit});
-      wt_idx[LSTMParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+                              WeightRegularizer::NONE, 1.0f, bias_decay, "bias_ih", true);
+      // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
+      wt_idx[LSTMParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
index 32c3b76..6e84220 100644 (file)
@@ -51,6 +51,8 @@ void LSTMCellLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizer>(*layer_impl_props).get();
   const float weight_regularizer_constant =
     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -118,36 +120,36 @@ void LSTMCellLayer::finalize(InitLayerContext &context) {
   // - weight_ih ( input to hidden )
   //  : [ 1, 1, feature_size, NUM_GATE x unit ] -> i, f, g, o
   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
-  wt_idx[LSTMCellParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+  wt_idx[LSTMCellParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // - weight_hh ( hidden to hidden )
   //  : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g, o
   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  wt_idx[LSTMCellParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[LSTMCellParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // - bias_h ( input bias, hidden bias are integrate to 1 bias )
       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
       TensorDim bias_h_dim({NUM_GATE * unit});
-      wt_idx[LSTMCellParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[LSTMCellParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // - bias_ih ( input bias )
       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
       TensorDim bias_ih_dim({NUM_GATE * unit});
-      wt_idx[LSTMCellParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      wt_idx[LSTMCellParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // - bias_hh ( hidden bias )
       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
       TensorDim bias_hh_dim({NUM_GATE * unit});
-      wt_idx[LSTMCellParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+      wt_idx[LSTMCellParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
index 2453f98..0848c58 100644 (file)
@@ -91,20 +91,22 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
   auto &weight_initializer =
     std::get<props::WeightInitializer>(*layer_impl_props);
   auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
 
   TensorDim fc_w_dim = {query_dim.width(), unit};
-  wt_idx[AttentionParams::fc_w] =
-    context.requestWeight(fc_w_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "fc_w", true);
+  wt_idx[AttentionParams::fc_w] = context.requestWeight(
+    fc_w_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "fc_w", true);
   TensorDim fc_bias_dim = {unit};
-  wt_idx[AttentionParams::fc_bias] =
-    context.requestWeight(fc_bias_dim, bias_initializer, weight_regularizer,
-                          weight_regularizer_constant, "fc_bias", true);
+  wt_idx[AttentionParams::fc_bias] = context.requestWeight(
+    fc_bias_dim, bias_initializer, weight_regularizer,
+    weight_regularizer_constant, bias_decay, "fc_bias", true);
 
   TensorDim fc_proj_w_dim = {unit, 3 * mol_k};
-  wt_idx[AttentionParams::fc_proj_w] =
-    context.requestWeight(fc_proj_w_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "fc_proj_w", true);
+  wt_idx[AttentionParams::fc_proj_w] = context.requestWeight(
+    fc_proj_w_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "fc_proj_w", true);
 
   TensorDim fc_out_dim = query_dim;
   fc_out_dim.width(fc_w_dim.width());
index 2a9cef2..c0aea09 100644 (file)
@@ -57,6 +57,8 @@ void RNNLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightInitializer>(*layer_impl_props);
   const Tensor::Initializer bias_initializer =
     std::get<props::BiasInitializer>(*layer_impl_props);
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -90,32 +92,32 @@ void RNNLayer::finalize(InitLayerContext &context) {
 
   // weight_ih_dim : [ 1, 1, feature_size, unit ]
   const TensorDim weight_ih_dim({feature_size, unit});
-  wt_idx[RNNParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+  wt_idx[RNNParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // weight_hh_dim : [ 1, 1, unit, unit ]
   const TensorDim weight_hh_dim({unit, unit});
-  wt_idx[RNNParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[RNNParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // bias_h_dim : [ 1, 1, 1, unit ]
       const TensorDim bias_h_dim({unit});
-      wt_idx[RNNParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[RNNParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // bias_ih_dim : [ 1, 1, 1, unit ]
       const TensorDim bias_ih_dim({unit});
-      wt_idx[RNNParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      wt_idx[RNNParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // bias_hh_dim : [ 1, 1, 1, unit ]
       const TensorDim bias_hh_dim({unit});
-      wt_idx[RNNParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+      wt_idx[RNNParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
index 7a8b903..645962b 100644 (file)
@@ -60,6 +60,8 @@ void RNNCellLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightInitializer>(*layer_impl_props);
   const Tensor::Initializer bias_initializer =
     std::get<props::BiasInitializer>(*layer_impl_props);
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -96,32 +98,32 @@ void RNNCellLayer::finalize(InitLayerContext &context) {
 
   // weight_ih_dim : [ 1, 1, feature_size, unit ]
   const TensorDim weight_ih_dim({feature_size, unit});
-  wt_idx[RNNCellParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+  wt_idx[RNNCellParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // weight_hh_dim : [ 1, 1, unit, unit ]
   const TensorDim weight_hh_dim({unit, unit});
-  wt_idx[RNNCellParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[RNNCellParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // bias_h_dim : [ 1, 1, 1, unit ]
       const TensorDim bias_h_dim({unit});
-      wt_idx[RNNCellParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[RNNCellParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // bias_ih_dim : [ 1, 1, 1, unit ]
       const TensorDim bias_ih_dim({unit});
-      wt_idx[RNNCellParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      wt_idx[RNNCellParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // bias_hh_dim : [ 1, 1, 1, unit ]
       const TensorDim bias_hh_dim({unit});
-      wt_idx[RNNCellParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+      wt_idx[RNNCellParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
index 9620541..b0a7845 100644 (file)
@@ -74,6 +74,8 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     std::get<props::WeightRegularizer>(*layer_impl_props).get();
   const float weight_regularizer_constant =
     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
+  auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
+  auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
   const bool disable_bias =
     std::get<props::DisableBias>(*layer_impl_props).get();
 
@@ -150,16 +152,16 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
   //  : [ 1, 1, feature_size, NUM_GATE x unit ] ->
   //  i, f, g, o
   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
-  wt_idx[ZoneoutLSTMParams::weight_ih] =
-    context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_ih", true);
+  wt_idx[ZoneoutLSTMParams::weight_ih] = context.requestWeight(
+    weight_ih_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_ih", true);
   // - weight_hh ( hidden to hidden )
   //  : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g,
   //  o
   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
-  wt_idx[ZoneoutLSTMParams::weight_hh] =
-    context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer,
-                          weight_regularizer_constant, "weight_hh", true);
+  wt_idx[ZoneoutLSTMParams::weight_hh] = context.requestWeight(
+    weight_hh_dim, weight_initializer, weight_regularizer,
+    weight_regularizer_constant, weight_decay, "weight_hh", true);
   if (!disable_bias) {
     if (integrate_bias) {
       // - bias_h ( input bias, hidden bias are
@@ -167,24 +169,24 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
       //  o
       TensorDim bias_h_dim({NUM_GATE * unit});
-      wt_idx[ZoneoutLSTMParams::bias_h] =
-        context.requestWeight(bias_h_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_h", true);
+      wt_idx[ZoneoutLSTMParams::bias_h] = context.requestWeight(
+        bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+        "bias_h", true);
     } else {
       // - bias_ih ( input bias )
       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
       //  o
       TensorDim bias_ih_dim({NUM_GATE * unit});
-      wt_idx[ZoneoutLSTMParams::bias_ih] =
-        context.requestWeight(bias_ih_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_ih", true);
+      wt_idx[ZoneoutLSTMParams::bias_ih] = context.requestWeight(
+        bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_ih", true);
       // - bias_hh ( hidden bias )
       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
       //  o
       TensorDim bias_hh_dim({NUM_GATE * unit});
-      wt_idx[ZoneoutLSTMParams::bias_hh] =
-        context.requestWeight(bias_hh_dim, bias_initializer,
-                              WeightRegularizer::NONE, 1.0f, "bias_hh", true);
+      wt_idx[ZoneoutLSTMParams::bias_hh] = context.requestWeight(
+        bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
+        bias_decay, "bias_hh", true);
     }
   }
 
@@ -210,7 +212,7 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
     wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
       context.requestWeight(hidden_state_zoneout_mask_dim,
                             Tensor::Initializer::NONE, WeightRegularizer::NONE,
-                            1.0f, "hidden_state_zoneout_mask", false);
+                            1.0f, 0.0f, "hidden_state_zoneout_mask", false);
   } else {
     wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
       context.requestTensor(
@@ -224,7 +226,7 @@ void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
   if (test) {
     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight(
       cell_state_zoneout_mask_dim, Tensor::Initializer::NONE,
-      WeightRegularizer::NONE, 1.0f, "cell_state_zoneout_mask", false);
+      WeightRegularizer::NONE, 1.0f, 0.0f, "cell_state_zoneout_mask", false);
   } else {
     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
       cell_state_zoneout_mask_dim, "cell_state_zoneout_mask",
index fb3fdc2..7b212e0 100644 (file)
@@ -367,7 +367,7 @@ std::vector<Weight *> Manager::requestWeights(
   size_t current_size = weights_v2.size();
 
   for (unsigned int i = 0; i < weights_spec.size(); ++i) {
-    auto &[dim, t_initializer, w_reg, w_reg_const, clip_by_global_norm,
+    auto &[dim, t_initializer, w_reg, w_reg_const, decay, clip_by_global_norm,
            need_gradient, name] = weights_spec.at(i);
     auto grad_exec_order = default_grad_exec_order;
     /**
@@ -404,7 +404,7 @@ std::vector<Weight *> Manager::requestWeights(
     }
 
     weights_v2.emplace_back(std::make_unique<Weight>(
-      var, grad, w_reg, w_reg_const, is_dependent, clip_by_global_norm));
+      var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm));
   }
 
   std::transform(weights_v2.begin() + current_size, weights_v2.end(),
index 504215e..354a37b 100644 (file)
@@ -61,11 +61,11 @@ enum class TensorLifespan {
  * @brief Specification of the Weight as a tensor wrapper
  *
  * @details The tuple values are dimension, initializer, regularizer,
- * regularizer_constant, clip gradient constant, need_gradient property amd name
- * of the tensor object.
+ * regularizer_constant, decay, clip gradient constant, need_gradient property
+ * amd name of the tensor object.
  */
 typedef std::tuple<TensorDim, Tensor::Initializer, WeightRegularizer, float,
-                   float, bool, const std::string>
+                   float, float, bool, const std::string>
   WeightSpec;
 
 /**
@@ -188,6 +188,7 @@ struct WeightSpecV2 {
   VarGradSpecV2 vg_spec; /**< variable + graident specification */
   WeightRegularizer regularizer = WeightRegularizer::NONE; /**< regularizer */
   float regularizer_constant = 0.0f; /**< regularizer constant */
+  float decay = 0.0f;                /**< decay constant */
   float clip_by_global_norm = 0.0f;  /**< clip the gradient by norm */
 };
 
index 2c53970..c69a7c5 100644 (file)
@@ -20,12 +20,12 @@ namespace nntrainer {
 
 Weight::Weight(const TensorDim &dim, const Tensor::Initializer init,
                const WeightRegularizer reg, const float reg_const,
-               const float max_norm, bool train, bool alloc_now_,
-               std::string name) :
+               const float decay_const, const float max_norm, bool train,
+               bool alloc_now_, std::string name) :
   Var_Grad(dim, init, train, alloc_now_, name),
   regularizer(reg),
   regularizer_constant(reg_const),
-  decay(0.0f),
+  decay(decay_const),
   clip_by_global_norm(max_norm) {
   if (init == Tensor::Initializer::NONE)
     throw std::invalid_argument("Weight initializer cannot be none");
index d5f2eeb..a8b0525 100644 (file)
@@ -61,8 +61,9 @@ public:
     const TensorDim &dim,
     const Tensor::Initializer init = Tensor::Initializer::XAVIER_UNIFORM,
     const WeightRegularizer reg = WeightRegularizer::NONE,
-    const float reg_const = 1.0f, const float clip_by_global_norm = 0.0f,
-    bool ng = true, bool alloc_now = false, std::string name = "");
+    const float reg_const = 1.0f, const float decay = 0.0f,
+    const float clip_by_global_norm = 0.0f, bool ng = true,
+    bool alloc_now = false, std::string name = "");
 
   /**
    * @brief Construct a new Weight object
@@ -74,10 +75,11 @@ public:
            std::get<1>(spec), // Tensor::Initializer
            std::get<2>(spec), // WeightRegularizer
            std::get<3>(spec), // WeightRegularizerConstant
-           std::get<4>(spec), // MaxNorm for clipping
-           std::get<5>(spec), // need_gradient
+           std::get<4>(spec), // weight decay constant
+           std::get<5>(spec), // MaxNorm for clipping
+           std::get<6>(spec), // need_gradient
            alloc_now,
-           std::get<6>(spec) // Name
+           std::get<7>(spec) // Name
     ) {}
 
   /**
@@ -111,13 +113,12 @@ public:
    * @param reg_const Constant multiplier for regularizer
    */
   explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg,
-                  const float reg_const, bool is_dependent = false,
-                  const float max_norm = 0.0f) :
+                  const float reg_const, const float decay,
+                  bool is_dependent = false, const float max_norm = 0.0f) :
     Var_Grad(v, g, is_dependent),
     regularizer(reg),
     regularizer_constant(reg_const),
-    // TODO: set properly
-    decay(0.0f),
+    decay(decay),
     clip_by_global_norm(max_norm) {}
 
   /**