Tensor cmu = input_.average(axes_to_reduce);
input_.subtract(cmu, deviation);
- cvar = deviation.pow(2.0f).average(axes_to_reduce);
+ Tensor t1 = deviation.pow(2.0f);
+ cvar = t1.average(axes_to_reduce);
mu.multiply_i(momentum);
mu.add_i(cmu, 1 - momentum);
var.add_i(cvar, 1 - momentum);
cvar.add_i(epsilon);
- invstd = cvar.pow(-0.5f);
+ cvar.pow(-0.5f, invstd);
} else {
- deviation = input_.subtract(mu);
- invstd = var.add(epsilon);
+ input_.subtract(mu, deviation);
+ var.add(epsilon, invstd);
invstd.pow_i(-0.5f);
}
- hidden_ = deviation.multiply(invstd, hidden_);
+ deviation.multiply(invstd, hidden_);
hidden_.multiply_i(gamma);
hidden_.add_i(beta);
}
Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
- int N = 1;
- const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
- const TensorDim &in_dim = input.getDim();
- for (auto &axis : axes_to_reduce) {
- N *= in_dim.getTensorDim(axis);
- }
-
Tensor dx_1 = gamma.multiply(invstd);
- Tensor dx_2 = deriv.multiply(N);
- dx_2.subtract_i(deriv.sum(axes_to_reduce));
- dx_2.subtract_i(deviation.divide(cvar).multiply(
- deviation.multiply(deriv).sum(axes_to_reduce)));
+ Tensor dx_2 = deriv.subtract(deriv.average(axes_to_reduce));
+
+ Tensor t1 = deviation.multiply(deriv);
+ Tensor t2 = t1.average(axes_to_reduce);
+ deviation.divide_i(cvar);
+ deviation.multiply_i(t2);
+ dx_2.subtract_i(deviation);
- dx = dx_2.multiply(dx_1, dx);
- dx.divide_i(N);
+ dx_2.multiply(dx_1, dx);
}
void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
- dbeta = deriv.sum(axes_to_reduce);
- Tensor dev = deviation.multiply(invstd);
- dev.multiply_i(deriv);
- dgamma = dev.sum(axes_to_reduce);
+ deriv.sum(axes_to_reduce, dbeta);
+ Tensor dev = deviation.multiply(deriv);
+ dev.multiply_i(invstd);
+ dev.sum(axes_to_reduce, dgamma);
}
void BatchNormalizationLayer::exportTo(Exporter &exporter,