From: Parichay Kapoor Date: Thu, 10 Feb 2022 02:05:15 +0000 (+0900) Subject: [weight-decay] Enable for batch normalization X-Git-Tag: accepted/tizen/unified/20220323.062643~25 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d95930d8d44fa64e8f34321657ec7eb9cbd96a73;p=platform%2Fcore%2Fml%2Fnntrainer.git [weight-decay] Enable for batch normalization Enable weight decay for batch normalization. Signed-off-by: Parichay Kapoor --- diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index 107bcd8..109e491 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -50,7 +50,8 @@ BatchNormalizationLayer::BatchNormalizationLayer() : divider(0), bn_props(props::Epsilon(), props::BNPARAMS_MU_INIT(), props::BNPARAMS_VAR_INIT(), props::BNPARAMS_BETA_INIT(), - props::BNPARAMS_GAMMA_INIT(), props::Momentum(), props::Axis()) { + props::BNPARAMS_GAMMA_INIT(), props::Momentum(), props::Axis(), + props::WeightDecay(), props::BiasDecay()) { wt_idx.fill(std::numeric_limits::max()); } @@ -65,6 +66,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { auto &bnparams_var = std::get(bn_props); auto &bnparams_beta = std::get(bn_props); auto &bnparams_gamma = std::get(bn_props); + auto &weight_decay = std::get(bn_props); + auto &bias_decay = std::get(bn_props); std::vector output_dims(1); @@ -105,11 +108,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) { wt_idx[BNParams::var] = context.requestWeight(dim, bnparams_var, WeightRegularizer::NONE, 1.0f, 0.0f, "moving_variance", false); - // TODO: setup decay for gamma and beta - wt_idx[BNParams::gamma] = context.requestWeight( - dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f, 0.0f, "gamma", true); - wt_idx[BNParams::beta] = context.requestWeight( - dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, 0.0f, "beta", true); + wt_idx[BNParams::gamma] = + context.requestWeight(dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f, + weight_decay, "gamma", true); + wt_idx[BNParams::beta] = + context.requestWeight(dim, bnparams_beta, WeightRegularizer::NONE, 1.0f, + bias_decay, "beta", true); /** * caches the deviation -> input - avg(input) diff --git a/nntrainer/layers/bn_layer.h b/nntrainer/layers/bn_layer.h index 5ecee09..281b80f 100644 --- a/nntrainer/layers/bn_layer.h +++ b/nntrainer/layers/bn_layer.h @@ -125,7 +125,7 @@ private: std::array wt_idx; /**< indices of the weights and tensors */ std::tuple + props::Momentum, props::Axis, props::WeightDecay, props::BiasDecay> bn_props; };