[batchnorm] Optimize batch normalization implementation
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 24 Sep 2021 09:08:38 +0000 (18:08 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 1 Oct 2021 04:28:07 +0000 (13:28 +0900)
This patch optimizes batch normalization implementation.
Reduces the temporary memory allocation and provide speedup for the
layer execution. With this patch, resnet18 runtime improves by approx
5%.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/bn_layer.cpp

index 9a553ec050e076d60b25dd9ac0250bf100d9294b..b2ecc18544ed308ee06f10cd35408dd95848ac36 100644 (file)
@@ -121,7 +121,8 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
     Tensor cmu = input_.average(axes_to_reduce);
     input_.subtract(cmu, deviation);
 
-    cvar = deviation.pow(2.0f).average(axes_to_reduce);
+    Tensor t1 = deviation.pow(2.0f);
+    cvar = t1.average(axes_to_reduce);
 
     mu.multiply_i(momentum);
     mu.add_i(cmu, 1 - momentum);
@@ -129,14 +130,14 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
     var.add_i(cvar, 1 - momentum);
 
     cvar.add_i(epsilon);
-    invstd = cvar.pow(-0.5f);
+    cvar.pow(-0.5f, invstd);
   } else {
-    deviation = input_.subtract(mu);
-    invstd = var.add(epsilon);
+    input_.subtract(mu, deviation);
+    var.add(epsilon, invstd);
     invstd.pow_i(-0.5f);
   }
 
-  hidden_ = deviation.multiply(invstd, hidden_);
+  deviation.multiply(invstd, hidden_);
   hidden_.multiply_i(gamma);
   hidden_.add_i(beta);
 }
@@ -148,21 +149,16 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
   Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
 
-  int N = 1;
-  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
-  const TensorDim &in_dim = input.getDim();
-  for (auto &axis : axes_to_reduce) {
-    N *= in_dim.getTensorDim(axis);
-  }
-
   Tensor dx_1 = gamma.multiply(invstd);
-  Tensor dx_2 = deriv.multiply(N);
-  dx_2.subtract_i(deriv.sum(axes_to_reduce));
-  dx_2.subtract_i(deviation.divide(cvar).multiply(
-    deviation.multiply(deriv).sum(axes_to_reduce)));
+  Tensor dx_2 = deriv.subtract(deriv.average(axes_to_reduce));
+
+  Tensor t1 = deviation.multiply(deriv);
+  Tensor t2 = t1.average(axes_to_reduce);
+  deviation.divide_i(cvar);
+  deviation.multiply_i(t2);
+  dx_2.subtract_i(deviation);
 
-  dx = dx_2.multiply(dx_1, dx);
-  dx.divide_i(N);
+  dx_2.multiply(dx_1, dx);
 }
 
 void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
@@ -172,10 +168,10 @@ void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
   Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
 
-  dbeta = deriv.sum(axes_to_reduce);
-  Tensor dev = deviation.multiply(invstd);
-  dev.multiply_i(deriv);
-  dgamma = dev.sum(axes_to_reduce);
+  deriv.sum(axes_to_reduce, dbeta);
+  Tensor dev = deviation.multiply(deriv);
+  dev.multiply_i(invstd);
+  dev.sum(axes_to_reduce, dgamma);
 }
 
 void BatchNormalizationLayer::exportTo(Exporter &exporter,