[batchnorm] Optimize batch norm layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 24 Sep 2021 11:58:10 +0000 (20:58 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 1 Oct 2021 04:28:07 +0000 (13:28 +0900)
This patch optimizes batch norm layer and tries to share the
calculations performed in calcGradient and calcDerivative.
- reuse dbeta and dgamma calculations
- reduce number of required temporary variables
- create all the required tensor variables with context
- add support for checking if the layer is trainable or not via run
context
- support average operation with the output tensor already allocated
- this patch reduces as much as memory as possible without sacrificing
speed. more memory optimization is possible at the expense of speed but
has been ommitted for now.

Note: this patch has slight improvement in performance, and adds no
extra operations.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/bn_layer.cpp
nntrainer/layers/bn_layer.h
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/layers/layer_node.cpp
nntrainer/layers/time_dist.cpp
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
test/unittest/layers/layers_golden_tests.cpp

index b2ecc18..443a811 100644 (file)
@@ -32,11 +32,22 @@ namespace nntrainer {
 
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
-enum BNParams { mu, var, gamma, beta, deviation };
+enum BNParams {
+  mu,
+  var,
+  gamma,
+  beta,
+  deviation,
+  invstd,
+  t_full_fw,
+  t_full_bw,
+  t_reduced
+};
 
 BatchNormalizationLayer::BatchNormalizationLayer(int axis_) :
   Layer(),
   axis(axis_),
+  divider(0),
   wt_idx({0}),
   bn_props(props::Epsilon(), props::BNPARAMS_MU_INIT(),
            props::BNPARAMS_VAR_INIT(), props::BNPARAMS_BETA_INIT(),
@@ -66,11 +77,21 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
   if (axis == -1)
     axis = in_dim.channel() > 1 ? 1 : 3;
 
+  /**
+   * @todo This can be speedup by employing transpose for convolution. With
+   * transpose, the channel dimension can be made last, and the remaining
+   * dimensions can be squeezed. This would allow the sum and average to be
+   * faster, and no temporary allocations inside them.
+   */
+
   dim.setTensorDim(axis, in_dim.getTensorDim(axis));
 
+  divider = 1;
   for (int i = 0; i < 4; ++i) {
-    if (axis != i)
+    if (axis != i) {
       axes_to_reduce.push_back(i);
+      divider *= in_dim.getTensorDim(i);
+    }
   }
 
   wt_idx[BNParams::mu] =
@@ -86,9 +107,33 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
     context.requestWeight(dim, bnparams_beta, WeightRegularizer::NONE, 1.0f,
                           context.getName() + ":beta", true);
 
+  /**
+   * caches the deviation -> input - avg(input)
+   * @todo check if avoiding this storage and adding dependency on input (no
+   * more in-place calculation) can save memory during memory optimization.
+   */
   wt_idx[BNParams::deviation] = context.requestTensor(
     in_dim, context.getName() + ":deviation", Tensor::Initializer::NONE, false,
     TensorLifespan::ITERATION_LIFESPAN);
+  /** caches the inverse standard deviation */
+  wt_idx[BNParams::invstd] = context.requestTensor(
+    dim, context.getName() + ":invstd", Tensor::Initializer::NONE, false,
+    TensorLifespan::ITERATION_LIFESPAN);
+  /**
+   * Temporary tensor to store the reduced tensors along the axes_to_reduce.
+   * This is further used to cache variance + epsilon as well.
+   */
+  wt_idx[BNParams::t_reduced] = context.requestTensor(
+    dim, context.getName() + ":tensor_reduced", Tensor::Initializer::NONE,
+    false, TensorLifespan::ITERATION_LIFESPAN);
+  /** Temporary tensor to store the full sized tensors in forwarding. */
+  wt_idx[BNParams::t_full_fw] = context.requestTensor(
+    in_dim, context.getName() + ":tensor_full_fw", Tensor::Initializer::NONE,
+    false, TensorLifespan::FORWARD_FUNC_LIFESPAN);
+  /** Temporary tensor to store the full sized tensors in backwarding. */
+  wt_idx[BNParams::t_full_bw] = context.requestTensor(
+    in_dim, context.getName() + ":tensor_full_back", Tensor::Initializer::NONE,
+    false, TensorLifespan::BACKWARD_FUNC_LIFESPAN);
 }
 
 void BatchNormalizationLayer::setProperty(
@@ -112,20 +157,23 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
   Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
   Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
+  Tensor &invstd = context.getTensor(wt_idx[BNParams::invstd]);
 
-  if (training) {
-    /**
-     * @todo support average with preallocated tensors,
-     * and then register cmu as a temporary tensor
-     */
-    Tensor cmu = input_.average(axes_to_reduce);
-    input_.subtract(cmu, deviation);
+  /** @todo these are not needed for inference, support optimizing these */
+  Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
+  Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full_fw]);
+  Tensor cvar = t_reduced; /** cache the variance in this tensor for backward */
 
-    Tensor t1 = deviation.pow(2.0f);
-    cvar = t1.average(axes_to_reduce);
+  if (training) {
+    input_.average(axes_to_reduce, t_reduced);
+    input_.subtract(t_reduced, deviation);
 
     mu.multiply_i(momentum);
-    mu.add_i(cmu, 1 - momentum);
+    mu.add_i(t_reduced, 1 - momentum);
+
+    t_full = deviation.pow(2.0f);
+    t_full.average(axes_to_reduce, cvar);
+
     var.multiply_i(momentum);
     var.add_i(cvar, 1 - momentum);
 
@@ -133,6 +181,7 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
     cvar.pow(-0.5f, invstd);
   } else {
     input_.subtract(mu, deviation);
+    /** @todo do below 2 lines only for first iteration */
     var.add(epsilon, invstd);
     invstd.pow_i(-0.5f);
   }
@@ -148,30 +197,49 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
   Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
   Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
+  Tensor &invstd = context.getTensor(wt_idx[BNParams::invstd]);
+
+  Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
+  Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full_bw]);
+  Tensor cvar = t_reduced;
+
+  if (context.getTrainable()) {
+    /**
+     * This implementation depends on the pre-calculated dbeta calculated.
+     */
+    Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
+    t_reduced = dbeta.divide(divider);
+  } else {
+    t_reduced = deriv.average(axes_to_reduce);
+  }
+
+  deriv.subtract(t_reduced, dx);
 
-  Tensor dx_1 = gamma.multiply(invstd);
-  Tensor dx_2 = deriv.subtract(deriv.average(axes_to_reduce));
+  deviation.multiply(deriv, t_full);
+  t_full.average(axes_to_reduce, t_reduced);
+  t_reduced.divide_i(cvar);
+  deviation.multiply_i(t_reduced);
+  dx.subtract_i(deviation);
 
-  Tensor t1 = deviation.multiply(deriv);
-  Tensor t2 = t1.average(axes_to_reduce);
-  deviation.divide_i(cvar);
-  deviation.multiply_i(t2);
-  dx_2.subtract_i(deviation);
+  if (context.getTrainable()) {
+    /**
+     * This calculates dgamma tensor.
+     */
+    Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
+    t_full.multiply_i(invstd);
+    t_full.sum(axes_to_reduce, dgamma);
+  }
 
-  dx_2.multiply(dx_1, dx);
+  invstd.multiply_i(gamma);
+  dx.multiply_i(invstd);
 }
 
 void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
-
-  Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
+  /** dgamma is calculated in calcDerivative. dbeta is calculated here */
   Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
   Tensor &deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
-  Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
 
   deriv.sum(axes_to_reduce, dbeta);
-  Tensor dev = deviation.multiply(deriv);
-  dev.multiply_i(invstd);
-  dev.sum(axes_to_reduce, dgamma);
 }
 
 void BatchNormalizationLayer::exportTo(Exporter &exporter,
index 4e7d0f8..c985564 100644 (file)
@@ -113,14 +113,11 @@ public:
   inline static const std::string type = "batch_normalization";
 
 private:
-  Tensor cvar; /**< training variance saved in bn_layer::forwarding and used in
-                    bn_layer::calcDerivative */
-  Tensor invstd; /**<  inversed training std for backward pass */
-
-  int axis; /**< Target axis, axis inferred at initialize when -1 */
+  int axis;      /**< Target axis, axis inferred at initialize when -1 */
+  float divider; /**< size of the axes of the reduced */
 
   std::vector<unsigned int> axes_to_reduce; /**< target axes to reduce */
-  std::array<unsigned int, 5> wt_idx; /**< indices of the weights and tensors */
+  std::array<unsigned int, 9> wt_idx; /**< indices of the weights and tensors */
   std::tuple<props::Epsilon, props::BNPARAMS_MU_INIT, props::BNPARAMS_VAR_INIT,
              props::BNPARAMS_BETA_INIT, props::BNPARAMS_GAMMA_INIT,
              props::Momentum>
index 99c8400..40c85e9 100644 (file)
@@ -15,8 +15,8 @@
 #include <weight.h>
 
 namespace nntrainer {
-RunLayerContext::RunLayerContext(const std::string &name, float l,
-                                 const std::vector<Weight *> &w,
+RunLayerContext::RunLayerContext(const std::string &name, bool trainable,
+                                 float l, const std::vector<Weight *> &w,
                                  const std::vector<Var_Grad *> &in,
                                  const std::vector<Var_Grad *> &out,
                                  const std::vector<Var_Grad *> &t) :
@@ -26,6 +26,7 @@ RunLayerContext::RunLayerContext(const std::string &name, float l,
   outputs(out),
   tensors(t) {
   std::get<props::Name>(props).set(name);
+  std::get<props::Trainable>(props).set(trainable);
   NNTR_THROW_IF(!readyToUse(), std::invalid_argument)
     << "run context is not ready to use upon creation";
 }
index a09a040..db27f5d 100644 (file)
@@ -295,6 +295,9 @@ private:
  * structures with memory allocated or support to allocate any new memory, but
  * rather only support storing specifications based on which memory will be
  * allocated later.
+ *
+ * @todo Check the caller of the getTensor() and set restrictions on the tensors
+ * to be accessed based on which function is requesting it.
  */
 class RunLayerContext {
 public:
@@ -307,7 +310,7 @@ public:
    * @param out outputs of the layer
    * @param t extra tensors of the layer
    */
-  RunLayerContext(const std::string &name, float l,
+  RunLayerContext(const std::string &name, bool trainable, float l,
                   const std::vector<Weight *> &w,
                   const std::vector<Var_Grad *> &in,
                   const std::vector<Var_Grad *> &out,
@@ -573,6 +576,13 @@ public:
   const std::string &getName() const { return std::get<props::Name>(props); }
 
   /**
+   * @brief   get trainable by the layer
+   *
+   * @return trainable of the layer
+   */
+  bool getTrainable() const { return std::get<props::Trainable>(props); }
+
+  /**
    * @brief   check if run context is set and is ready to use
    *
    * @return true if ready, else false
@@ -580,8 +590,8 @@ public:
   bool readyToUse() const;
 
 private:
-  std::tuple<props::Name> props; /**< props of the layer */
-  float loss;                    /**< loss of the layer */
+  std::tuple<props::Name, props::Trainable> props; /**< props of the layer */
+  float loss;                                      /**< loss of the layer */
 
   std::vector<Weight *> weights;   /**< weights of the layer */
   std::vector<Var_Grad *> inputs;  /**< inputs of the layer */
index 8b0fb8b..9d64f01 100644 (file)
@@ -485,7 +485,8 @@ void LayerNode::calcDerivative() {
  */
 void LayerNode::calcGradient() {
   START_PROFILE(calc_grad_event_key);
-  layer->calcGradient(*run_context);
+  if (getTrainable())
+    layer->calcGradient(*run_context);
   END_PROFILE(calc_grad_event_key);
 }
 
@@ -537,8 +538,8 @@ void LayerNode::configureRunContext(const std::vector<Weight *> &weights,
                                     const std::vector<Var_Grad *> &inputs,
                                     const std::vector<Var_Grad *> &outputs,
                                     const std::vector<Var_Grad *> &tensors) {
-  run_context = std::make_unique<RunLayerContext>(getName(), 0.0f, weights,
-                                                  inputs, outputs, tensors);
+  run_context = std::make_unique<RunLayerContext>(
+    getName(), getTrainable(), 0.0f, weights, inputs, outputs, tensors);
 }
 
 /**
index 80e9894..bb7799c 100644 (file)
@@ -248,9 +248,9 @@ void TimeDistLayer::forwarding(RunLayerContext &context, bool training) {
       out_var.initializeGradient(label_iter);
     }
 
-    RunLayerContext dist_context(context.getName(), context.getLoss(),
-                                 getWeightsForContext(), {&in_var}, {&out_var},
-                                 getTensorsForContext());
+    RunLayerContext dist_context(context.getName(), context.getTrainable(),
+                                 context.getLoss(), getWeightsForContext(),
+                                 {&in_var}, {&out_var}, getTensorsForContext());
 
     dist_layer->forwarding(dist_context, training);
   }
@@ -294,9 +294,9 @@ void TimeDistLayer::calcDerivative(RunLayerContext &context) {
     out_var.initializeGradient(d_iter);
     out_var.initializeVariable(hval_iter);
 
-    RunLayerContext dist_context(context.getName(), context.getLoss(),
-                                 getWeightsForContext(), {&in_var}, {&out_var},
-                                 getTensorsForContext());
+    RunLayerContext dist_context(context.getName(), context.getTrainable(),
+                                 context.getLoss(), getWeightsForContext(),
+                                 {&in_var}, {&out_var}, getTensorsForContext());
 
     dist_layer->calcDerivative(dist_context);
   }
@@ -346,9 +346,9 @@ void TimeDistLayer::calcGradient(RunLayerContext &context) {
     in_var.initializeVariable(in_iter);
     out_var.initializeGradient(d_iter);
 
-    RunLayerContext dist_context(context.getName(), context.getLoss(),
-                                 getWeightsForContext(), {&in_var}, {&out_var},
-                                 getTensorsForContext());
+    RunLayerContext dist_context(context.getName(), context.getTrainable(),
+                                 context.getLoss(), getWeightsForContext(),
+                                 {&in_var}, {&out_var}, getTensorsForContext());
 
     dist_layer->calcGradient(dist_context);
   }
@@ -389,9 +389,9 @@ void TimeDistLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     fillWeightsFromContext(context);
     fillTensorsFromContext(context);
 
-    RunLayerContext dist_context(context.getName(), context.getLoss(),
-                                 getWeightsForContext(), {&in_var}, {&out_var},
-                                 getTensorsForContext());
+    RunLayerContext dist_context(context.getName(), context.getTrainable(),
+                                 context.getLoss(), getWeightsForContext(),
+                                 {&in_var}, {&out_var}, getTensorsForContext());
 
     dist_layer->setBatch(dist_context, batch);
 
index 9abc409..6072034 100644 (file)
@@ -1180,20 +1180,36 @@ void Tensor::read(std::ifstream &file) {
  * @brief Calculate average value according to the axis.
  */
 Tensor Tensor::average(unsigned int axis) const {
+  Tensor t;
+  return average(axis, t);
+}
+
+/**
+ * @brief Calculate average value according to the axis.
+ */
+Tensor &Tensor::average(unsigned int axis, Tensor &output) const {
   if (axis >= TensorDim::MAXDIM)
     throw std::out_of_range(
       "negative axis or axis more then MAXDIM is invalid");
 
   unsigned int axis_size = dim.getDim()[axis];
   if (axis_size == 1)
-    return this->clone();
+    output.copy(*this);
+  else
+    this->sum(axis, output, 1.0 / ((float)axis_size));
 
-  return this->sum(axis, 1.0 / ((float)axis_size));
+  return output;
 }
 
 Tensor Tensor::average(const std::vector<unsigned int> &axes) const {
+  Tensor t;
+  return average(axes, t);
+}
+
+Tensor &Tensor::average(const std::vector<unsigned int> &axes,
+                        Tensor &output) const {
   if (axes.empty())
-    return this->average();
+    return this->average(output);
 
   TensorDim ret_shape;
   for (const auto &idx : axes) {
@@ -1203,7 +1219,7 @@ Tensor Tensor::average(const std::vector<unsigned int> &axes) const {
     ret_shape.setTensorDim(idx, dim.getTensorDim(idx));
   }
 
-  return this->sum(axes, 1.0 / (float)ret_shape.getDataLen());
+  return this->sum(axes, output, 1.0 / (float)ret_shape.getDataLen());
 }
 
 /**
@@ -1215,6 +1231,15 @@ Tensor Tensor::average() const {
   return result.average(3);
 }
 
+/**
+ * @brief Calculate average value according to the axis.
+ */
+Tensor &Tensor::average(Tensor &output) const {
+  Tensor result = *this;
+  result.reshape({1, 1, 1, dim.getDataLen()});
+  return result.average(3, output);
+}
+
 void Tensor::setValue(float val) {
   float *data = getData();
   std::fill(data, data + size(), val);
index 7c9a64c..5dc1e92 100644 (file)
@@ -597,6 +597,7 @@ public:
    */
   Tensor &sum(const std::vector<unsigned int> &axes, Tensor &output,
               float alpha = 1.0) const;
+
   /**
    * @brief     Averaging the Tensor elements according to the axis
    *            0 : batch direction
@@ -606,6 +607,12 @@ public:
    * @retval    Calculated Tensor
    */
   Tensor average(unsigned int axis) const;
+  /**
+   * @brief     Averaging the Tensor elements according to the axis
+   *
+   * @retval    Calculated Tensor
+   */
+  Tensor &average(unsigned int axis, Tensor &output) const;
 
   /**
    * @brief average all the Tensor by multiple axes
@@ -616,12 +623,27 @@ public:
   Tensor average(const std::vector<unsigned int> &axes) const;
 
   /**
+   * @brief average all the Tensor by multiple axes
+   *
+   * @param axes axes to sum along
+   * @param output output tensor
+   * @return Tensor
+   */
+  Tensor &average(const std::vector<unsigned int> &axes, Tensor &output) const;
+
+  /**
    * @brief     Averaging the Tensor elements by all axis
    * @retval    Calculated Tensor
    */
   Tensor average() const;
 
   /**
+   * @brief     Averaging the Tensor elements by all axis
+   * @retval    Calculated Tensor
+   */
+  Tensor &average(Tensor &output) const;
+
+  /**
    * @brief     Anchor a starting point to defer following evaluation
    * @retval    LazyTensor class that can be used with run();
    */
index dac2641..eef15e9 100644 (file)
@@ -131,8 +131,8 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) {
   };
 
   auto rc =
-    RunLayerContext("golden", 0.0f, create_view(weights), create_view(ins),
-                    create_view(outs), create_view(tensors));
+    RunLayerContext("golden", true, 0.0f, create_view(weights),
+                    create_view(ins), create_view(outs), create_view(tensors));
 
   auto num_outputs = rc.getNumOutputs();