[weight-decay] Enable for batch normalization
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 10 Feb 2022 02:05:15 +0000 (11:05 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 10 Feb 2022 04:29:44 +0000 (13:29 +0900)
Enable weight decay for batch normalization.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/bn_layer.cpp
nntrainer/layers/bn_layer.h

index 107bcd8..109e491 100644 (file)
@@ -50,7 +50,8 @@ BatchNormalizationLayer::BatchNormalizationLayer() :
   divider(0),
   bn_props(props::Epsilon(), props::BNPARAMS_MU_INIT(),
            props::BNPARAMS_VAR_INIT(), props::BNPARAMS_BETA_INIT(),
-           props::BNPARAMS_GAMMA_INIT(), props::Momentum(), props::Axis()) {
+           props::BNPARAMS_GAMMA_INIT(), props::Momentum(), props::Axis(),
+           props::WeightDecay(), props::BiasDecay()) {
   wt_idx.fill(std::numeric_limits<unsigned>::max());
 }
 
@@ -65,6 +66,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
   auto &bnparams_var = std::get<props::BNPARAMS_VAR_INIT>(bn_props);
   auto &bnparams_beta = std::get<props::BNPARAMS_BETA_INIT>(bn_props);
   auto &bnparams_gamma = std::get<props::BNPARAMS_GAMMA_INIT>(bn_props);
+  auto &weight_decay = std::get<props::WeightDecay>(bn_props);
+  auto &bias_decay = std::get<props::BiasDecay>(bn_props);
 
   std::vector<TensorDim> output_dims(1);
 
@@ -105,11 +108,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
   wt_idx[BNParams::var] =
     context.requestWeight(dim, bnparams_var, WeightRegularizer::NONE, 1.0f,
                           0.0f, "moving_variance", false);
-  // TODO: setup decay for gamma and beta
-  wt_idx[BNParams::gamma] = context.requestWeight(
-    dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f, 0.0f, "gamma", true);
-  wt_idx[BNParams::beta] = context.requestWeight(
-    dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, 0.0f, "beta", true);
+  wt_idx[BNParams::gamma] =
+    context.requestWeight(dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f,
+                          weight_decay, "gamma", true);
+  wt_idx[BNParams::beta] =
+    context.requestWeight(dim, bnparams_beta, WeightRegularizer::NONE, 1.0f,
+                          bias_decay, "beta", true);
 
   /**
    * caches the deviation -> input - avg(input)
index 5ecee09..281b80f 100644 (file)
@@ -125,7 +125,7 @@ private:
   std::array<unsigned int, 9> wt_idx; /**< indices of the weights and tensors */
   std::tuple<props::Epsilon, props::BNPARAMS_MU_INIT, props::BNPARAMS_VAR_INIT,
              props::BNPARAMS_BETA_INIT, props::BNPARAMS_GAMMA_INIT,
-             props::Momentum, props::Axis>
+             props::Momentum, props::Axis, props::WeightDecay, props::BiasDecay>
     bn_props;
 };