From 435aa81b8fd5a43906453f8bb9bcf8c5d13dc47f Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Mon, 26 Jul 2021 17:27:06 +0900 Subject: [PATCH] [weight/var_grad] Remove exposure of weight/var_grad This patch updates the usage of weight and var_grad headers in order to hide them internally and not expose them in the devel headers. The below changes are made to support this: - weight and var_grad specs are declared separately in tensor wrapper specs header. - layer_context is made indepedent of the weight and var_grad definition. The usages and implementation are moved to layer_context souce file. Signed-off-by: Parichay Kapoor --- debian/nntrainer-dev.install | 1 + jni/Android.mk | 2 +- nntrainer/layers/layer_context.cpp | 265 +++++++++++++++++++++++++++++++++++ nntrainer/layers/layer_context.h | 121 ++++------------ nntrainer/layers/layer_node.h | 1 + nntrainer/layers/meson.build | 1 + nntrainer/layers/time_dist.cpp | 1 + nntrainer/tensor/meson.build | 3 +- nntrainer/tensor/tensor_wrap_specs.h | 69 +++++++++ nntrainer/tensor/var_grad.h | 3 +- nntrainer/tensor/weight.h | 31 +--- packaging/nntrainer.spec | 1 + 12 files changed, 378 insertions(+), 121 deletions(-) create mode 100644 nntrainer/layers/layer_context.cpp create mode 100644 nntrainer/tensor/tensor_wrap_specs.h diff --git a/debian/nntrainer-dev.install b/debian/nntrainer-dev.install index 7807d46..e8bffe3 100644 --- a/debian/nntrainer-dev.install +++ b/debian/nntrainer-dev.install @@ -17,6 +17,7 @@ /usr/include/nntrainer/layer_devel.h /usr/include/nntrainer/neuralnet.h /usr/include/nntrainer/tensor.h +/usr/include/nntrainer/tensor_wrap_specs.h /usr/include/nntrainer/optimizer_devel.h /usr/include/nntrainer/optimizer_impl.h /usr/include/nntrainer/profiler.h diff --git a/jni/Android.mk b/jni/Android.mk index d9dc480..970989c 100644 --- a/jni/Android.mk +++ b/jni/Android.mk @@ -140,6 +140,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \ $(NNTRAINER_ROOT)/nntrainer/tensor/tensor_dim.cpp \ $(NNTRAINER_ROOT)/nntrainer/tensor/blas_interface.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/layer_node.cpp \ + $(NNTRAINER_ROOT)/nntrainer/layers/layer_context.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/input_layer.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/multiout_layer.cpp \ $(NNTRAINER_ROOT)/nntrainer/layers/fc_layer.cpp \ @@ -173,7 +174,6 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \ $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_impl.cpp \ $(NNTRAINER_ROOT)/nntrainer/optimizers/adam.cpp \ $(NNTRAINER_ROOT)/nntrainer/optimizers/sgd.cpp \ - $(NNTRAINER_ROOT)/nntrainer/optimizers/optimizer_factory.cpp \ $(NNTRAINER_ROOT)/nntrainer/utils/util_func.cpp \ $(NNTRAINER_ROOT)/nntrainer/utils/ini_wrapper.cpp \ $(NNTRAINER_ROOT)/nntrainer/utils/parse_util.cpp \ diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp new file mode 100644 index 0000000..0f3f70e --- /dev/null +++ b/nntrainer/layers/layer_context.cpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Parichay Kapoor + * + * @file layer_context.cpp + * @date 26 July 2021 + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the layer context for each layer + */ + +#include +#include +#include + +namespace nntrainer { + +/** + * @brief Get the Weight tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight tensor + */ +Tensor &RunLayerContext::getWeight(unsigned int idx) const { + return weights[idx]->getVariableRef(); +} + +/** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ +Tensor &RunLayerContext::getWeightGrad(unsigned int idx) const { + if (!weights[idx]->hasGradient()) + throw std::invalid_argument( + "Requesting gradient for a non-trainable weight."); + return weights[idx]->getGradientRef(); +} + +/** + * @brief Get regularization loss for the weight + * + * @param idx Identifier of the weight + * @return float Value of the loss + */ +float RunLayerContext::getWeightRegularizationLoss(unsigned int idx) const { + return weights[idx]->getRegularizationLoss(); +} + +/** + * @brief Get the Weight name + * + * @param idx Identifier of the weight + * @return name of the weight + */ +const std::string &RunLayerContext::getWeightName(unsigned int idx) const { + return weights[idx]->getName(); +} + +/** + * @brief check if the weight has gradient + * + * @param idx Identifier of the weight + * @return true if weight has gradient, else false + */ +bool RunLayerContext::weightHasGradient(unsigned int idx) const { + return weights[idx]->hasGradient(); +} + +/** + * @brief Get the Output tensor object + * + * @param idx Identifier of the output + * @return Tensor& Reference to the output tensor + */ +Tensor &RunLayerContext::getOutput(unsigned int idx) { + return outputs[idx]->getVariableRef(); +} + +/** + * @brief Get the Output Grad tensor object + * + * @param idx Identifier of the output + * @return Tensor& Reference to the output grad tensor + */ +Tensor &RunLayerContext::getOutputGrad(unsigned int idx) { + if (!outputs[idx]->hasGradient()) + throw std::invalid_argument( + "Requesting gradient for a non-trainable tensor."); + return getOutputGradUnsafe(idx); +} + +/** + * @brief Get the Output Grad tensor object + * + * @param idx Identifier of the output + * @return Tensor& Reference to the output grad tensor + * + * @note recommended to NOT use this function as a layer developer but rather + * use getOutputGrad(). + */ +Tensor &RunLayerContext::getOutputGradUnsafe(unsigned int idx) { + return outputs[idx]->getGradientRef(); +} + +/** + * @brief Get the incoming Derivative tensor object + * + * @param idx Identifier of the output + * @return Tensor& Reference to the output derivative tensor + */ +Tensor &RunLayerContext::getIncomingDerivative(unsigned int idx) { + return getOutputGrad(idx); +} + +/** + * @brief Get the Input tensor object + * + * @param idx Identifier of the input + * @return Tensor& Reference to the input grad tensor + */ +Tensor &RunLayerContext::getInput(unsigned int idx) { + return inputs[idx]->getVariableRef(); +} + +/** + * @brief Get the Input Grad tensor object + * + * @param idx Identifier of the input + * @return Tensor& Reference to the input grad tensor + */ +Tensor &RunLayerContext::getInputGrad(unsigned int idx) { + if (!inputs[idx]->hasGradient()) + throw std::invalid_argument( + "Requesting gradient for a non-trainable tensor."); + return inputs[idx]->getGradientRef(); +} + +/** + * @brief Get the outgoing Derivative tensor object + * + * @param idx Identifier of the input + * @return Tensor& Reference to the input derivative tensor + */ +Tensor &RunLayerContext::getOutgoingDerivative(unsigned int idx) { + return getInputGrad(idx); +} + +/** + * @brief Get the Tensor object + * + * @param idx Identifier of the tensor + * @return Tensor& Reference to the tensor + */ +Tensor &RunLayerContext::getTensor(unsigned int idx) { + return tensors[idx]->getVariableRef(); +} + +/** + * @brief Get the Tensor Grad object + * + * @param idx Identifier of the tensor + * @return Tensor& Reference to the tensor grad tensor + */ +Tensor &RunLayerContext::getTensorGrad(unsigned int idx) { + if (!tensors[idx]->hasGradient()) + throw std::invalid_argument( + "Requesting gradient for a non-trainable tensor."); + return tensors[idx]->getGradientRef(); +} + +/** + * @brief check if the tensor has gradient + * + * @param idx Identifier of the tensor + * @return true if tensor has gradient, else false + */ +bool RunLayerContext::tensorHasGradient(unsigned int idx) const { + return tensors[idx]->hasGradient(); +} + +/** + * @brief Get the tensor name + * + * @param idx Identifier of the tensor + * @return name of the tensor + */ +const std::string &RunLayerContext::getTensorName(unsigned int idx) const { + return tensors[idx]->getName(); +} + +/** + * @brief Set the batch for the run context + * + * @param batch Update batch size + */ +void RunLayerContext::setBatch(unsigned int batch) { + for (auto &vg : inputs) + vg->setBatchSize(batch); + for (auto &vg : outputs) + vg->setBatchSize(batch); +} + +/** + * @brief Update the dimensions for a requested tensor + * + * @param idx index of the tensor (identifier) + * @param batch Updated batch size + */ +void RunLayerContext::updateTensor(unsigned int idx, unsigned int batch) { + tensors[idx]->setBatchSize(batch); +} + +/** + * @brief Get weight object for the weights + * + * @param idx index of the weight (identifier) + * @return weight object + */ +Weight &RunLayerContext::getWeightObject(unsigned int idx) { + return *weights[idx]; +} + +/** + * @brief check if the label is available + * + * @param idx Identifier of the input + * @return true if label is available else false + */ +bool RunLayerContext::isLabelAvailable(unsigned int idx) const { + return !outputs[idx]->getGradientRef().uninitialized(); +} + +/** + * @brief Get label tensor + * + * @param idx Identifier of the input + * @return Tensor& Reference to the label tensor + */ +Tensor &RunLayerContext::getLabel(unsigned int idx) { + if (isLabelAvailable(idx)) + return outputs[idx]->getGradientRef(); + else + throw std::invalid_argument("Request tensor which does not exist"); +} + +/** + * @brief check if run context is set and is ready to use + * + * @return true if ready, else false + */ +bool RunLayerContext::readyToUse() const { + /** + * assumption: + * 1. there must be atleast 1 input + * 2. the setter set everything at once + */ + if (inputs.empty()) + return false; + return !inputs[0]->getVariable().uninitialized(); +} + +} // namespace nntrainer diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 56249b2..a78f684 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -19,10 +19,13 @@ #include #include #include -#include -#include +#include namespace nntrainer { + +class Weight; +class Var_Grad; + /** * @brief define the lifespan of the given tensor to reduce peak memory * @@ -162,7 +165,7 @@ public: * @todo Consider providing a guarantee that the returned indices will always * start from 0 and will always be incremental. */ - unsigned int requestWeight(const Weight::Spec &spec) { + unsigned int requestWeight(const WeightSpec &spec) { weights_spec.emplace_back(spec); return weights_spec.size() - 1; } @@ -190,7 +193,7 @@ public: * @brief Specification of the tensors * */ - typedef Var_Grad::Spec TensorSpec; + typedef VarGradSpec TensorSpec; /** * @brief Request a new tensor for the layer @@ -211,9 +214,7 @@ public: * * @return The current weights spec */ - const std::vector &getWeightsSpec() const { - return weights_spec; - } + const std::vector &getWeightsSpec() const { return weights_spec; } /** * @brief Get the number of requested weights @@ -282,7 +283,7 @@ private: std::vector input_dim; /**< Input dimensions for the layer */ std::vector output_dim; /**< Output dimensions for the layer */ - std::vector weights_spec; /**< Specification for the weights */ + std::vector weights_spec; /**< Specification for the weights */ std::vector tensors_spec; /**< Specification for the var_grad (trainable/non-trainable variables) */ @@ -343,9 +344,7 @@ public: * @param idx Identifier of the weight * @return Tensor& Reference to the weight tensor */ - Tensor &getWeight(unsigned int idx) const { - return weights[idx]->getVariableRef(); - } + Tensor &getWeight(unsigned int idx) const; /** * @brief Get the Weight Gradient tensor object @@ -353,12 +352,7 @@ public: * @param idx Identifier of the weight * @return Tensor& Reference to the weight grad tensor */ - Tensor &getWeightGrad(unsigned int idx) const { - if (!weights[idx]->hasGradient()) - throw std::invalid_argument( - "Requesting gradient for a non-trainable weight."); - return weights[idx]->getGradientRef(); - } + Tensor &getWeightGrad(unsigned int idx) const; /** * @brief Get the Weight name @@ -366,9 +360,7 @@ public: * @param idx Identifier of the weight * @return name of the weight */ - const std::string &getWeightName(unsigned int idx) const { - return weights[idx]->getName(); - } + const std::string &getWeightName(unsigned int idx) const; /** * @brief check if the weight has gradient @@ -376,9 +368,7 @@ public: * @param idx Identifier of the weight * @return true if weight has gradient, else false */ - bool weightHasGradient(unsigned int idx) const { - return weights[idx]->hasGradient(); - } + bool weightHasGradient(unsigned int idx) const; /** * @brief Get the Output tensor object @@ -386,7 +376,7 @@ public: * @param idx Identifier of the output * @return Tensor& Reference to the output tensor */ - Tensor &getOutput(unsigned int idx) { return outputs[idx]->getVariableRef(); } + Tensor &getOutput(unsigned int idx); /** * @brief Get the Output Grad tensor object @@ -394,12 +384,7 @@ public: * @param idx Identifier of the output * @return Tensor& Reference to the output grad tensor */ - Tensor &getOutputGrad(unsigned int idx) { - if (!outputs[idx]->hasGradient()) - throw std::invalid_argument( - "Requesting gradient for a non-trainable tensor."); - return getOutputGradUnsafe(idx); - } + Tensor &getOutputGrad(unsigned int idx); /** * @brief Get the Output Grad tensor object @@ -410,9 +395,7 @@ public: * @note recommended to NOT use this function as a layer developer but rather * use getOutputGrad(). */ - Tensor &getOutputGradUnsafe(unsigned int idx) { - return outputs[idx]->getGradientRef(); - } + Tensor &getOutputGradUnsafe(unsigned int idx); /** * @brief Get the incoming Derivative tensor object @@ -420,7 +403,7 @@ public: * @param idx Identifier of the output * @return Tensor& Reference to the output derivative tensor */ - Tensor &getIncomingDerivative(unsigned int idx) { return getOutputGrad(idx); } + Tensor &getIncomingDerivative(unsigned int idx); /** * @brief Get the Input tensor object @@ -428,7 +411,7 @@ public: * @param idx Identifier of the input * @return Tensor& Reference to the input grad tensor */ - Tensor &getInput(unsigned int idx) { return inputs[idx]->getVariableRef(); } + Tensor &getInput(unsigned int idx); /** * @brief Get the Input Grad tensor object @@ -436,12 +419,7 @@ public: * @param idx Identifier of the input * @return Tensor& Reference to the input grad tensor */ - Tensor &getInputGrad(unsigned int idx) { - if (!inputs[idx]->hasGradient()) - throw std::invalid_argument( - "Requesting gradient for a non-trainable tensor."); - return inputs[idx]->getGradientRef(); - } + Tensor &getInputGrad(unsigned int idx); /** * @brief Get the outgoing Derivative tensor object @@ -449,7 +427,7 @@ public: * @param idx Identifier of the input * @return Tensor& Reference to the input derivative tensor */ - Tensor &getOutgoingDerivative(unsigned int idx) { return getInputGrad(idx); } + Tensor &getOutgoingDerivative(unsigned int idx); /** * @brief Get the Tensor object @@ -457,7 +435,7 @@ public: * @param idx Identifier of the tensor * @return Tensor& Reference to the tensor */ - Tensor &getTensor(unsigned int idx) { return tensors[idx]->getVariableRef(); } + Tensor &getTensor(unsigned int idx); /** * @brief Get the Tensor Grad object @@ -465,12 +443,7 @@ public: * @param idx Identifier of the tensor * @return Tensor& Reference to the tensor grad tensor */ - Tensor &getTensorGrad(unsigned int idx) { - if (!tensors[idx]->hasGradient()) - throw std::invalid_argument( - "Requesting gradient for a non-trainable tensor."); - return tensors[idx]->getGradientRef(); - } + Tensor &getTensorGrad(unsigned int idx); /** * @brief check if the tensor has gradient @@ -478,9 +451,7 @@ public: * @param idx Identifier of the tensor * @return true if tensor has gradient, else false */ - bool tensorHasGradient(unsigned int idx) const { - return tensors[idx]->hasGradient(); - } + bool tensorHasGradient(unsigned int idx) const; /** * @brief Get the tensor name @@ -488,9 +459,7 @@ public: * @param idx Identifier of the tensor * @return name of the tensor */ - const std::string &getTensorName(unsigned int idx) const { - return tensors[idx]->getName(); - } + const std::string &getTensorName(unsigned int idx) const; /** * @brief Get the number of Outputs tensor objects @@ -525,12 +494,7 @@ public: * * @param batch Update batch size */ - void setBatch(unsigned int batch) { - for (auto &vg : inputs) - vg->setBatchSize(batch); - for (auto &vg : outputs) - vg->setBatchSize(batch); - } + void setBatch(unsigned int batch); /** * @brief Update the dimensions for a requested tensor @@ -538,9 +502,7 @@ public: * @param idx index of the tensor (identifier) * @param batch Updated batch size */ - void updateTensor(unsigned int idx, unsigned int batch) { - tensors[idx]->setBatchSize(batch); - } + void updateTensor(unsigned int idx, unsigned int batch); /** * @brief Get weight object for the weights @@ -548,7 +510,7 @@ public: * @param idx index of the weight (identifier) * @return weight object */ - Weight &getWeightObject(unsigned int idx) { return *weights[idx]; } + Weight &getWeightObject(unsigned int idx); /** * @brief check if the label is available @@ -556,9 +518,7 @@ public: * @param idx Identifier of the input * @return true if label is available else false */ - bool isLabelAvailable(unsigned int idx) const { - return !outputs[idx]->getGradientRef().uninitialized(); - } + bool isLabelAvailable(unsigned int idx) const; /** * @brief Get label tensor @@ -566,12 +526,7 @@ public: * @param idx Identifier of the input * @return Tensor& Reference to the label tensor */ - Tensor &getLabel(unsigned int idx) { - if (isLabelAvailable(idx)) - return outputs[idx]->getGradientRef(); - else - throw std::invalid_argument("Request tensor which does not exist"); - } + Tensor &getLabel(unsigned int idx); /** * @brief update loss by the layer @@ -615,16 +570,7 @@ public: * * @return true if ready, else false */ - bool readyToUse() const { - /** - * assumption: - * 1. there must be atleast 1 input - * 2. the setter set everything at once - */ - if (inputs.empty()) - return false; - return !inputs[0]->getVariable().uninitialized(); - } + bool readyToUse() const; private: std::tuple props; /**< props of the layer */ @@ -641,12 +587,7 @@ private: * @param idx Identifier of the weight * @return float Value of the loss */ - float getWeightRegularizationLoss(unsigned int idx) const { - if (weights[idx]->hasGradient()) - return weights[idx]->getRegularizationLoss(); - - return 0; - } + float getWeightRegularizationLoss(unsigned int idx) const; }; } // namespace nntrainer diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index 96e8c86..b243e23 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -33,6 +33,7 @@ #include #include #include +#include namespace nntrainer { diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index 9624b96..08cb94a 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -26,6 +26,7 @@ layer_sources = [ 'layer_impl.cpp', 'gru.cpp', 'dropout.cpp', + 'layer_context.cpp' ] layer_headers = [ diff --git a/nntrainer/layers/time_dist.cpp b/nntrainer/layers/time_dist.cpp index 9d29aac..9e560e5 100644 --- a/nntrainer/layers/time_dist.cpp +++ b/nntrainer/layers/time_dist.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace nntrainer { diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index 9677ad6..d86987e 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -12,7 +12,8 @@ tensor_headers = [ 'manager.h', 'tensor.h', 'weight.h', - 'var_grad.h' + 'var_grad.h', + 'tensor_wrap_specs.h' ] foreach s : tensor_sources diff --git a/nntrainer/tensor/tensor_wrap_specs.h b/nntrainer/tensor/tensor_wrap_specs.h new file mode 100644 index 0000000..fe9fa69 --- /dev/null +++ b/nntrainer/tensor/tensor_wrap_specs.h @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Parichay Kapoor + * + * @file tensor_wrap_specs.h + * @date 26 July 2021 + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is specs for various tensor wrappers + * + */ + +#ifndef __TENSOR_WRAP_SPECS_H__ +#define __TENSOR_WRAP_SPECS_H__ + +#include + +#include + +namespace nntrainer { + +/** + * @brief Enumeration of Weight Regularizer + * @todo Update to TensorRegularizer + */ +enum class WeightRegularizer { + L2NORM, /**< L2 norm regularization */ + NONE, /**< no regularization */ + UNKNOWN /**< Unknown */ +}; + +/** + * @brief Enumeration of Weight Initialization Type + * @todo Update to TensorInitializer + */ +enum class WeightInitializer { + WEIGHT_ZEROS, /** Zero initialization */ + WEIGHT_ONES, /** One initialization */ + WEIGHT_LECUN_NORMAL, /** LeCun normal initialization */ + WEIGHT_LECUN_UNIFORM, /** uniform initialization */ + WEIGHT_XAVIER_NORMAL, /** Xavier normal initialization */ + WEIGHT_XAVIER_UNIFORM, /** Xavier uniform initialization */ + WEIGHT_HE_NORMAL, /** He normal initialization */ + WEIGHT_HE_UNIFORM, /** He uniform initialization */ + WEIGHT_UNKNOWN /** Unknown */ +}; + +/** + * @brief Specification of the Weight as a tensor wrapper + * + * @details The tuple values are dimension, initializer, regularizer, + * regularizer_constant, need_gradient property amd name of the Weight object. + */ +typedef std::tuple + WeightSpec; + +/** + * @brief Specification of the Var_Grad (trainable tensor) as a tensor wrapper + * + * @details The tuple values are dimension, need_gradient property, and the + * name of the Var_Grad object. + */ +typedef std::tuple VarGradSpec; + +} // namespace nntrainer + +#endif /** __TENSOR_WRAP_SPECS_H__ */ diff --git a/nntrainer/tensor/var_grad.h b/nntrainer/tensor/var_grad.h index 7f98f4c..975b2b8 100644 --- a/nntrainer/tensor/var_grad.h +++ b/nntrainer/tensor/var_grad.h @@ -17,6 +17,7 @@ #include #include +#include namespace nntrainer { @@ -32,7 +33,7 @@ public: * @details The tuple values are dimension, need_gradient property, and the * name of the Var_Grad object. */ - typedef std::tuple Spec; + typedef VarGradSpec Spec; /** * @brief Var_Grad default constructor diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index 72c9959..dd3558d 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -17,35 +17,12 @@ #include #include +#include #include namespace nntrainer { /** - * @brief Enumeration of Weight Regularizer - */ -enum class WeightRegularizer { - L2NORM, /**< L2 norm regularization */ - NONE, /**< no regularization */ - UNKNOWN /**< Unknown */ -}; - -/** - * @brief Enumeration of Weight Initialization Type - */ -enum class WeightInitializer { - WEIGHT_ZEROS, /** Zero initialization */ - WEIGHT_ONES, /** One initialization */ - WEIGHT_LECUN_NORMAL, /** LeCun normal initialization */ - WEIGHT_LECUN_UNIFORM, /** uniform initialization */ - WEIGHT_XAVIER_NORMAL, /** Xavier normal initialization */ - WEIGHT_XAVIER_UNIFORM, /** Xavier uniform initialization */ - WEIGHT_HE_NORMAL, /** He normal initialization */ - WEIGHT_HE_UNIFORM, /** He uniform initialization */ - WEIGHT_UNKNOWN /** Unknown */ -}; - -/** * @class Weight * @brief Weight with gradient, and its corresponding need_gradient property */ @@ -57,9 +34,7 @@ public: * @details The tuple values are dimension, initializer, regularizer, * regularizer_constant, need_gradient property amd name of the Weight object. */ - typedef std::tuple - Spec; + typedef WeightSpec Spec; /** * @brief Weight default constructor @@ -264,7 +239,7 @@ public: * @brief Get loss from the regularization of the weight */ float getRegularizationLoss() { - if (isWeightRegularizerL2Norm()) + if (hasGradient() && isWeightRegularizerL2Norm()) return regularizer_constant * 0.5f * var->l2norm(); return 0; diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 5208102..977709b 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -450,6 +450,7 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/layer_devel.h %{_includedir}/nntrainer/neuralnet.h %{_includedir}/nntrainer/tensor.h +%{_includedir}/nntrainer/tensor_wrap_specs.h %{_includedir}/nntrainer/optimizer_devel.h %{_includedir}/nntrainer/optimizer_impl.h %{_includedir}/nntrainer/profiler.h -- 2.7.4