[bn layer] Maintain bn layer property with props
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 9 Sep 2021 23:55:40 +0000 (08:55 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 13 Sep 2021 11:59:57 +0000 (20:59 +0900)
 - All the property will be maintain with props

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
nntrainer/layers/bn_layer.cpp
nntrainer/layers/bn_layer.h
nntrainer/layers/common_properties.cpp
nntrainer/layers/common_properties.h

index 3cbe8c6054f9d602338d531668f8a1927b11fb40..c78749b5f9187efbd7afe2366773a2d687707ba6 100644 (file)
@@ -25,6 +25,7 @@
 #include <lazy_tensor.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
+#include <node_exporter.h>
 #include <parse_util.h>
 #include <util_func.h>
 
@@ -34,6 +35,14 @@ static constexpr size_t SINGLE_INOUT_IDX = 0;
 
 enum BNParams { mu, var, gamma, beta, deviation };
 
+BatchNormalizationLayer::BatchNormalizationLayer(int axis_) :
+  Layer(),
+  axis(axis_),
+  wt_idx({0}),
+  bn_props(props::Epsilon(), props::BNPARAMS_MU_INIT(),
+           props::BNPARAMS_VAR_INIT(), props::BNPARAMS_BETA_INIT(),
+           props::BNPARAMS_GAMMA_INIT(), props::Momentum()) {}
+
 /// @todo add multiple axis support
 void BatchNormalizationLayer::finalize(InitLayerContext &context) {
   if (context.getNumInputs() != 1) {
@@ -41,6 +50,11 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
       "Only one input is allowed for batch normalization layer");
   }
 
+  auto &bnparams_mu = std::get<props::BNPARAMS_MU_INIT>(bn_props);
+  auto &bnparams_var = std::get<props::BNPARAMS_VAR_INIT>(bn_props);
+  auto &bnparams_beta = std::get<props::BNPARAMS_BETA_INIT>(bn_props);
+  auto &bnparams_gamma = std::get<props::BNPARAMS_GAMMA_INIT>(bn_props);
+
   std::vector<TensorDim> output_dims(1);
 
   /** set output dimensions */
@@ -60,18 +74,18 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
       axes_to_reduce.push_back(i);
   }
 
-  wt_idx[BNParams::mu] = context.requestWeight(
-    dim, initializers[BNParams::mu], WeightRegularizer::NONE, 1.0f,
-    context.getName() + ":moving_mean", false);
-  wt_idx[BNParams::var] = context.requestWeight(
-    dim, initializers[BNParams::var], WeightRegularizer::NONE, 1.0f,
-    context.getName() + ":moving_variance", false);
-  wt_idx[BNParams::gamma] = context.requestWeight(
-    dim, initializers[BNParams::gamma], WeightRegularizer::NONE, 1.0f,
-    context.getName() + ":gamma", true);
-  wt_idx[BNParams::beta] = context.requestWeight(
-    dim, initializers[BNParams::beta], WeightRegularizer::NONE, 1.0f,
-    context.getName() + ":beta", true);
+  wt_idx[BNParams::mu] =
+    context.requestWeight(dim, bnparams_mu, WeightRegularizer::NONE, 1.0f,
+                          context.getName() + ":moving_mean", false);
+  wt_idx[BNParams::var] =
+    context.requestWeight(dim, bnparams_var, WeightRegularizer::NONE, 1.0f,
+                          context.getName() + ":moving_variance", false);
+  wt_idx[BNParams::gamma] =
+    context.requestWeight(dim, bnparams_gamma, WeightRegularizer::NONE, 1.0f,
+                          context.getName() + ":gamma", true);
+  wt_idx[BNParams::beta] =
+    context.requestWeight(dim, bnparams_beta, WeightRegularizer::NONE, 1.0f,
+                          context.getName() + ":beta", true);
 
   wt_idx[BNParams::deviation] =
     context.requestTensor(in_dim, context.getName() + ":deviation",
@@ -80,68 +94,17 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
 
 void BatchNormalizationLayer::setProperty(
   const std::vector<std::string> &values) {
-  /// @todo: deprecate this in favor of loadProperties
-  for (unsigned int i = 0; i < values.size(); ++i) {
-    std::string key;
-    std::string value;
-    std::stringstream ss;
-
-    if (getKeyValue(values[i], key, value) != ML_ERROR_NONE) {
-      throw std::invalid_argument("Error parsing the property: " + values[i]);
-    }
-
-    if (value.empty()) {
-      ss << "value is empty: key: " << key << ", value: " << value;
-      throw std::invalid_argument(ss.str());
-    }
-
-    /// @note this calls derived setProperty if available
-    setProperty(key, value);
-  }
-}
-
-void BatchNormalizationLayer::setProperty(const std::string &type_str,
-                                          const std::string &value) {
-  using PropertyType = nntrainer::Layer::PropertyType;
-  int status = ML_ERROR_NONE;
-  nntrainer::Layer::PropertyType type =
-    static_cast<nntrainer::Layer::PropertyType>(parseLayerProperty(type_str));
-
-  switch (type) {
-  case PropertyType::epsilon:
-    status = setFloat(epsilon, value);
-    throw_status(status);
-    break;
-  case PropertyType::moving_mean_initializer:
-    initializers[BNParams::mu] =
-      (Tensor::Initializer)parseType(value, TOKEN_WEIGHT_INIT);
-    break;
-  case PropertyType::moving_variance_initializer:
-    initializers[BNParams::var] =
-      (Tensor::Initializer)parseType(value, TOKEN_WEIGHT_INIT);
-    break;
-  case PropertyType::beta_initializer:
-    initializers[BNParams::beta] =
-      (Tensor::Initializer)parseType(value, TOKEN_WEIGHT_INIT);
-    break;
-  case PropertyType::gamma_initializer:
-    initializers[BNParams::gamma] =
-      (Tensor::Initializer)parseType(value, TOKEN_WEIGHT_INIT);
-    break;
-  case PropertyType::momentum:
-    status = setFloat(momentum, value);
-    throw_status(status);
-    break;
-  default:
-    std::string msg =
-      "[BatchNormalizationLayer] Unknown Layer Property Key for value " +
-      std::string(value);
-    throw exception::not_supported(msg);
-  }
+  auto remain_props = loadProperties(values, bn_props);
+  NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
+    << "[BNLayer] Unknown Layer Properties count " +
+         std::to_string(values.size());
 }
 
 void BatchNormalizationLayer::forwarding(RunLayerContext &context,
                                          bool training) {
+  float epsilon = std::get<props::Epsilon>(bn_props);
+  float momentum = std::get<props::Momentum>(bn_props);
+
   Tensor &mu = context.getWeight(wt_idx[BNParams::mu]);
   Tensor &var = context.getWeight(wt_idx[BNParams::var]);
   Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
@@ -216,4 +179,9 @@ void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
   dgamma = dev.sum(axes_to_reduce);
 }
 
+void BatchNormalizationLayer::exportTo(Exporter &exporter,
+                                       const ExportMethods &method) const {
+  exporter.saveResult(bn_props, method, this);
+}
+
 } /* namespace nntrainer */
index d5592d58c1c99bf48730678670ddf9702e8b5e2f..43407c499385eee9138a0bc7cda5ece7e7f0defc 100644 (file)
@@ -41,19 +41,7 @@ public:
   /**
    * @brief     Constructor of Batch Noramlization Layer
    */
-  BatchNormalizationLayer(
-    int axis = -1, float momentum = 0.99, float epsilon = 0.001,
-    Tensor::Initializer moving_mean_initializer = Tensor::Initializer::ZEROS,
-    Tensor::Initializer moving_variance_initializer = Tensor::Initializer::ONES,
-    Tensor::Initializer gamma_initializer = Tensor::Initializer::ONES,
-    Tensor::Initializer beta_initializer = Tensor::Initializer::ZEROS) :
-    Layer(),
-    epsilon(epsilon),
-    momentum(momentum),
-    axis(axis),
-    initializers{moving_mean_initializer, moving_variance_initializer,
-                 gamma_initializer, beta_initializer},
-    wt_idx({0}) {}
+  BatchNormalizationLayer(int axis_ = -1);
 
   /**
    * @brief     Destructor of BatchNormalizationLayer
@@ -95,10 +83,7 @@ public:
   /**
    * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
    */
-  void exportTo(Exporter &exporter,
-                const ExportMethods &method) const override {
-    Layer::exportTo(exporter, method);
-  }
+  void exportTo(Exporter &exporter, const ExportMethods &method) const override;
 
   /**
    * @copydoc Layer::getType()
@@ -132,23 +117,14 @@ private:
                     bn_layer::calcDerivative */
   Tensor invstd; /**<  inversed training std for backward pass */
 
-  float epsilon;  /**< epsilon */
-  float momentum; /**< momentum */
-  int axis;       /**< Target axis, axis inferred at initialize when -1 */
+  int axis; /**< Target axis, axis inferred at initialize when -1 */
 
-  std::vector<unsigned int> axes_to_reduce;        /**< target axes to reduce */
-  std::array<Tensor::Initializer, 4> initializers; /**< weight initializers */
+  std::vector<unsigned int> axes_to_reduce; /**< target axes to reduce */
   std::array<unsigned int, 5> wt_idx; /**< indices of the weights and tensors */
-
-  /**
-   * @brief setProperty by type and value separated
-   * @param[in] type property type to be passed
-   * @param[in] value value to be passed
-   * @exception exception::not_supported     when property type is not valid for
-   * the particular layer
-   * @exception std::invalid_argument invalid argument
-   */
-  void setProperty(const std::string &type, const std::string &value);
+  std::tuple<props::Epsilon, props::BNPARAMS_MU_INIT, props::BNPARAMS_VAR_INIT,
+             props::BNPARAMS_BETA_INIT, props::BNPARAMS_GAMMA_INIT,
+             props::Momentum>
+    bn_props;
 };
 
 } // namespace nntrainer
index 4ce910c8a43bd24d96993630aaad2dff99160a56..b237032f53fecbbf1eb501bd0de1ca87499917d1 100644 (file)
@@ -106,6 +106,16 @@ bool InputSpec::isValid(const ConnectionSpec &v) const {
   return v.getLayerIds().size() > 0;
 }
 
+Epsilon::Epsilon(float value) { set(value); }
+
+bool Epsilon::isValid(const float &value) const { return value > 0.0f; }
+
+Momentum::Momentum(float value) { set(value); }
+
+bool Momentum::isValid(const float &value) const {
+  return value > 0.0f && value < 1.0f;
+}
+
 bool SplitDimension::isValid(const unsigned int &value) const {
   return value > 0 && value < ml::train::TensorDim::MAXDIM;
 }
@@ -204,6 +214,18 @@ RecurrentActivation::RecurrentActivation(ActivationTypeInfo::Enum value) {
   set(value);
 };
 
+BNPARAMS_MU_INIT::BNPARAMS_MU_INIT(Tensor::Initializer value) { set(value); }
+
+BNPARAMS_VAR_INIT::BNPARAMS_VAR_INIT(Tensor::Initializer value) { set(value); }
+
+BNPARAMS_GAMMA_INIT::BNPARAMS_GAMMA_INIT(Tensor::Initializer value) {
+  set(value);
+}
+
+BNPARAMS_BETA_INIT::BNPARAMS_BETA_INIT(Tensor::Initializer value) {
+  set(value);
+}
+
 } // namespace props
 
 static const std::vector<std::pair<char, std::string>>
index 135284958d1812f84d029bd37c6227b44a39dd54..1305c8e00788d0c85c02b7c063b369849cab1f03 100644 (file)
@@ -18,6 +18,7 @@
 #include <string>
 
 #include <base_properties.h>
+#include <tensor.h>
 
 namespace nntrainer {
 
@@ -241,6 +242,57 @@ public:
   bool isValid(const ConnectionSpec &v) const override;
 };
 
+/**
+ * @brief Epsilon property, this is used to avoid divide by zero
+ *
+ */
+class Epsilon : public nntrainer::Property<float> {
+
+public:
+  /**
+   * @brief Construct a new Epsilon object with a default value 0.001
+   *
+   */
+  Epsilon(float value = 0.001);
+  static constexpr const char *key = "epsilon"; /**< unique key to access */
+  using prop_tag = float_prop_tag;              /**< property type */
+
+  /**
+   * @brief Epsilon validator
+   *
+   * @param value float to validate
+   * @retval true if it is greater or equal than 0.0
+   * @retval false if it is samller than 0.0
+   */
+  bool isValid(const float &value) const override;
+};
+
+/**
+ * @brief Momentum property, moving average in batch normalization layer
+ *
+ */
+class Momentum : public nntrainer::Property<float> {
+
+public:
+  /**
+   * @brief Construct a new Momentum object with a default value 0.99
+   *
+   */
+  Momentum(float value = 0.99);
+  static constexpr const char *key = "momentum"; /**< unique key to access */
+  using prop_tag = float_prop_tag;               /**< property type */
+
+  /**
+   * @brief Momentum validator
+   *
+   * @param value float to validate
+   * @retval true if it is greater than 0.0 and smaller than 1.0
+   * @retval false if it is samller or equal than 0.0
+   * or greater or equal than 1.0
+   */
+  bool isValid(const float &value) const override;
+};
+
 /**
  * @brief SplitDimension property, dimension along which to split the input
  *
@@ -558,6 +610,78 @@ public:
   static constexpr const char *key = "recurrent_activation";
 };
 
+/**
+ * @brief     Enumeration of tensor initialization type
+ */
+struct InitializerInfo {
+  using Enum = Tensor::Initializer;
+  static constexpr std::initializer_list<Enum> EnumList = {
+    Enum::ZEROS,         Enum::ONES,          Enum::LECUN_NORMAL,
+    Enum::LECUN_UNIFORM, Enum::XAVIER_NORMAL, Enum::XAVIER_UNIFORM,
+    Enum::HE_NORMAL,     Enum::HE_UNIFORM,    Enum::NONE};
+
+  static constexpr const char *EnumStr[] = {
+    "zeros",         "ones",          "lecun_normal",
+    "lecun_uniform", "xavier_normal", "xavier_uniform",
+    "he_normal",     "he_uniform",    "none"};
+};
+
+/**
+ * @brief BNPARAMS_MU_INIT Initialization Enumeration Information
+ *
+ */
+class BNPARAMS_MU_INIT final : public EnumProperty<InitializerInfo> {
+public:
+  /**
+   * @brief Construct a BNPARAMS_MU_INIT object
+   */
+  BNPARAMS_MU_INIT(Tensor::Initializer value = Tensor::Initializer::ZEROS);
+  using prop_tag = enum_class_prop_tag;
+  static constexpr const char *key = "moving_mean_initializer";
+};
+
+/**
+ * @brief BNPARAMS_VAR_INIT Initialization Enumeration Information
+ *
+ */
+class BNPARAMS_VAR_INIT final : public EnumProperty<InitializerInfo> {
+public:
+  /**
+   * @brief Construct a BNPARAMS_VAR_INIT object
+   */
+  BNPARAMS_VAR_INIT(Tensor::Initializer value = Tensor::Initializer::ONES);
+  using prop_tag = enum_class_prop_tag;
+  static constexpr const char *key = "moving_variance_initializer";
+};
+
+/**
+ * @brief BNPARAMS_GAMMA_INIT Initialization Enumeration Information
+ *
+ */
+class BNPARAMS_GAMMA_INIT final : public EnumProperty<InitializerInfo> {
+public:
+  /**
+   * @brief Construct a BNPARAMS_GAMMA_INIT object
+   */
+  BNPARAMS_GAMMA_INIT(Tensor::Initializer value = Tensor::Initializer::ONES);
+  using prop_tag = enum_class_prop_tag;
+  static constexpr const char *key = "gamma_initializer";
+};
+
+/**
+ * @brief BNPARAMS_BETA_INIT Initialization Enumeration Information
+ *
+ */
+class BNPARAMS_BETA_INIT final : public EnumProperty<InitializerInfo> {
+public:
+  /**
+   * @brief Construct a BNPARAMS_BETA_INIT object
+   */
+  BNPARAMS_BETA_INIT(Tensor::Initializer value = Tensor::Initializer::ZEROS);
+  using prop_tag = enum_class_prop_tag;
+  static constexpr const char *key = "beta_initializer";
+};
+
 /**
  * @brief     Enumeration of pooling type
  */