Tensor &beta = weightAt(static_cast<int>(BNParams::beta)).getVariableRef();
input = *in[0];
- /// @todo change trainable #524
+ /// @todo change trainable to train/eval mode #524
if (trainable) {
Tensor cmu = input.average(axes_to_reduce);
deviation = input.subtract(cmu);
Tensor deriv = *derivative[0];
int N = 1;
-
for (auto &axis : axes_to_reduce) {
N *= input_dim[0].getTensorDim(axis);
}
- dbeta = deriv.sum(axes_to_reduce);
- dgamma = deviation.multiply(invstd).multiply(deriv).sum(axes_to_reduce);
-
Tensor dx_1 = gamma.multiply(invstd);
Tensor dx_2 = deriv.multiply(N);
dx_2.subtract_i(deriv.sum(axes_to_reduce));
Tensor dx = dx_2.multiply(dx_1);
dx.divide_i(N);
- opt->apply_gradients(weight_list, num_weights, iteration);
+ if (trainable) {
+ dbeta = deriv.sum(axes_to_reduce);
+ dgamma = deviation.multiply(invstd).multiply(deriv).sum(axes_to_reduce);
+ opt->apply_gradients(weight_list, num_weights, iteration);
+ }
return {MAKE_SHARED_TENSOR(std::move(dx))};
}