model_graph.applyGradients(
node.get(), [iteration, opt_ = opt.get()](Weight &w) {
w.calcRegularizationGradient();
+ w.calcWeightDecayGradient();
RunOptimizerContext opt_context(&w, iteration);
opt_->applyGradient(opt_context);
});
}
/**
- * @brief Calculate gradient from the regularizaiton of the weight
+ * @brief Calculate gradient from the regularization of the weight
*/
void calcRegularizationGradient() {
if (isWeightRegularizerL2Norm())
}
/**
- * @brief Apply the gradient to the weight
+ * @brief Calculate gradient from the decay of the weight
*/
- void applyGradient(double lr) {
+ void calcWeightDecayGradient() {
if (isWeightDecay())
applyWeightDecay();
-
- var->add_i(*grad.get(), -lr);
}
/**
+ * @brief Apply the gradient to the weight
+ */
+ void applyGradient(double lr) { var->add_i(*grad.get(), -lr); }
+
+ /**
* @brief Check if the gradient is supposed to be clipped by global norm with
* the given max_norm value
*