[layer normalization] implement layer normalization
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 27 Jul 2022 10:02:51 +0000 (19:02 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 7 Sep 2022 14:18:57 +0000 (23:18 +0900)
 - implement layer normalization layer based on batch normalization

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
api/ccapi/include/layer.h
nntrainer/app_context.cpp
nntrainer/layers/layer_normalization_layer.cpp [new file with mode: 0644]
nntrainer/layers/layer_normalization_layer.h [new file with mode: 0644]
nntrainer/layers/meson.build

index 2f41e99..2d2b105 100644 (file)
@@ -66,9 +66,11 @@ enum LayerType {
   LAYER_MULTI_HEAD_ATTENTION =
     ML_TRAIN_LAYER_TYPE_MULTI_HEAD_ATTENTION, /**< Multi Head Attention Layer
                                                  type */
-  LAYER_POSITIONAL_ENCODING =
-    ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING, /**< Positional Encoding Layer type
+  LAYER_LAYER_NORMALIZATION =
+    ML_TRAIN_LAYER_TYPE_LAYER_NORMALIZATION, /**< Layer Normalization Layer type
                                               */
+  LAYER_POSITIONAL_ENCODING =
+    ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING, /**< Positional Encoding Layer type */
   LAYER_PREPROCESS_FLIP =
     ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP, /**< Preprocess flip Layer type */
   LAYER_PREPROCESS_TRANSLATE =
@@ -248,6 +250,14 @@ BatchNormalization(const std::vector<std::string> &properties = {}) {
 }
 
 /**
+ * @brief Helper function to create layer normalization layer
+ */
+inline std::unique_ptr<Layer>
+LayerNormalization(const std::vector<std::string> &properties = {}) {
+  return createLayer(LayerType::LAYER_LAYER_NORMALIZATION, properties);
+}
+
+/**
  * @brief Helper function to create convolution 2d layer
  */
 inline std::unique_ptr<Layer>
index 02cf3db..cc59e61 100644 (file)
@@ -49,6 +49,7 @@
 #include <grucell.h>
 #include <identity_layer.h>
 #include <input_layer.h>
+#include <layer_normalization_layer.h>
 #include <lr_scheduler_constant.h>
 #include <lr_scheduler_exponential.h>
 #include <lr_scheduler_step.h>
@@ -241,6 +242,9 @@ static void add_default_object(AppContext &ac) {
                      FullyConnectedLayer::type, LayerType::LAYER_FC);
   ac.registerFactory(nntrainer::createLayer<BatchNormalizationLayer>,
                      BatchNormalizationLayer::type, LayerType::LAYER_BN);
+  ac.registerFactory(nntrainer::createLayer<LayerNormalizationLayer>,
+                     LayerNormalizationLayer::type,
+                     LayerType::LAYER_LAYER_NORMALIZATION);
   ac.registerFactory(nntrainer::createLayer<Conv2DLayer>, Conv2DLayer::type,
                      LayerType::LAYER_CONV2D);
   ac.registerFactory(nntrainer::createLayer<Conv1DLayer>, Conv1DLayer::type,
diff --git a/nntrainer/layers/layer_normalization_layer.cpp b/nntrainer/layers/layer_normalization_layer.cpp
new file mode 100644 (file)
index 0000000..bd9127a
--- /dev/null
@@ -0,0 +1,225 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2022 hyeonseok Lee <hs89.lee@samsung.com>
+ *
+ * @file   layer_normalization_layer.cpp
+ * @date   25 July 2022
+ * @see    https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/abs/1607.06450
+ * @author hyeonseok Lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is Layer Normalization Layer Class for Neural Network
+ *
+ */
+
+#include <algorithm>
+#include <numeric>
+
+#include <layer_context.h>
+#include <layer_normalization_layer.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <util_func.h>
+
+namespace nntrainer {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+enum LNParams {
+  gamma,
+  beta,
+  deviation,
+  variance,
+  inv_std_dev,
+  temp_origin_size,
+  temp_normalized_size,
+};
+
+LayerNormalizationLayer::LayerNormalizationLayer() :
+  Layer(),
+  layer_normalization_props(
+    std::vector<props::Axis>(), props::Epsilon(), props::BNPARAMS_GAMMA_INIT(),
+    props::BNPARAMS_BETA_INIT(), props::WeightDecay(), props::BiasDecay()) {
+  wt_idx.fill(std::numeric_limits<unsigned>::max());
+}
+
+void LayerNormalizationLayer::finalize(InitLayerContext &context) {
+  if (context.getNumInputs() != 1) {
+    throw std::invalid_argument(
+      "Only one input is allowed for layer normalization layer");
+  }
+
+  auto gamma_initializer =
+    std::get<props::BNPARAMS_GAMMA_INIT>(layer_normalization_props).get();
+  auto beta_initializer =
+    std::get<props::BNPARAMS_BETA_INIT>(layer_normalization_props).get();
+  auto weight_decay = std::get<props::WeightDecay>(layer_normalization_props);
+  auto bias_decay = std::get<props::BiasDecay>(layer_normalization_props);
+
+  auto const &input_dim = context.getInputDimensions()[0];
+  context.setOutputDimensions({input_dim});
+
+  std::vector<props::Axis> axes_prop =
+    std::get<std::vector<props::Axis>>(layer_normalization_props);
+
+  NNTR_THROW_IF(axes_prop.empty(), std::invalid_argument)
+    << "[Layer normalization]axis property is empty";
+
+  normalize_axes.insert(normalize_axes.end(), axes_prop.begin(),
+                        axes_prop.end());
+  std::sort(normalize_axes.begin(), normalize_axes.end());
+  normalize_axes.erase(
+    std::unique(normalize_axes.begin(), normalize_axes.end()),
+    normalize_axes.end());
+
+  TensorDim normalize_dim;
+  for (unsigned int axis : normalize_axes) {
+    normalize_dim.setTensorDim(axis, input_dim.getTensorDim(axis));
+  }
+
+  wt_idx[LNParams::gamma] = context.requestWeight(
+    normalize_dim, gamma_initializer, WeightRegularizer::NONE, 1.0f,
+    weight_decay, "gamma", true);
+  wt_idx[LNParams::beta] = context.requestWeight(
+    normalize_dim, beta_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
+    "beta", true);
+
+  TensorDim remain_dim;
+  std::vector<unsigned int> total_axes;
+  total_axes.resize(ml::train::TensorDim::MAXDIM);
+  std::iota(total_axes.begin(), total_axes.end(), 0u);
+  std::set_difference(total_axes.begin(), total_axes.end(),
+                      normalize_axes.begin(), normalize_axes.end(),
+                      std::back_inserter(remain_axes));
+  for (unsigned int axis : remain_axes) {
+    remain_dim.setTensorDim(axis, input_dim.getTensorDim(axis));
+  }
+
+  /** caches the deviation -> input - avg(input) */
+  wt_idx[LNParams::deviation] =
+    context.requestTensor(input_dim, "deviation", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+  /** caches variance + epsilon as well */
+  wt_idx[LNParams::variance] =
+    context.requestTensor(remain_dim, "variance", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+  /** caches the inverse standard deviation */
+  wt_idx[LNParams::inv_std_dev] =
+    context.requestTensor(remain_dim, "inv_std_dev", Tensor::Initializer::NONE,
+                          false, TensorLifespan::ITERATION_LIFESPAN);
+
+  /** temporary tensor (origin size) */
+  wt_idx[LNParams::temp_origin_size] = context.requestTensor(
+    input_dim, "temp_origin_size", Tensor::Initializer::NONE, false,
+    TensorLifespan::CALC_DERIV_LIFESPAN);
+  /** temporary tensor (normalized size) */
+  wt_idx[LNParams::temp_normalized_size] = context.requestTensor(
+    remain_dim, "temp_normalized_size", Tensor::Initializer::NONE, false,
+    TensorLifespan::CALC_DERIV_LIFESPAN);
+}
+
+void LayerNormalizationLayer::setProperty(
+  const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, layer_normalization_props);
+  NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
+    << "[Layer Normalization Layer] Unknown Layer Properties count " +
+         std::to_string(values.size());
+}
+
+void LayerNormalizationLayer::forwarding(RunLayerContext &context,
+                                         bool training) {
+  const float epsilon =
+    std::get<props::Epsilon>(layer_normalization_props).get();
+
+  const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+
+  Tensor &gamma = context.getWeight(wt_idx[LNParams::gamma]);
+  Tensor &beta = context.getWeight(wt_idx[LNParams::beta]);
+
+  Tensor &deviation = context.getTensor(wt_idx[LNParams::deviation]);
+  Tensor &variance = context.getTensor(wt_idx[LNParams::variance]);
+  Tensor &inv_std_dev = context.getTensor(wt_idx[LNParams::inv_std_dev]);
+
+  Tensor &temp_full_size = output;
+  Tensor &temp_norm_size = inv_std_dev;
+
+  input.average(normalize_axes, temp_norm_size);
+  input.subtract(temp_norm_size, deviation);
+
+  deviation.pow(2.0f, temp_full_size);
+  temp_full_size.average(normalize_axes, variance);
+
+  variance.add_i(epsilon);
+  variance.pow(-0.5f, inv_std_dev);
+
+  deviation.multiply(inv_std_dev, output);
+  output.multiply_i(gamma);
+  output.add_i(beta);
+}
+
+void LayerNormalizationLayer::calcDerivative(RunLayerContext &context) {
+  const bool trainable = context.getTrainable();
+  Tensor empty;
+
+  Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+  const Tensor &gamma = context.getWeight(wt_idx[LNParams::gamma]);
+  Tensor &d_gamma =
+    trainable ? context.getWeightGrad(wt_idx[LNParams::gamma]) : empty;
+
+  Tensor &deviation = context.getTensor(wt_idx[LNParams::deviation]);
+  Tensor &variance = context.getTensor(wt_idx[LNParams::variance]);
+  Tensor &inv_std_dev = context.getTensor(wt_idx[LNParams::inv_std_dev]);
+
+  Tensor &temp_origin_size =
+    context.getTensor(wt_idx[LNParams::temp_origin_size]);
+  Tensor &temp_normalized_size =
+    context.getTensor(wt_idx[LNParams::temp_normalized_size]);
+
+  incoming_derivative.multiply(deviation, temp_origin_size);
+  temp_origin_size.average(normalize_axes, temp_normalized_size);
+  temp_normalized_size.divide_i(variance);
+  deviation.multiply_i(temp_normalized_size);
+
+  if (trainable) {
+    /** calculate d_gamma */
+    temp_origin_size.multiply_i(inv_std_dev);
+    temp_origin_size.sum(remain_axes, d_gamma);
+  }
+  incoming_derivative.average(normalize_axes, temp_normalized_size);
+  incoming_derivative.subtract(temp_normalized_size, outgoing_derivative);
+  outgoing_derivative.subtract_i(deviation);
+
+  inv_std_dev.multiply_i(gamma);
+  outgoing_derivative.multiply_i(inv_std_dev);
+}
+
+void LayerNormalizationLayer::calcGradient(RunLayerContext &context) {
+  /** d_gamma is calculated in calcDerivative. d_beta is calculated here */
+
+  const Tensor &incoming_derivative =
+    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  Tensor &d_beta = context.getWeightGrad(wt_idx[LNParams::beta]);
+
+  incoming_derivative.sum(remain_axes, d_beta);
+}
+
+void LayerNormalizationLayer::exportTo(
+  Exporter &exporter, const ml::train::ExportMethods &method) const {
+  exporter.saveResult(layer_normalization_props, method, this);
+}
+
+void LayerNormalizationLayer::setBatch(RunLayerContext &context,
+                                       unsigned int batch) {
+  context.updateTensor(wt_idx[LNParams::deviation], batch);
+  context.updateTensor(wt_idx[LNParams::variance], batch);
+  context.updateTensor(wt_idx[LNParams::inv_std_dev], batch);
+  context.updateTensor(wt_idx[LNParams::temp_origin_size], batch);
+  context.updateTensor(wt_idx[LNParams::temp_normalized_size], batch);
+}
+
+} /* namespace nntrainer */
diff --git a/nntrainer/layers/layer_normalization_layer.h b/nntrainer/layers/layer_normalization_layer.h
new file mode 100644 (file)
index 0000000..da513e8
--- /dev/null
@@ -0,0 +1,129 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2022 hyeonseok Lee <hs89.lee@samsung.com>
+ *
+ * @file   layer_normalization_layer.h
+ * @date   25 July 2022
+ * @see    https://github.com/nnstreamer/nntrainer
+ *         https://arxiv.org/abs/1607.06450
+ * @author hyeonseok Lee <hs89.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is Layer Normalization Layer Class for Neural Network
+ *
+ */
+
+#ifndef __LAYER_NORMALIZATION_LAYER_H__
+#define __LAYER_NORMALIZATION_LAYER_H__
+#ifdef __cplusplus
+
+#include <array>
+#include <functional>
+#include <vector>
+
+#include <common_properties.h>
+#include <layer_devel.h>
+
+namespace nntrainer {
+
+/**
+ * @class   LayerNormalizationLayer
+ * @brief   Layer Noramlization Layer
+ */
+class LayerNormalizationLayer : public Layer {
+public:
+  /**
+   * @brief     Constructor of LayerNormalizationLayer
+   */
+  LayerNormalizationLayer();
+
+  /**
+   * @brief     Destructor of LayerNormalizationLayer
+   */
+  ~LayerNormalizationLayer() {}
+
+  /**
+   * @brief  Move constructor of LayerNormalizationLayer
+   * @param[in] rhs LayerNormalizationLayer to be moved
+   */
+  LayerNormalizationLayer(LayerNormalizationLayer &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator
+   * @param[in] rhs LayerNormalizationLayer to be moved
+   */
+  LayerNormalizationLayer &operator=(LayerNormalizationLayer &&rhs) = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::calcGradient(RunLayerContext &context)
+   */
+  void calcGradient(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, const ml::train::ExportMethods
+   * method)
+   */
+  void exportTo(Exporter &exporter,
+                const ml::train::ExportMethods &method) const override;
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override {
+    return LayerNormalizationLayer::type;
+  };
+
+  /**
+   * @copydoc Layer::supportBackwarding()
+   */
+  bool supportBackwarding() const override { return true; }
+
+  using Layer::setProperty;
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc Layer::supportInPlace()
+   */
+  bool supportInPlace() const override { return true; }
+
+  /**
+   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
+   */
+  void setBatch(RunLayerContext &context, unsigned int batch) override;
+
+  inline static const std::string type = "layer_normalization";
+
+private:
+  std::vector<unsigned int> normalize_axes; /**< normalize axes */
+  std::vector<unsigned int>
+    remain_axes; /**< remained axes (exclusive with normalize axes) */
+
+  std::array<unsigned int, 7> wt_idx;
+  std::tuple<std::vector<props::Axis>, props::Epsilon,
+             props::BNPARAMS_GAMMA_INIT, props::BNPARAMS_BETA_INIT,
+             props::WeightDecay, props::BiasDecay>
+    layer_normalization_props;
+};
+
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __LAYER_NORMALIZATION_LAYER_H__ */
index 400a631..0433730 100644 (file)
@@ -10,6 +10,7 @@ layer_sources = [
   'multi_head_attention_layer.cpp',
   'concat_layer.cpp',
   'bn_layer.cpp',
+  'layer_normalization_layer.cpp',
   'conv2d_layer.cpp',
   'conv1d_layer.cpp',
   'fc_layer.cpp',