};
protected:
+
+/**
+ * @brief check if current layer's weight decay type is l2norm
+ * @return bool is weightdecay type is L2 Norm
+ */
+ bool isWeightDecayL2Norm() {
+ return weight_decay.type == WeightDecayType::l2norm;
+ }
/**
* @brief Input Tensor
*/
Tensor input;
+
/**
* @brief Hidden Layer Tensor which store the
* forwading result
* @brief calculate optimizer and Update Weight & Bais
* @param[in] dJdW Weight derivative
* @param[in] dJdB Bias derivative
- * @param[in] Weight Weight Tensor
- * @param[in] Bias Bais Tensor
+ * @param[in/out] Weight Weight Tensor
+ * @param[in/out] Bias Bias Tensor
* @param[in] iteration nth epoch number
* @param[in] init_zero bool it is true if bias sets zero.
- * @param[in] weight_decay weight decay type & lambda
*/
- void calculate(Tensor &djdw, Tensor &djdb, Tensor &weight, Tensor &bias,
- int iteration, bool init_zero, WeightDecayParam weight_decay);
+ void calculate(const Tensor &djdw, const Tensor &djdb, Tensor &weight,
+ Tensor &bias, int iteration, bool init_zero);
/**
* @brief Property Enumeration
Tensor FullyConnectedLayer::backwarding(Tensor derivative, int iteration) {
Tensor ret = derivative.dot(weight.transpose("0:2:1"));
- Tensor djdw = input.transpose("0:2:1").dot(derivative);
-
- opt.calculate(djdw, derivative, weight, bias, iteration, this->init_zero,
- weight_decay);
+ Tensor djdw = input.chain()
+ .transpose("0:2:1")
+ .dot(derivative)
+ .applyIf(this->isWeightDecayL2Norm(), _LIFT(add_i), weight, weight_decay.lambda)
+ .run();
return ret;
}
return status;
}
-void Optimizer::calculate(Tensor &djdw, Tensor &djdb, Tensor &weight,
- Tensor &bias, int iteration, bool init_zero,
- WeightDecayParam weight_decay) {
+void Optimizer::calculate(const Tensor &djdw, const Tensor &djdb,
+ Tensor &weight, Tensor &bias, int iteration,
+ bool init_zero) {
Tensor djdwAvg, djdbAvg;
- if (weight_decay.type == WeightDecayType::l2norm) {
- djdw = djdw.add(weight.multiply(weight_decay.lambda));
- }
-
float ll = popt.learning_rate;
if (popt.decay_steps != -1) {
ll = ll * pow(popt.decay_rate, (iteration / popt.decay_steps));
}
- bool isL2norm = weight_decay.type == WeightDecayType::l2norm;
-
- djdwAvg = djdw.chain()
- .applyIf(isL2norm, _LIFT(add_i), weight, weight_decay.lambda)
- .average()
- .run();
-
+ djdwAvg = djdw.average();
djdbAvg = djdb.average();
switch (type) {