[weight/var_grad] Remove exposure of weight/var_grad
authorParichay Kapoor <pk.kapoor@samsung.com>
Mon, 26 Jul 2021 08:27:06 +0000 (17:27 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 3 Aug 2021 05:26:04 +0000 (14:26 +0900)
This patch updates the usage of weight and var_grad headers in order to
hide them internally and not expose them in the devel headers.
The below changes are made to support this:
- weight and var_grad specs are declared separately in tensor wrapper
specs header.
- layer_context is made indepedent of the weight and var_grad
definition. The usages and implementation are moved to layer_context
souce file.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
12 files changed:
debian/nntrainer-dev.install
jni/Android.mk
nntrainer/layers/layer_context.cpp [new file with mode: 0644]
nntrainer/layers/layer_context.h
nntrainer/layers/layer_node.h
nntrainer/layers/meson.build
nntrainer/layers/time_dist.cpp
nntrainer/tensor/meson.build
nntrainer/tensor/tensor_wrap_specs.h [new file with mode: 0644]
nntrainer/tensor/var_grad.h
nntrainer/tensor/weight.h
packaging/nntrainer.spec

index 7807d46..e8bffe3 100644 (file)
@@ -17,6 +17,7 @@
 /usr/include/nntrainer/layer_devel.h
 /usr/include/nntrainer/neuralnet.h
 /usr/include/nntrainer/tensor.h
+/usr/include/nntrainer/tensor_wrap_specs.h
 /usr/include/nntrainer/optimizer_devel.h
 /usr/include/nntrainer/optimizer_impl.h
 /usr/include/nntrainer/profiler.h
index d9dc480..970989c 100644 (file)
@@ -140,6 +140,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/tensor_dim.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/tensor/blas_interface.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/layer_node.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/layer_context.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/input_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/multiout_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/fc_layer.cpp \
@@ -173,7 +174,6 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_impl.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/optimizers/adam.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/optimizers/sgd.cpp \
-                  $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_factory.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/utils/util_func.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/utils/ini_wrapper.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/utils/parse_util.cpp \
diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp
new file mode 100644 (file)
index 0000000..0f3f70e
--- /dev/null
@@ -0,0 +1,265 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   layer_context.cpp
+ * @date   26 July 2021
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is the layer context for each layer
+ */
+
+#include <layer_context.h>
+#include <var_grad.h>
+#include <weight.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Get the Weight tensor object
+ *
+ * @param idx Identifier of the weight
+ * @return Tensor& Reference to the weight tensor
+ */
+Tensor &RunLayerContext::getWeight(unsigned int idx) const {
+  return weights[idx]->getVariableRef();
+}
+
+/**
+ * @brief Get the Weight Gradient tensor object
+ *
+ * @param idx Identifier of the weight
+ * @return Tensor& Reference to the weight grad tensor
+ */
+Tensor &RunLayerContext::getWeightGrad(unsigned int idx) const {
+  if (!weights[idx]->hasGradient())
+    throw std::invalid_argument(
+      "Requesting gradient for a non-trainable weight.");
+  return weights[idx]->getGradientRef();
+}
+
+/**
+ * @brief Get regularization loss for the weight
+ *
+ * @param idx Identifier of the weight
+ * @return float Value of the loss
+ */
+float RunLayerContext::getWeightRegularizationLoss(unsigned int idx) const {
+  return weights[idx]->getRegularizationLoss();
+}
+
+/**
+ * @brief Get the Weight name
+ *
+ * @param idx Identifier of the weight
+ * @return name of the weight
+ */
+const std::string &RunLayerContext::getWeightName(unsigned int idx) const {
+  return weights[idx]->getName();
+}
+
+/**
+ * @brief check if the weight has gradient
+ *
+ * @param idx Identifier of the weight
+ * @return true if weight has gradient, else false
+ */
+bool RunLayerContext::weightHasGradient(unsigned int idx) const {
+  return weights[idx]->hasGradient();
+}
+
+/**
+ * @brief Get the Output tensor object
+ *
+ * @param idx Identifier of the output
+ * @return Tensor& Reference to the output tensor
+ */
+Tensor &RunLayerContext::getOutput(unsigned int idx) {
+  return outputs[idx]->getVariableRef();
+}
+
+/**
+ * @brief Get the Output Grad tensor object
+ *
+ * @param idx Identifier of the output
+ * @return Tensor& Reference to the output grad tensor
+ */
+Tensor &RunLayerContext::getOutputGrad(unsigned int idx) {
+  if (!outputs[idx]->hasGradient())
+    throw std::invalid_argument(
+      "Requesting gradient for a non-trainable tensor.");
+  return getOutputGradUnsafe(idx);
+}
+
+/**
+ * @brief Get the Output Grad tensor object
+ *
+ * @param idx Identifier of the output
+ * @return Tensor& Reference to the output grad tensor
+ *
+ * @note recommended to NOT use this function as a layer developer but rather
+ * use getOutputGrad().
+ */
+Tensor &RunLayerContext::getOutputGradUnsafe(unsigned int idx) {
+  return outputs[idx]->getGradientRef();
+}
+
+/**
+ * @brief Get the incoming Derivative tensor object
+ *
+ * @param idx Identifier of the output
+ * @return Tensor& Reference to the output derivative tensor
+ */
+Tensor &RunLayerContext::getIncomingDerivative(unsigned int idx) {
+  return getOutputGrad(idx);
+}
+
+/**
+ * @brief Get the Input tensor object
+ *
+ * @param idx Identifier of the input
+ * @return Tensor& Reference to the input grad tensor
+ */
+Tensor &RunLayerContext::getInput(unsigned int idx) {
+  return inputs[idx]->getVariableRef();
+}
+
+/**
+ * @brief Get the Input Grad tensor object
+ *
+ * @param idx Identifier of the input
+ * @return Tensor& Reference to the input grad tensor
+ */
+Tensor &RunLayerContext::getInputGrad(unsigned int idx) {
+  if (!inputs[idx]->hasGradient())
+    throw std::invalid_argument(
+      "Requesting gradient for a non-trainable tensor.");
+  return inputs[idx]->getGradientRef();
+}
+
+/**
+ * @brief Get the outgoing Derivative tensor object
+ *
+ * @param idx Identifier of the input
+ * @return Tensor& Reference to the input derivative tensor
+ */
+Tensor &RunLayerContext::getOutgoingDerivative(unsigned int idx) {
+  return getInputGrad(idx);
+}
+
+/**
+ * @brief Get the Tensor object
+ *
+ * @param idx Identifier of the tensor
+ * @return Tensor& Reference to the tensor
+ */
+Tensor &RunLayerContext::getTensor(unsigned int idx) {
+  return tensors[idx]->getVariableRef();
+}
+
+/**
+ * @brief Get the Tensor Grad object
+ *
+ * @param idx Identifier of the tensor
+ * @return Tensor& Reference to the tensor grad tensor
+ */
+Tensor &RunLayerContext::getTensorGrad(unsigned int idx) {
+  if (!tensors[idx]->hasGradient())
+    throw std::invalid_argument(
+      "Requesting gradient for a non-trainable tensor.");
+  return tensors[idx]->getGradientRef();
+}
+
+/**
+ * @brief check if the tensor has gradient
+ *
+ * @param idx Identifier of the tensor
+ * @return true if tensor has gradient, else false
+ */
+bool RunLayerContext::tensorHasGradient(unsigned int idx) const {
+  return tensors[idx]->hasGradient();
+}
+
+/**
+ * @brief Get the tensor name
+ *
+ * @param idx Identifier of the tensor
+ * @return name of the tensor
+ */
+const std::string &RunLayerContext::getTensorName(unsigned int idx) const {
+  return tensors[idx]->getName();
+}
+
+/**
+ * @brief Set the batch for the run context
+ *
+ * @param batch Update batch size
+ */
+void RunLayerContext::setBatch(unsigned int batch) {
+  for (auto &vg : inputs)
+    vg->setBatchSize(batch);
+  for (auto &vg : outputs)
+    vg->setBatchSize(batch);
+}
+
+/**
+ * @brief Update the dimensions for a requested tensor
+ *
+ * @param idx index of the tensor (identifier)
+ * @param batch Updated batch size
+ */
+void RunLayerContext::updateTensor(unsigned int idx, unsigned int batch) {
+  tensors[idx]->setBatchSize(batch);
+}
+
+/**
+ * @brief   Get weight object for the weights
+ *
+ * @param idx index of the weight (identifier)
+ * @return weight object
+ */
+Weight &RunLayerContext::getWeightObject(unsigned int idx) {
+  return *weights[idx];
+}
+
+/**
+ * @brief   check if the label is available
+ *
+ * @param idx Identifier of the input
+ * @return true if label is available else false
+ */
+bool RunLayerContext::isLabelAvailable(unsigned int idx) const {
+  return !outputs[idx]->getGradientRef().uninitialized();
+}
+
+/**
+ * @brief   Get label tensor
+ *
+ * @param idx Identifier of the input
+ * @return Tensor& Reference to the label tensor
+ */
+Tensor &RunLayerContext::getLabel(unsigned int idx) {
+  if (isLabelAvailable(idx))
+    return outputs[idx]->getGradientRef();
+  else
+    throw std::invalid_argument("Request tensor which does not exist");
+}
+
+/**
+ * @brief   check if run context is set and is ready to use
+ *
+ * @return true if ready, else false
+ */
+bool RunLayerContext::readyToUse() const {
+  /**
+   * assumption:
+   * 1. there must be atleast 1 input
+   * 2. the setter set everything at once
+   */
+  if (inputs.empty())
+    return false;
+  return !inputs[0]->getVariable().uninitialized();
+}
+
+} // namespace nntrainer
index 56249b2..a78f684 100644 (file)
 #include <common_properties.h>
 #include <tensor.h>
 #include <tensor_dim.h>
-#include <var_grad.h>
-#include <weight.h>
+#include <tensor_wrap_specs.h>
 
 namespace nntrainer {
+
+class Weight;
+class Var_Grad;
+
 /**
  * @brief define the lifespan of the given tensor to reduce peak memory
  *
@@ -162,7 +165,7 @@ public:
    * @todo Consider providing a guarantee that the returned indices will always
    * start from 0 and will always be incremental.
    */
-  unsigned int requestWeight(const Weight::Spec &spec) {
+  unsigned int requestWeight(const WeightSpec &spec) {
     weights_spec.emplace_back(spec);
     return weights_spec.size() - 1;
   }
@@ -190,7 +193,7 @@ public:
    * @brief Specification of the tensors
    *
    */
-  typedef Var_Grad::Spec TensorSpec;
+  typedef VarGradSpec TensorSpec;
 
   /**
    * @brief Request a new tensor for the layer
@@ -211,9 +214,7 @@ public:
    *
    * @return The current weights spec
    */
-  const std::vector<Weight::Spec> &getWeightsSpec() const {
-    return weights_spec;
-  }
+  const std::vector<WeightSpec> &getWeightsSpec() const { return weights_spec; }
 
   /**
    * @brief Get the number of requested weights
@@ -282,7 +283,7 @@ private:
   std::vector<TensorDim> input_dim;  /**< Input dimensions for the layer */
   std::vector<TensorDim> output_dim; /**< Output dimensions for the layer */
 
-  std::vector<Weight::Spec> weights_spec; /**< Specification for the weights */
+  std::vector<WeightSpec> weights_spec; /**< Specification for the weights */
   std::vector<TensorSpec>
     tensors_spec; /**< Specification for the var_grad (trainable/non-trainable
                      variables) */
@@ -343,9 +344,7 @@ public:
    * @param idx Identifier of the weight
    * @return Tensor& Reference to the weight tensor
    */
-  Tensor &getWeight(unsigned int idx) const {
-    return weights[idx]->getVariableRef();
-  }
+  Tensor &getWeight(unsigned int idx) const;
 
   /**
    * @brief Get the Weight Gradient tensor object
@@ -353,12 +352,7 @@ public:
    * @param idx Identifier of the weight
    * @return Tensor& Reference to the weight grad tensor
    */
-  Tensor &getWeightGrad(unsigned int idx) const {
-    if (!weights[idx]->hasGradient())
-      throw std::invalid_argument(
-        "Requesting gradient for a non-trainable weight.");
-    return weights[idx]->getGradientRef();
-  }
+  Tensor &getWeightGrad(unsigned int idx) const;
 
   /**
    * @brief Get the Weight name
@@ -366,9 +360,7 @@ public:
    * @param idx Identifier of the weight
    * @return name of the weight
    */
-  const std::string &getWeightName(unsigned int idx) const {
-    return weights[idx]->getName();
-  }
+  const std::string &getWeightName(unsigned int idx) const;
 
   /**
    * @brief check if the weight has gradient
@@ -376,9 +368,7 @@ public:
    * @param idx Identifier of the weight
    * @return true if weight has gradient, else false
    */
-  bool weightHasGradient(unsigned int idx) const {
-    return weights[idx]->hasGradient();
-  }
+  bool weightHasGradient(unsigned int idx) const;
 
   /**
    * @brief Get the Output tensor object
@@ -386,7 +376,7 @@ public:
    * @param idx Identifier of the output
    * @return Tensor& Reference to the output tensor
    */
-  Tensor &getOutput(unsigned int idx) { return outputs[idx]->getVariableRef(); }
+  Tensor &getOutput(unsigned int idx);
 
   /**
    * @brief Get the Output Grad tensor object
@@ -394,12 +384,7 @@ public:
    * @param idx Identifier of the output
    * @return Tensor& Reference to the output grad tensor
    */
-  Tensor &getOutputGrad(unsigned int idx) {
-    if (!outputs[idx]->hasGradient())
-      throw std::invalid_argument(
-        "Requesting gradient for a non-trainable tensor.");
-    return getOutputGradUnsafe(idx);
-  }
+  Tensor &getOutputGrad(unsigned int idx);
 
   /**
    * @brief Get the Output Grad tensor object
@@ -410,9 +395,7 @@ public:
    * @note recommended to NOT use this function as a layer developer but rather
    * use getOutputGrad().
    */
-  Tensor &getOutputGradUnsafe(unsigned int idx) {
-    return outputs[idx]->getGradientRef();
-  }
+  Tensor &getOutputGradUnsafe(unsigned int idx);
 
   /**
    * @brief Get the incoming Derivative tensor object
@@ -420,7 +403,7 @@ public:
    * @param idx Identifier of the output
    * @return Tensor& Reference to the output derivative tensor
    */
-  Tensor &getIncomingDerivative(unsigned int idx) { return getOutputGrad(idx); }
+  Tensor &getIncomingDerivative(unsigned int idx);
 
   /**
    * @brief Get the Input tensor object
@@ -428,7 +411,7 @@ public:
    * @param idx Identifier of the input
    * @return Tensor& Reference to the input grad tensor
    */
-  Tensor &getInput(unsigned int idx) { return inputs[idx]->getVariableRef(); }
+  Tensor &getInput(unsigned int idx);
 
   /**
    * @brief Get the Input Grad tensor object
@@ -436,12 +419,7 @@ public:
    * @param idx Identifier of the input
    * @return Tensor& Reference to the input grad tensor
    */
-  Tensor &getInputGrad(unsigned int idx) {
-    if (!inputs[idx]->hasGradient())
-      throw std::invalid_argument(
-        "Requesting gradient for a non-trainable tensor.");
-    return inputs[idx]->getGradientRef();
-  }
+  Tensor &getInputGrad(unsigned int idx);
 
   /**
    * @brief Get the outgoing Derivative tensor object
@@ -449,7 +427,7 @@ public:
    * @param idx Identifier of the input
    * @return Tensor& Reference to the input derivative tensor
    */
-  Tensor &getOutgoingDerivative(unsigned int idx) { return getInputGrad(idx); }
+  Tensor &getOutgoingDerivative(unsigned int idx);
 
   /**
    * @brief Get the Tensor object
@@ -457,7 +435,7 @@ public:
    * @param idx Identifier of the tensor
    * @return Tensor& Reference to the tensor
    */
-  Tensor &getTensor(unsigned int idx) { return tensors[idx]->getVariableRef(); }
+  Tensor &getTensor(unsigned int idx);
 
   /**
    * @brief Get the Tensor Grad object
@@ -465,12 +443,7 @@ public:
    * @param idx Identifier of the tensor
    * @return Tensor& Reference to the tensor grad tensor
    */
-  Tensor &getTensorGrad(unsigned int idx) {
-    if (!tensors[idx]->hasGradient())
-      throw std::invalid_argument(
-        "Requesting gradient for a non-trainable tensor.");
-    return tensors[idx]->getGradientRef();
-  }
+  Tensor &getTensorGrad(unsigned int idx);
 
   /**
    * @brief check if the tensor has gradient
@@ -478,9 +451,7 @@ public:
    * @param idx Identifier of the tensor
    * @return true if tensor has gradient, else false
    */
-  bool tensorHasGradient(unsigned int idx) const {
-    return tensors[idx]->hasGradient();
-  }
+  bool tensorHasGradient(unsigned int idx) const;
 
   /**
    * @brief Get the tensor name
@@ -488,9 +459,7 @@ public:
    * @param idx Identifier of the tensor
    * @return name of the tensor
    */
-  const std::string &getTensorName(unsigned int idx) const {
-    return tensors[idx]->getName();
-  }
+  const std::string &getTensorName(unsigned int idx) const;
 
   /**
    * @brief Get the number of Outputs tensor objects
@@ -525,12 +494,7 @@ public:
    *
    * @param batch Update batch size
    */
-  void setBatch(unsigned int batch) {
-    for (auto &vg : inputs)
-      vg->setBatchSize(batch);
-    for (auto &vg : outputs)
-      vg->setBatchSize(batch);
-  }
+  void setBatch(unsigned int batch);
 
   /**
    * @brief Update the dimensions for a requested tensor
@@ -538,9 +502,7 @@ public:
    * @param idx index of the tensor (identifier)
    * @param batch Updated batch size
    */
-  void updateTensor(unsigned int idx, unsigned int batch) {
-    tensors[idx]->setBatchSize(batch);
-  }
+  void updateTensor(unsigned int idx, unsigned int batch);
 
   /**
    * @brief   Get weight object for the weights
@@ -548,7 +510,7 @@ public:
    * @param idx index of the weight (identifier)
    * @return weight object
    */
-  Weight &getWeightObject(unsigned int idx) { return *weights[idx]; }
+  Weight &getWeightObject(unsigned int idx);
 
   /**
    * @brief   check if the label is available
@@ -556,9 +518,7 @@ public:
    * @param idx Identifier of the input
    * @return true if label is available else false
    */
-  bool isLabelAvailable(unsigned int idx) const {
-    return !outputs[idx]->getGradientRef().uninitialized();
-  }
+  bool isLabelAvailable(unsigned int idx) const;
 
   /**
    * @brief   Get label tensor
@@ -566,12 +526,7 @@ public:
    * @param idx Identifier of the input
    * @return Tensor& Reference to the label tensor
    */
-  Tensor &getLabel(unsigned int idx) {
-    if (isLabelAvailable(idx))
-      return outputs[idx]->getGradientRef();
-    else
-      throw std::invalid_argument("Request tensor which does not exist");
-  }
+  Tensor &getLabel(unsigned int idx);
 
   /**
    * @brief   update loss by the layer
@@ -615,16 +570,7 @@ public:
    *
    * @return true if ready, else false
    */
-  bool readyToUse() const {
-    /**
-     * assumption:
-     * 1. there must be atleast 1 input
-     * 2. the setter set everything at once
-     */
-    if (inputs.empty())
-      return false;
-    return !inputs[0]->getVariable().uninitialized();
-  }
+  bool readyToUse() const;
 
 private:
   std::tuple<props::Name> props; /**< props of the layer */
@@ -641,12 +587,7 @@ private:
    * @param idx Identifier of the weight
    * @return float Value of the loss
    */
-  float getWeightRegularizationLoss(unsigned int idx) const {
-    if (weights[idx]->hasGradient())
-      return weights[idx]->getRegularizationLoss();
-
-    return 0;
-  }
+  float getWeightRegularizationLoss(unsigned int idx) const;
 };
 
 } // namespace nntrainer
index 96e8c86..b243e23 100644 (file)
@@ -33,6 +33,7 @@
 #include <layer.h>
 #include <layer_context.h>
 #include <layer_devel.h>
+#include <weight.h>
 
 namespace nntrainer {
 
index 9624b96..08cb94a 100644 (file)
@@ -26,6 +26,7 @@ layer_sources = [
   'layer_impl.cpp',
   'gru.cpp',
   'dropout.cpp',
+  'layer_context.cpp'
 ]
 
 layer_headers = [
index 9d29aac..9e560e5 100644 (file)
@@ -16,6 +16,7 @@
 #include <parse_util.h>
 #include <time_dist.h>
 #include <util_func.h>
+#include <weight.h>
 
 namespace nntrainer {
 
index 9677ad6..d86987e 100644 (file)
@@ -12,7 +12,8 @@ tensor_headers = [
   'manager.h',
   'tensor.h',
   'weight.h',
-  'var_grad.h'
+  'var_grad.h',
+  'tensor_wrap_specs.h'
 ]
 
 foreach s : tensor_sources
diff --git a/nntrainer/tensor/tensor_wrap_specs.h b/nntrainer/tensor/tensor_wrap_specs.h
new file mode 100644 (file)
index 0000000..fe9fa69
--- /dev/null
@@ -0,0 +1,69 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   tensor_wrap_specs.h
+ * @date   26 July 2021
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is specs for various tensor wrappers
+ *
+ */
+
+#ifndef __TENSOR_WRAP_SPECS_H__
+#define __TENSOR_WRAP_SPECS_H__
+
+#include <tuple>
+
+#include <tensor.h>
+
+namespace nntrainer {
+
+/**
+ * @brief     Enumeration of Weight Regularizer
+ * @todo      Update to TensorRegularizer
+ */
+enum class WeightRegularizer {
+  L2NORM, /**< L2 norm regularization */
+  NONE,   /**< no regularization */
+  UNKNOWN /**< Unknown */
+};
+
+/**
+ * @brief     Enumeration of Weight Initialization Type
+ * @todo      Update to TensorInitializer
+ */
+enum class WeightInitializer {
+  WEIGHT_ZEROS,          /** Zero initialization */
+  WEIGHT_ONES,           /** One initialization */
+  WEIGHT_LECUN_NORMAL,   /** LeCun normal initialization */
+  WEIGHT_LECUN_UNIFORM,  /** uniform initialization */
+  WEIGHT_XAVIER_NORMAL,  /** Xavier normal initialization */
+  WEIGHT_XAVIER_UNIFORM, /** Xavier uniform initialization */
+  WEIGHT_HE_NORMAL,      /** He normal initialization */
+  WEIGHT_HE_UNIFORM,     /** He uniform initialization */
+  WEIGHT_UNKNOWN         /** Unknown */
+};
+
+/**
+ * @brief Specification of the Weight as a tensor wrapper
+ *
+ * @details The tuple values are dimension, initializer, regularizer,
+ * regularizer_constant, need_gradient property amd name of the Weight object.
+ */
+typedef std::tuple<TensorDim, WeightInitializer, WeightRegularizer, float, bool,
+                   const std::string>
+  WeightSpec;
+
+/**
+ * @brief Specification of the Var_Grad (trainable tensor) as a tensor wrapper
+ *
+ * @details The tuple values are dimension, need_gradient property, and the
+ * name of the Var_Grad object.
+ */
+typedef std::tuple<TensorDim, bool, const std::string> VarGradSpec;
+
+} // namespace nntrainer
+
+#endif /** __TENSOR_WRAP_SPECS_H__ */
index 7f98f4c..975b2b8 100644 (file)
@@ -17,6 +17,7 @@
 #include <tuple>
 
 #include <tensor.h>
+#include <tensor_wrap_specs.h>
 
 namespace nntrainer {
 
@@ -32,7 +33,7 @@ public:
    * @details The tuple values are dimension, need_gradient property, and the
    * name of the Var_Grad object.
    */
-  typedef std::tuple<TensorDim, bool, const std::string> Spec;
+  typedef VarGradSpec Spec;
 
   /**
    * @brief Var_Grad default constructor
index 72c9959..dd3558d 100644 (file)
 #include <tuple>
 
 #include <tensor.h>
+#include <tensor_wrap_specs.h>
 #include <var_grad.h>
 
 namespace nntrainer {
 
 /**
- * @brief     Enumeration of Weight Regularizer
- */
-enum class WeightRegularizer {
-  L2NORM, /**< L2 norm regularization */
-  NONE,   /**< no regularization */
-  UNKNOWN /**< Unknown */
-};
-
-/**
- * @brief     Enumeration of Weight Initialization Type
- */
-enum class WeightInitializer {
-  WEIGHT_ZEROS,          /** Zero initialization */
-  WEIGHT_ONES,           /** One initialization */
-  WEIGHT_LECUN_NORMAL,   /** LeCun normal initialization */
-  WEIGHT_LECUN_UNIFORM,  /** uniform initialization */
-  WEIGHT_XAVIER_NORMAL,  /** Xavier normal initialization */
-  WEIGHT_XAVIER_UNIFORM, /** Xavier uniform initialization */
-  WEIGHT_HE_NORMAL,      /** He normal initialization */
-  WEIGHT_HE_UNIFORM,     /** He uniform initialization */
-  WEIGHT_UNKNOWN         /** Unknown */
-};
-
-/**
  * @class   Weight
  * @brief   Weight with gradient, and its corresponding need_gradient property
  */
@@ -57,9 +34,7 @@ public:
    * @details The tuple values are dimension, initializer, regularizer,
    * regularizer_constant, need_gradient property amd name of the Weight object.
    */
-  typedef std::tuple<TensorDim, WeightInitializer, WeightRegularizer, float,
-                     bool, const std::string>
-    Spec;
+  typedef WeightSpec Spec;
 
   /**
    * @brief Weight default constructor
@@ -264,7 +239,7 @@ public:
    * @brief     Get loss from the regularization of the weight
    */
   float getRegularizationLoss() {
-    if (isWeightRegularizerL2Norm())
+    if (hasGradient() && isWeightRegularizerL2Norm())
       return regularizer_constant * 0.5f * var->l2norm();
 
     return 0;
index 5208102..977709b 100644 (file)
@@ -450,6 +450,7 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/
 %{_includedir}/nntrainer/layer_devel.h
 %{_includedir}/nntrainer/neuralnet.h
 %{_includedir}/nntrainer/tensor.h
+%{_includedir}/nntrainer/tensor_wrap_specs.h
 %{_includedir}/nntrainer/optimizer_devel.h
 %{_includedir}/nntrainer/optimizer_impl.h
 %{_includedir}/nntrainer/profiler.h