From 80834f3aed1f93c1d7f9960be3159c61acadb299 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Thu, 8 Oct 2020 15:42:32 +0900 Subject: [PATCH] [optimizer] Refactor optimizer This patch refactors optimizer in the following fashion: - split optimizer implementations of different types to derived classes of adam and sgd - create a optimizer_factory to create the optimizer class objs based on its type - this can be directly used with ccapi - OptParam struct has been removed - applyGradients has been broken down into different methods - updated associated unittests **Self evaluation:** 1. Build test: [x]Passed [ ]Failed [ ]Skipped 2. Run test: [x]Passed [ ]Failed [ ]Skipped Signed-off-by: Parichay Kapoor --- api/capi/src/nntrainer.cpp | 20 +-- jni/Android.mk | 5 +- nntrainer/include/activation_layer.h | 6 - nntrainer/include/adam.h | 107 ++++++++++++ nntrainer/include/addition_layer.h | 6 - nntrainer/include/flatten_layer.h | 6 - nntrainer/include/input_layer.h | 6 - nntrainer/include/layer.h | 5 +- nntrainer/include/neuralnet.h | 6 +- nntrainer/include/optimizer.h | 151 +++++++++-------- nntrainer/include/optimizer_factory.h | 48 ++++++ nntrainer/include/sgd.h | 50 ++++++ nntrainer/include/weight.h | 2 + nntrainer/meson.build | 6 + nntrainer/src/activation_layer.cpp | 12 -- nntrainer/src/adam.cpp | 155 ++++++++++++++++++ nntrainer/src/addition_layer.cpp | 9 -- nntrainer/src/bn_layer.cpp | 7 +- nntrainer/src/conv2d_layer.cpp | 8 +- nntrainer/src/fc_layer.cpp | 14 +- nntrainer/src/flatten_layer.cpp | 9 -- nntrainer/src/input_layer.cpp | 9 -- nntrainer/src/layer.cpp | 26 ++- nntrainer/src/loss_layer.cpp | 4 +- nntrainer/src/model_loader.cpp | 62 +++++-- nntrainer/src/neuralnet.cpp | 4 +- nntrainer/src/optimizer.cpp | 225 ++++++++------------------ nntrainer/src/optimizer_factory.cpp | 35 ++++ nntrainer/src/pooling2d_layer.cpp | 7 +- nntrainer/src/sgd.cpp | 25 +++ packaging/nntrainer.spec | 3 + test/unittest/unittest_nntrainer_internal.cpp | 61 +++---- test/unittest/unittest_nntrainer_layers.cpp | 24 ++- 33 files changed, 697 insertions(+), 426 deletions(-) create mode 100644 nntrainer/include/adam.h create mode 100644 nntrainer/include/optimizer_factory.h create mode 100644 nntrainer/include/sgd.h create mode 100644 nntrainer/src/adam.cpp create mode 100644 nntrainer/src/optimizer_factory.cpp create mode 100644 nntrainer/src/sgd.cpp diff --git a/api/capi/src/nntrainer.cpp b/api/capi/src/nntrainer.cpp index 15a5770..05d0287 100644 --- a/api/capi/src/nntrainer.cpp +++ b/api/capi/src/nntrainer.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -610,26 +611,19 @@ int ml_train_optimizer_create(ml_train_optimizer_h *optimizer, ml_train_optimizer *nnopt = new ml_train_optimizer; nnopt->magic = ML_NNTRAINER_MAGIC; - - status = - exception_bounded_make_shared(nnopt->optimizer); - if (status != ML_ERROR_NONE) { - delete nnopt; - ml_loge("creating optimizer failed"); - return status; - } - nnopt->in_use = false; - *optimizer = nnopt; - returnable f = [&]() { - return nnopt->optimizer->setType(ml_optimizer_to_nntrainer_type(type)); + nnopt->optimizer = createOptimizer(ml_optimizer_to_nntrainer_type(type)); + return ML_ERROR_NONE; }; - status = nntrainer_exception_boundary(f); + status = nntrainer_exception_boundary(f); if (status != ML_ERROR_NONE) { delete nnopt; + ml_loge("creating optimizer failed"); + } else { + *optimizer = nnopt; } return status; diff --git a/jni/Android.mk b/jni/Android.mk index 5923411..ae424d8 100644 --- a/jni/Android.mk +++ b/jni/Android.mk @@ -42,7 +42,10 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/src/neuralnet.cpp \ $(NNTRAINER_ROOT)/nntrainer/src/model_loader.cpp \ $(NNTRAINER_ROOT)/nntrainer/src/addition_layer.cpp \ $(NNTRAINER_ROOT)/nntrainer/src/blas_interface.cpp \ - $(NNTRAINER_ROOT)/nntrainer/src/weight.cpp + $(NNTRAINER_ROOT)/nntrainer/src/weight.cpp \ + $(NNTRAINER_ROOT)/nntrainer/src/adam.cpp \ + $(NNTRAINER_ROOT)/nntrainer/src/sgd.cpp \ + $(NNTRAINER_ROOT)/nntrainer/src/optimizer_factory.cpp NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer/include \ $(NNTRAINER_ROOT)/api \ diff --git a/nntrainer/include/activation_layer.h b/nntrainer/include/activation_layer.h index e9df137..26118ea 100644 --- a/nntrainer/include/activation_layer.h +++ b/nntrainer/include/activation_layer.h @@ -68,12 +68,6 @@ public: sharedConstTensor backwarding(sharedConstTensor in, int iteration); /** - * @brief copy layer - * @param[in] l layer to copy - */ - void copy(std::shared_ptr l); - - /** * @brief setActivation by preset ActivationType * * @param[in] ActivationTypeeActivationTypeeActivationTypeet diff --git a/nntrainer/include/adam.h b/nntrainer/include/adam.h new file mode 100644 index 0000000..32c55de --- /dev/null +++ b/nntrainer/include/adam.h @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Parichay Kapoor + * + * @file adam.h + * @date 6 October 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the Adam optimizer. + */ +#ifndef __ADAM_H__ +#define __ADAM_H__ +#ifdef __cplusplus + +#include + +namespace nntrainer { + +/** + * @class Adam optimizer class + * @brief Adam optimizer + */ +class Adam : public Optimizer { +public: + /** + * @brief Constructor of Optimizer Class + */ + template + Adam(float lr = 0.001f, double b1 = 0.9f, double b2 = 0.999f, + double ep = 1.0e-7f, Args... args) : + Optimizer(OptType::adam, lr, args...), + beta1(b1), + beta2(b2), + epsilon(ep) {} + + /** + * @copydoc apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + * int iteration) + */ + void apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + int iteration); + + /** + * @brief get the base name for the optimizer + * @retval base name of the optimizer + */ + std::string getBaseName() { return "Adam"; }; + + /** + * @copydoc getLearningRate(int iteration) + */ + double getLearningRate(int iteration); + + /** + * @copydoc setProperty(const PropertyType type, + const std::string &value = "") + */ + void setProperty(const PropertyType type, const std::string &value = ""); + + /** + * @copydoc Optimizer::initialize(std::shared_ptr params, unsigned int + num_weights, bool setTensor) + */ + int initialize(std::shared_ptr params, unsigned int num_weights, + bool setTensor); + + /** + * @copydoc read(std::ifstream &file) + */ + void read(std::ifstream &file); + + /** + * @copydoc save(std::ofstream &file) + */ + void save(std::ofstream &file); + + /** + * @brief get beta1 + */ + double getBeta1() { return beta1; }; + + /** + * @brief get beta2 + */ + double getBeta2() { return beta2; }; + + /** + * @brief get epsilon + */ + double getEpsilon() { return epsilon; } + +private: + /** + * @brief Internal Tensors for adam Optimizer + */ + std::vector> weight_mv; + + double beta1; /** momentum for grad */ + double beta2; /** momentum for grad**2 */ + double epsilon; /** epsilon to protect overflow */ +}; +} /* namespace nntrainer */ + +#endif /* __cplusplus */ +#endif /* __ADAM_H__ */ diff --git a/nntrainer/include/addition_layer.h b/nntrainer/include/addition_layer.h index 356efd4..e5407ce 100644 --- a/nntrainer/include/addition_layer.h +++ b/nntrainer/include/addition_layer.h @@ -82,12 +82,6 @@ public: sharedConstTensor backwarding(sharedConstTensor in, int iteration); /** - * @brief copy layer - * @param[in] l layer to copy - */ - void copy(std::shared_ptr l); - - /** * @brief get the base name for the layer * @retval base name of the layer */ diff --git a/nntrainer/include/flatten_layer.h b/nntrainer/include/flatten_layer.h index 0c085c8..54501c2 100644 --- a/nntrainer/include/flatten_layer.h +++ b/nntrainer/include/flatten_layer.h @@ -78,12 +78,6 @@ public: sharedConstTensor backwarding(sharedConstTensor in, int iteration); /** - * @brief copy layer - * @param[in] l layer to copy - */ - void copy(std::shared_ptr l); - - /** * @brief get the base name for the layer * @retval base name of the layer */ diff --git a/nntrainer/include/input_layer.h b/nntrainer/include/input_layer.h index 1693357..ac1ae2d 100644 --- a/nntrainer/include/input_layer.h +++ b/nntrainer/include/input_layer.h @@ -87,12 +87,6 @@ public: int initialize(); /** - * @brief Copy Layer - * @param[in] l layer to copy - */ - void copy(std::shared_ptr l); - - /** * @brief get the base name for the layer * @retval base name of the layer */ diff --git a/nntrainer/include/layer.h b/nntrainer/include/layer.h index 2fb6cee..f3caaf1 100644 --- a/nntrainer/include/layer.h +++ b/nntrainer/include/layer.h @@ -230,7 +230,7 @@ public: * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int setOptimizer(Optimizer &opt); + int setOptimizer(std::shared_ptr opt); /** * @brief Activation Type Getter @@ -394,7 +394,8 @@ protected: /** * @brief Optimizer for this layer */ - Optimizer opt; + // TODO: fix with #630 + std::shared_ptr opt; /** * @brief Layer type diff --git a/nntrainer/include/neuralnet.h b/nntrainer/include/neuralnet.h index 16b5db7..eb0c894 100644 --- a/nntrainer/include/neuralnet.h +++ b/nntrainer/include/neuralnet.h @@ -115,7 +115,7 @@ public: * @brief Get Learning rate * @retval Learning rate */ - float getLearningRate() { return opt.getLearningRate(); }; + float getLearningRate() { return opt->getLearningRate(); }; /** * @brief Create and load the Network with ini configuration file. @@ -305,8 +305,8 @@ private: std::string save_path; /**< Model path to save / read */ - Optimizer opt; /**< Optimizer, This gets copied into each layer, do not use - this directly */ + std::shared_ptr opt; /**< Optimizer; this gets copied into each + layer, do not use this directly */ NetType net_type; /**< Network Type */ diff --git a/nntrainer/include/optimizer.h b/nntrainer/include/optimizer.h index 0bb3ec0..f85e27b 100644 --- a/nntrainer/include/optimizer.h +++ b/nntrainer/include/optimizer.h @@ -37,45 +37,29 @@ namespace nntrainer { */ enum class OptType { sgd = 0, adam = 1, unknown = 2 }; -/** - * @brief type for the Optimizor to save hyper-parameter - */ -typedef struct _OptParam { - float learning_rate; - double beta1; - double beta2; - double epsilon; - float decay_rate; - float decay_steps; - bool continue_train; /** Continue training with previous tensors for adam */ +class Optimizer { - _OptParam(OptType type = OptType::adam) : - learning_rate(0.001f), - beta1(0.9f), - beta2(0.999f), - epsilon(1.0e-7f), - decay_rate(1.0f), - decay_steps(-1.0f), - continue_train(false) { - if (type == OptType::sgd) { - learning_rate = 0.01f; - } - } -} OptParam; + /** Allow layer to initialize optimizer with itself */ + friend class Layer; -class Optimizer { public: /** - * @brief Constructor of Optimizer Class + * @brief Default Constructor of Optimizer Class */ - Optimizer() : type(OptType::unknown), popt() {} - - Optimizer(const OptType type, OptParam popt); + Optimizer(const OptType t, float lr, float decay_rate = 1.0f, + float decay_steps = -1.0f, float continue_train = false) : + type(t), + learning_rate(lr), + decay_rate(decay_rate), + decay_steps(decay_steps), + continue_train(continue_train) { + checkValidation(); + } /** * @brief Destructor of Optimizer Class */ - ~Optimizer() {} + virtual ~Optimizer() {} /** * @brief copy constructor @@ -102,14 +86,6 @@ public: Optimizer &operator=(Optimizer &&rhs) = default; /** - * @brief set Optimizer Type - * @param[in] t Optimizer type - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - int setType(OptType t); - - /** * @brief get Optimizer Type * @retval Optimizer type */ @@ -119,27 +95,19 @@ public: * @brief get Learning Rate * @retval Learning rate */ - float getLearningRate() { return popt.learning_rate; }; + float getLearningRate() { return learning_rate; }; /** * @brief get Decay Rate for learning rate decay * @retval decay rate */ - float getDecayRate() { return popt.decay_rate; }; + float getDecayRate() { return decay_rate; }; /** * @brief get Decay Steps for learning rate decay * @retval decay steps */ - float getDecaySteps() { return popt.decay_steps; }; - - /** - * @brief set Optimizer Parameters - * @param[in] p Optimizer Parameter : OptParam - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - int setOptParam(OptParam p); + float getDecaySteps() { return decay_steps; }; /** * @brief set Optimizer Parameters @@ -150,25 +118,6 @@ public: int setProperty(std::vector values); /** - * @brief get Optimizer Parameters - * @retval OptParam - */ - OptParam getOptParam() { return popt; }; - - /** - * @brief initialize optimizer. Initialize Weight if it is adam - * @param[in] params Weight list - * @param[in] num_weights size of the array - * @param[in] setTensor true if the layer need weight update. - * Input Layer and Batch Normalization layer won't need it. - * Therefore, it sets false. - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - int initialize(std::shared_ptr params, unsigned int num_weights, - bool setTensor); - - /** * @brief apply gradient to weight_list * @param[in] params Weight list * @param[in] num_weights size of the array @@ -201,36 +150,82 @@ public: * @brief Read Training optimizer paramters from file * @param[in] file input stream file */ - void read(std::ifstream &file); + virtual void read(std::ifstream &file); /** * @brief Save Training optimizer paramters from file * @param[in] file output stream file */ - void save(std::ofstream &file); + virtual void save(std::ofstream &file); /** - * @brief get the base name for the layer - * @retval base name of the layer + * @brief setProperty by PropertyType + * @note By passing empty string, this can validate if @a type is valid + * @param[in] type property type to be passed + * @param[in] value value to be passed, if empty string is passed, do nothing + * but throws error when @a type is invalid + * @exception exception::not_supported when property type is not valid for + * the particular layer + * @exception std::invalid_argument invalid argument */ - std::string getBaseName() { return "Optimizer"; }; + virtual void setProperty(const PropertyType type, + const std::string &value = ""); + + /** + * @brief get the base name for the optimizer + * @retval base name of the optimizer + */ + virtual std::string getBaseName() = 0; -private: + /** + * @brief validate the optimizer + */ + virtual void checkValidation(); + +protected: /** * @brief Optimizer Type */ OptType type; /** - * @brief Optimizer Hyper Parmeters + * @brief get Learning Rate for the given iteration + * @param[in] iteration Iteration for the learning rate + * @retval Learning rate */ - OptParam popt; + virtual double getLearningRate(int iteration); + float learning_rate; /** learning rate */ + float decay_rate; /** decay rate for learning rate */ + float decay_steps; /** decay steps for learning rate */ + bool continue_train; /** Continue training with previous tensors for adam */ + +private: /** - * @brief Internal Tensors for adam Optimizer + * @brief initialize optimizer. Initialize Weight if it is adam + * @param[in] params Weight list + * @param[in] num_weights size of the array + * @param[in] setTensor true if the layer need weight update. + * Input Layer and Batch Normalization layer won't need it. + * Therefore, it sets false. + * @retval #ML_ERROR_NONE Successful. + * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - std::vector> weight_mv; + virtual int initialize(std::shared_ptr params, + unsigned int num_weights, bool setTensor); + + /** + * @brief apply gradient to the given weight + * @param[in] weight Weight and gradient set to be updated + * @param[in] tensor_idx Idx of this tensor in the tensors list + * @param[in] num_weights size of the array + * @param[in] iteration nth epoch number + * @note weight which is called upon can be assumed to be trainable + */ + virtual void apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + int iteration) = 0; }; + } /* namespace nntrainer */ #endif /* __cplusplus */ diff --git a/nntrainer/include/optimizer_factory.h b/nntrainer/include/optimizer_factory.h new file mode 100644 index 0000000..f25abcf --- /dev/null +++ b/nntrainer/include/optimizer_factory.h @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Parichay Kapoor + * + * @file optimizer_factory.h + * @date 7 October 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the optimizer factory. + */ + +#ifndef __OPTIMIZER_FACTORY_H__ +#define __OPTIMIZER_FACTORY_H__ +#ifdef __cplusplus + +#include +#include +#include + +namespace nntrainer { + +/** + * @brief Factory creator with copy constructor + */ +std::unique_ptr createOptimizer(OptType type, const Optimizer &opt); + +/** + * @brief Factory creator with constructor + */ +template +std::unique_ptr createOptimizer(OptType type, Args... args) { + switch (type) { + case OptType::sgd: + return std::make_unique(args...); + case OptType::adam: + return std::make_unique(args...); + case OptType::unknown: + /** fallthrough intended */ + default: + throw std::invalid_argument("Unknown type for the optimizer"); + } +} + +} // namespace nntrainer + +#endif // __cplusplus +#endif // __OPTIMIZER_FACTORY_H__ diff --git a/nntrainer/include/sgd.h b/nntrainer/include/sgd.h new file mode 100644 index 0000000..30834a8 --- /dev/null +++ b/nntrainer/include/sgd.h @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Parichay Kapoor + * + * @file sgd.h + * @date 6 October 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the SGD optimizer. + */ +#ifndef __SGD_H__ +#define __SGD_H__ +#ifdef __cplusplus + +#include + +namespace nntrainer { + +/** + * @class SGD optimizer class + * @brief Stochastic Gradient Descent optimizer class + */ +class SGD : public Optimizer { +public: + /** + * @brief Constructor of Optimizer Class + */ + template + SGD(float lr = 0.0001f, Args... args) : + Optimizer(OptType::sgd, lr, args...) {} + + /** + * @copydoc apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + * int iteration) + */ + void apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + int iteration); + + /** + * @brief get the base name for the optimizer + * @retval base name of the optimizer + */ + std::string getBaseName() { return "SGD"; }; +}; +} /* namespace nntrainer */ + +#endif /* __cplusplus */ +#endif /* __SGD_H__ */ diff --git a/nntrainer/include/weight.h b/nntrainer/include/weight.h index af0714b..d62d364 100644 --- a/nntrainer/include/weight.h +++ b/nntrainer/include/weight.h @@ -55,6 +55,8 @@ class Weight { /** Declare opitmizer as friend to get variable/gradient reference */ friend class Optimizer; + friend class SGD; + friend class Adam; public: /** diff --git a/nntrainer/meson.build b/nntrainer/meson.build index 9366592..b656cb1 100644 --- a/nntrainer/meson.build +++ b/nntrainer/meson.build @@ -43,8 +43,11 @@ nntrainer_sources = [ 'src/neuralnet.cpp', 'src/nntrainer_logger.cpp', 'src/optimizer.cpp', + 'src/optimizer_factory.cpp', 'src/parse_util.cpp', 'src/pooling2d_layer.cpp', + 'src/sgd.cpp', + 'src/adam.cpp', 'src/tensor.cpp', 'src/tensor_dim.cpp', 'src/util_func.cpp', @@ -73,10 +76,13 @@ nntrainer_headers = [ 'include/optimizer.h', 'include/parse_util.h', 'include/pooling2d_layer.h', + 'include/sgd.h', + 'include/adam.h', 'include/tensor.h', 'include/tensor_dim.h', 'include/util_func.h', 'include/weight.h', + 'include/optimizer_factory.h', '../api/nntrainer-api-common.h' ] diff --git a/nntrainer/src/activation_layer.cpp b/nntrainer/src/activation_layer.cpp index 2604d1d..f0c798d 100644 --- a/nntrainer/src/activation_layer.cpp +++ b/nntrainer/src/activation_layer.cpp @@ -72,18 +72,6 @@ sharedConstTensor ActivationLayer::backwarding(sharedConstTensor derivative, return MAKE_SHARED_TENSOR(std::move(ret)); } -/** - * @brief copy layer - * @param[in] l layer to copy - */ -void ActivationLayer::copy(std::shared_ptr l) { - std::shared_ptr from = - std::static_pointer_cast(l); - this->input.copy(from->input); - this->hidden.copy(from->hidden); - this->activation_type = from->activation_type; -}; - int ActivationLayer::setActivation( std::function const &activation_fn, std::function const diff --git a/nntrainer/src/adam.cpp b/nntrainer/src/adam.cpp new file mode 100644 index 0000000..a95252e --- /dev/null +++ b/nntrainer/src/adam.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Parichay Kapoor + * + * @file adam.cpp + * @date 6 October 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the Adam optimizer. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace nntrainer { + +int Adam::initialize(std::shared_ptr weight_list, + unsigned int num_weights, bool set_tensor) { + int status = ML_ERROR_NONE; + weight_mv.clear(); + + if (set_tensor) { + for (unsigned int i = 0; i < num_weights; ++i) { + Weight &w = weight_list.get()[i]; + + // TODO: only trainable weights must be sent to optimizer + if (!w.getTrainable()) + continue; + + Tensor m = Tensor(w.getDim()); + m.setZero(); + Tensor v = Tensor(w.getDim()); + v.setZero(); + std::pair p = + std::pair(std::move(m), std::move(v)); + weight_mv.push_back(std::move(p)); + } + } + return status; +} + +double Adam::getLearningRate(int iteration) { + double ll = Optimizer::getLearningRate(iteration); + + std::function biasCorrection = [&](float f) { + return 1.0f - pow(f, iteration + 1); + }; + + ll *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1); + + return ll; +} + +void Adam::apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + int iteration) { + + Tensor &x = weight.getVariableRef(); + const Tensor &x_grad = weight.getGradientRef(); + + // This is implementation of adam from original paper. + // This is not deleted intentionally. + // float biasCorrection1 = 1 - pow(beta1, iteration + 1); + // float biasCorrection2 = 1 - pow(beta2, iteration + 1); + // Tensor &wm = weight_mv[idx].first; + // Tensor &wv = weight_mv[idx].second; + + // wm.multiply_i(beta1); + // wm.add_i(x_grad, 1.0f - beta1); + + // wv.multiply_i(beta2); + // wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2); + + // Tensor denom = wv.apply(sqrtFloat) + // .divide(sqrtFloat(biasCorrection2)) + // .add(epsilon); + // x.add_i(wm.divide(denom), -ll / biasCorrection1); + + std::function sqrtEps = [&](double f) { + return sqrtDouble(f) + this->epsilon; + }; + + Tensor &wm = weight_mv[tensor_idx].first; + Tensor &wv = weight_mv[tensor_idx].second; + + wm.multiply_i(beta1); + wm.add_i(x_grad, 1.0f - beta1); + + wv.multiply_i(beta2); + wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2); + + x.add_i(wm.divide(wv.apply(sqrtEps)), -updated_lr); +} + +void Adam::setProperty(const PropertyType type, const std::string &value) { + int status = ML_ERROR_NONE; + + switch (type) { + case PropertyType::beta1: + status = setDouble(beta1, value); + break; + case PropertyType::beta2: + status = setDouble(beta2, value); + break; + case PropertyType::epsilon: + status = setDouble(epsilon, value); + break; + default: + Optimizer::setProperty(type, value); + status = ML_ERROR_NONE; + break; + } + + throw_status(status); +} + +void Adam::read(std::ifstream &file) { + OptType loaded_type; + file.read((char *)&loaded_type, sizeof(OptType)); + + if (loaded_type == type) { + if (continue_train) { + for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) { + (*iter).first.read(file); + (*iter).second.read(file); + } + } else { + size_t total_size = 0; + for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) + total_size += (*iter).first.getSize() + (*iter).second.getSize(); + + file.seekg(total_size, std::ifstream::cur); + } + } else { + ml_logw("Not loading saved optimizer parameters due to mismatched type"); + } +} + +void Adam::save(std::ofstream &file) { + Optimizer::save(file); + + for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) { + (*iter).first.save(file); + (*iter).second.save(file); + } +} + +} // namespace nntrainer diff --git a/nntrainer/src/addition_layer.cpp b/nntrainer/src/addition_layer.cpp index bb73ea8..ea9de6d 100644 --- a/nntrainer/src/addition_layer.cpp +++ b/nntrainer/src/addition_layer.cpp @@ -83,13 +83,4 @@ void AdditionLayer::setProperty(const PropertyType type, } } -void AdditionLayer::copy(std::shared_ptr l) { - std::shared_ptr from = - std::static_pointer_cast(l); - this->input.copy(from->input); - this->hidden.copy(from->hidden); - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; -} - } /* namespace nntrainer */ diff --git a/nntrainer/src/bn_layer.cpp b/nntrainer/src/bn_layer.cpp index 828a3cb..3c65840 100644 --- a/nntrainer/src/bn_layer.cpp +++ b/nntrainer/src/bn_layer.cpp @@ -176,7 +176,7 @@ BatchNormalizationLayer::backwarding(sharedConstTensor derivative, Tensor dx = dx_2.multiply(dx_1); dx.divide_i(N); - opt.apply_gradients(weight_list, num_weights, iteration); + opt->apply_gradients(weight_list, num_weights, iteration); return MAKE_SHARED_TENSOR(std::move(dx)); } @@ -186,11 +186,6 @@ void BatchNormalizationLayer::copy(std::shared_ptr l) { std::shared_ptr from = std::static_pointer_cast(l); - this->opt = from->opt; - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; - this->input.copy(from->input); - this->hidden.copy(from->hidden); this->cvar.copy(from->cvar); } diff --git a/nntrainer/src/conv2d_layer.cpp b/nntrainer/src/conv2d_layer.cpp index 6cecc33..17808f6 100644 --- a/nntrainer/src/conv2d_layer.cpp +++ b/nntrainer/src/conv2d_layer.cpp @@ -348,7 +348,7 @@ sharedConstTensor Conv2DLayer::backwarding(sharedConstTensor derivative, } } - opt.apply_gradients(weight_list, num_weights, iteration); + opt->apply_gradients(weight_list, num_weights, iteration); } return MAKE_SHARED_TENSOR(std::move(strip_pad(ret, padding))); @@ -356,6 +356,7 @@ sharedConstTensor Conv2DLayer::backwarding(sharedConstTensor derivative, void Conv2DLayer::copy(std::shared_ptr l) { Layer::copy(l); + std::shared_ptr from = std::static_pointer_cast(l); this->filter_size = from->filter_size; for (unsigned int i = 0; i < CONV2D_DIM; ++i) { @@ -363,11 +364,6 @@ void Conv2DLayer::copy(std::shared_ptr l) { this->stride[i] = from->stride[i]; this->padding[i] = from->padding[i]; } - - this->input.copy(from->input); - this->hidden.copy(from->hidden); - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; } int Conv2DLayer::setSize(int *size, PropertyType type) { diff --git a/nntrainer/src/fc_layer.cpp b/nntrainer/src/fc_layer.cpp index 7db4f7a..ff9d451 100644 --- a/nntrainer/src/fc_layer.cpp +++ b/nntrainer/src/fc_layer.cpp @@ -88,12 +88,14 @@ sharedConstTensor FullyConnectedLayer::forwarding(sharedConstTensor in) { void FullyConnectedLayer::read(std::ifstream &file) { Layer::read(file); - opt.read(file); + if (opt) + opt->read(file); } void FullyConnectedLayer::save(std::ofstream &file) { Layer::save(file); - opt.save(file); + if (opt) + opt->save(file); } void FullyConnectedLayer::copy(std::shared_ptr l) { @@ -101,13 +103,7 @@ void FullyConnectedLayer::copy(std::shared_ptr l) { std::shared_ptr from = std::static_pointer_cast(l); - this->opt = from->opt; this->unit = from->unit; - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; - this->input.copy(from->input); - this->hidden.copy(from->hidden); - this->loss = from->loss; } sharedConstTensor FullyConnectedLayer::backwarding(sharedConstTensor derivative, @@ -127,7 +123,7 @@ sharedConstTensor FullyConnectedLayer::backwarding(sharedConstTensor derivative, djdw = djdw.sum(0); if (trainable) { - opt.apply_gradients(weight_list, num_weights, iteration); + opt->apply_gradients(weight_list, num_weights, iteration); } return MAKE_SHARED_TENSOR(std::move(ret)); diff --git a/nntrainer/src/flatten_layer.cpp b/nntrainer/src/flatten_layer.cpp index 66dc2e1..7808db0 100644 --- a/nntrainer/src/flatten_layer.cpp +++ b/nntrainer/src/flatten_layer.cpp @@ -51,13 +51,4 @@ sharedConstTensor FlattenLayer::backwarding(sharedConstTensor in, return MAKE_SHARED_TENSOR(std::move(temp)); } -void FlattenLayer::copy(std::shared_ptr l) { - std::shared_ptr from = - std::static_pointer_cast(l); - this->input.copy(from->input); - this->hidden.copy(from->hidden); - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; -} - } /* namespace nntrainer */ diff --git a/nntrainer/src/input_layer.cpp b/nntrainer/src/input_layer.cpp index 89a1843..d1cfec7 100644 --- a/nntrainer/src/input_layer.cpp +++ b/nntrainer/src/input_layer.cpp @@ -51,15 +51,6 @@ void InputLayer::setProperty(const PropertyType type, } } -void InputLayer::copy(std::shared_ptr l) { - std::shared_ptr from = std::static_pointer_cast(l); - this->opt = from->opt; - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; - this->input.copy(from->input); - this->hidden.copy(from->hidden); -} - sharedConstTensor InputLayer::forwarding(sharedConstTensor in) { input = *in; diff --git a/nntrainer/src/layer.cpp b/nntrainer/src/layer.cpp index 5c20bba..daccd27 100644 --- a/nntrainer/src/layer.cpp +++ b/nntrainer/src/layer.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -40,11 +41,9 @@ int Layer::setActivation(ActivationType acti) { return status; } -int Layer::setOptimizer(Optimizer &opt) { - this->opt.setType(opt.getType()); - this->opt.setOptParam(opt.getOptParam()); - - return this->opt.initialize(weight_list, num_weights, true); +int Layer::setOptimizer(std::shared_ptr opt) { + this->opt = createOptimizer(opt->getType(), *opt.get()); + return this->opt->initialize(weight_list, num_weights, true); } int Layer::checkValidation() { @@ -72,6 +71,23 @@ void Layer::copy(std::shared_ptr l) { for (unsigned int i = 0; i < num_weights; ++i) { weightAt(i) = l->weightAt(i); } + + // TODO: fix this #630 + this->opt = l->opt; + this->input_dim = l->input_dim; + this->output_dim = l->output_dim; + this->input.copy(l->input); + this->hidden.copy(l->hidden); + this->activation_type = l->activation_type; + this->loss = l->loss; + this->type = l->type; + this->weight_regularizer = l->weight_regularizer; + this->weight_regularizer_constant = l->weight_regularizer_constant; + this->weight_initializer = l->weight_initializer; + this->flatten = l->flatten; + this->trainable = l->trainable; + this->num_inputs = l->num_inputs; + this->num_outputs = l->num_outputs; } void Layer::read(std::ifstream &file) { diff --git a/nntrainer/src/loss_layer.cpp b/nntrainer/src/loss_layer.cpp index 3cb05a0..692a55c 100644 --- a/nntrainer/src/loss_layer.cpp +++ b/nntrainer/src/loss_layer.cpp @@ -126,10 +126,10 @@ void LossLayer::updateLoss(const Tensor &l) { } void LossLayer::copy(std::shared_ptr l) { + Layer::copy(l); + std::shared_ptr from = std::static_pointer_cast(l); - this->input.copy(from->input); this->loss_type = from->loss_type; - this->loss = from->loss; } sharedConstTensor LossLayer::backwarding(sharedConstTensor derivative, diff --git a/nntrainer/src/model_loader.cpp b/nntrainer/src/model_loader.cpp index 4014333..36c3b21 100644 --- a/nntrainer/src/model_loader.cpp +++ b/nntrainer/src/model_loader.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -54,22 +55,53 @@ int ModelLoader::loadModelConfigIni(dictionary *ini, NeuralNetwork &model) { iniparser_getint(ini, "Model:Batch_Size", model.batch_size); /** Default to adam optimizer */ - status = model.opt.setType((OptType)parseType( - iniparser_getstring(ini, "Model:Optimizer", "adam"), TOKEN_OPT)); - NN_RETURN_STATUS(); + OptType opt_type = (OptType)parseType( + iniparser_getstring(ini, "Model:Optimizer", "adam"), TOKEN_OPT); + + try { + model.opt = createOptimizer(opt_type); + } catch (std::exception &e) { + ml_loge("%s %s", typeid(e).name(), e.what()); + return ML_ERROR_INVALID_PARAMETER; + } catch (...) { + ml_loge("Creating the optimizer failed"); + return ML_ERROR_INVALID_PARAMETER; + } + + std::vector optimizer_prop = {}; + optimizer_prop.push_back( + {"learning_rate=" + + std::string(iniparser_getstring( + ini, "Model:Learning_rate", + std::to_string(model.opt->getLearningRate()).c_str()))}); + + optimizer_prop.push_back( + {"decay_steps=" + std::string(iniparser_getstring( + ini, "Model:Decay_steps", + std::to_string(model.opt->getDecaySteps()).c_str()))}); + optimizer_prop.push_back( + {"decay_rate=" + std::string(iniparser_getstring( + ini, "Model:Decay_rate", + std::to_string(model.opt->getDecayRate()).c_str()))}); + + if (model.opt->getType() == OptType::adam) { + std::shared_ptr opt_adam = std::static_pointer_cast(model.opt); + + optimizer_prop.push_back( + {"beta1=" + + std::string(iniparser_getstring( + ini, "Model:Beta1", std::to_string(opt_adam->getBeta1()).c_str()))}); + optimizer_prop.push_back( + {"beta2=" + + std::string(iniparser_getstring( + ini, "Model:Beta2", std::to_string(opt_adam->getBeta2()).c_str()))}); + optimizer_prop.push_back( + {"epsilon=" + std::string(iniparser_getstring( + ini, "Model:Epsilon", + std::to_string(opt_adam->getEpsilon()).c_str()))}); + } - OptParam popt(model.opt.getType()); - popt.learning_rate = - iniparser_getdouble(ini, "Model:Learning_rate", popt.learning_rate); - popt.decay_steps = - iniparser_getint(ini, "Model:Decay_steps", popt.decay_steps); - popt.decay_rate = - iniparser_getdouble(ini, "Model:Decay_rate", popt.decay_rate); - popt.beta1 = iniparser_getdouble(ini, "Model:beta1", popt.beta1); - popt.beta2 = iniparser_getdouble(ini, "Model:beta2", popt.beta2); - popt.epsilon = iniparser_getdouble(ini, "Model:epsilon", popt.epsilon); - - status = model.opt.setOptParam(popt); + status = model.opt->setProperty(optimizer_prop); NN_RETURN_STATUS(); return status; diff --git a/nntrainer/src/neuralnet.cpp b/nntrainer/src/neuralnet.cpp index 3eb1f04..1a4a27a 100644 --- a/nntrainer/src/neuralnet.cpp +++ b/nntrainer/src/neuralnet.cpp @@ -162,7 +162,7 @@ int NeuralNetwork::setTrainConfig(std::vector values) { status = setBoolean(cont_train, value); NN_RETURN_STATUS(); continue_train = cont_train; - opt.setProperty({values[i]}); + opt->setProperty({values[i]}); } break; case PropertyType::batch_size: { status = setUint(batch_size, value); @@ -597,7 +597,7 @@ int NeuralNetwork::setOptimizer(std::shared_ptr optimizer) { return ML_ERROR_NOT_SUPPORTED; } - opt = *optimizer.get(); + opt = optimizer; return ML_ERROR_NONE; } diff --git a/nntrainer/src/optimizer.cpp b/nntrainer/src/optimizer.cpp index 4fe43d7..8065f97 100644 --- a/nntrainer/src/optimizer.cpp +++ b/nntrainer/src/optimizer.cpp @@ -34,73 +34,25 @@ namespace nntrainer { -Optimizer::Optimizer(const OptType t, const OptParam p) { - type = t; - popt = p; +int Optimizer::initialize(std::shared_ptr weight_list, + unsigned int num_weights, bool set_tensor) { + return ML_ERROR_NONE; } -int Optimizer::setType(OptType t) { - int status = ML_ERROR_NONE; - if (t == OptType::unknown) { - ml_loge("Error: Optimizer is unknown"); - return ML_ERROR_INVALID_PARAMETER; - } - type = t; - return status; -} +double Optimizer::getLearningRate(int iteration) { + double ll = learning_rate; -int Optimizer::setOptParam(OptParam p) { - int status = ML_ERROR_NONE; - if (p.learning_rate <= 0) { - ml_loge("Error: learning_rate should be grater than 0 (%f)", - p.learning_rate); - return ML_ERROR_INVALID_PARAMETER; + if (decay_steps != -1) { + ll = ll * pow(decay_rate, (iteration / decay_steps)); } - popt = p; - return status; -} - -int Optimizer::initialize(std::shared_ptr weight_list, - unsigned int num_weights, bool set_tensor) { - int status = ML_ERROR_NONE; - - if (type == OptType::adam && set_tensor) { - for (unsigned int i = 0; i < num_weights; ++i) { - Weight &w = weight_list.get()[i]; - - // TODO: only trainable weights must be sent to optimizer - if (!w.getTrainable()) - continue; - - Tensor m = Tensor(w.getDim()); - m.setZero(); - Tensor v = Tensor(w.getDim()); - v.setZero(); - std::pair p = - std::pair(std::move(m), std::move(v)); - weight_mv.push_back(std::move(p)); - } - } - return status; + return ll; } void Optimizer::apply_gradients(std::shared_ptr weight_list, unsigned int num_weights, int iteration) { - double ll = popt.learning_rate; - - if (popt.decay_steps != -1) { - ll = ll * pow(popt.decay_rate, (iteration / popt.decay_steps)); - } - - if (type == OptType::adam) { - std::function biasCorrection = [&](float f) { - return 1.0f - pow(f, iteration + 1); - }; - - ll *= sqrt(biasCorrection(popt.beta2)) / biasCorrection(popt.beta1); - } + double ll = getLearningRate(iteration); int idx = 0; for (unsigned int i = 0; i < num_weights; ++i) { @@ -109,55 +61,7 @@ void Optimizer::apply_gradients(std::shared_ptr weight_list, if (!weight.getTrainable()) continue; - Tensor &x = weight.getVariableRef(); - const Tensor &x_grad = weight.getGradientRef(); - switch (type) { - case OptType::sgd: - x.add_i(x_grad, -ll); - break; - case OptType::adam: { - - // This is implementation of adam from original paper. - // This is not deleted intentionally. - // float biasCorrection1 = 1 - pow(popt.beta1, iteration + 1); - // float biasCorrection2 = 1 - pow(popt.beta2, iteration + 1); - // Tensor &wm = weight_mv[idx].first; - // Tensor &wv = weight_mv[idx].second; - - // wm.multiply_i(popt.beta1); - // wm.add_i(x_grad, 1.0f - popt.beta1); - - // wv.multiply_i(popt.beta2); - // wv.add_i(x_grad.multiply(x_grad), 1.0f - popt.beta2); - - // Tensor denom = wv.apply(sqrtFloat) - // .divide(sqrtFloat(biasCorrection2)) - // .add(popt.epsilon); - // x.add_i(wm.divide(denom), -ll / biasCorrection1); - - std::function sqrtEps = [&](double f) { - return sqrtDouble(f) + this->popt.epsilon; - }; - - Tensor &wm = weight_mv[idx].first; - Tensor &wv = weight_mv[idx].second; - - wm.multiply_i(popt.beta1); - wm.add_i(x_grad, 1.0f - popt.beta1); - - wv.multiply_i(popt.beta2); - wv.add_i(x_grad.multiply(x_grad), 1.0f - popt.beta2); - - x.add_i(wm.divide(wv.apply(sqrtEps)), -ll); - - break; - } - case OptType::unknown: - default: - throw std::runtime_error("Unknown optimizer."); - break; - } - + apply_gradient(weight, idx, ll, iteration); idx += 1; } } @@ -168,75 +72,74 @@ int Optimizer::setProperty(std::vector values) { for (unsigned int i = 0; i < values.size(); ++i) { std::string key; std::string value; + status = getKeyValue(values[i], key, value); + NN_RETURN_STATUS(); - unsigned int type = parseOptProperty(key.c_str()); - - switch (static_cast(type)) { - case PropertyType::learning_rate: - status = setFloat(popt.learning_rate, value); - NN_RETURN_STATUS(); - break; - case PropertyType::decay_steps: - status = setFloat(popt.decay_steps, value); - NN_RETURN_STATUS(); - break; - case PropertyType::decay_rate: - status = setFloat(popt.decay_rate, value); - NN_RETURN_STATUS(); - break; - case PropertyType::beta1: - status = setDouble(popt.beta1, value); - NN_RETURN_STATUS(); - break; - case PropertyType::beta2: - status = setDouble(popt.beta2, value); - NN_RETURN_STATUS(); - break; - case PropertyType::epsilon: - status = setDouble(popt.epsilon, value); - NN_RETURN_STATUS(); - break; - case PropertyType::continue_train: - status = setBoolean(popt.continue_train, value); - NN_RETURN_STATUS(); - break; - default: - ml_loge("Error: Unknown Optimizer Property Key"); - status = ML_ERROR_INVALID_PARAMETER; - break; + unsigned int type = parseOptProperty(key); + + if (value.empty()) { + return ML_ERROR_INVALID_PARAMETER; + } + + try { + /// @note this calls derived setProperty if available + setProperty(static_cast(type), value); + } catch (...) { + return ML_ERROR_INVALID_PARAMETER; } } + try { + checkValidation(); + } catch (...) { + return ML_ERROR_INVALID_PARAMETER; + } return status; } +void Optimizer::checkValidation() { + if (learning_rate <= 0.0f) + throw std::invalid_argument("Learning rate must be positive"); +} + +void Optimizer::setProperty(const PropertyType type, const std::string &value) { + int status = ML_ERROR_NONE; + + switch (type) { + case PropertyType::learning_rate: + status = setFloat(learning_rate, value); + break; + case PropertyType::decay_steps: + status = setFloat(decay_steps, value); + break; + case PropertyType::decay_rate: + status = setFloat(decay_rate, value); + break; + case PropertyType::continue_train: + status = setBoolean(continue_train, value); + break; + default: + ml_loge("Error: Unknown Optimizer Property Key"); + status = ML_ERROR_INVALID_PARAMETER; + break; + } + + throw_status(status); +} + void Optimizer::read(std::ifstream &file) { OptType loaded_type; file.read((char *)&loaded_type, sizeof(OptType)); - if (type == OptType::adam and loaded_type == type) { - if (popt.continue_train) { - for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) { - (*iter).first.read(file); - (*iter).second.read(file); - } - } else { - size_t total_size = 0; - for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) - total_size += (*iter).first.getSize() + (*iter).second.getSize(); - - file.seekg(total_size, std::ifstream::cur); - } - } + + if (loaded_type >= OptType::unknown) + throw std::runtime_error("Saved file has unknown optimizer"); } void Optimizer::save(std::ofstream &file) { + if (type >= OptType::unknown) + throw std::runtime_error("Cannot save unknown optimizer"); + file.write((char *)&type, sizeof(OptType)); - if (type == OptType::adam) { - for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) { - (*iter).first.save(file); - (*iter).second.save(file); - } - } } } // namespace nntrainer diff --git a/nntrainer/src/optimizer_factory.cpp b/nntrainer/src/optimizer_factory.cpp new file mode 100644 index 0000000..825b06d --- /dev/null +++ b/nntrainer/src/optimizer_factory.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Parichay Kapoor + * + * @file optimizer_factory.cpp + * @date 7 October 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the optimizer factory. + */ + +#include +#include +#include + +namespace nntrainer { + +/** + * @brief Factory creator with copy constructor + */ +std::unique_ptr createOptimizer(OptType type, const Optimizer &opt) { + switch (type) { + case OptType::sgd: + return std::make_unique(static_cast(opt)); + case OptType::adam: + return std::make_unique(static_cast(opt)); + case OptType::unknown: + /** fallthrough intended */ + default: + throw std::invalid_argument("Unknown type for the optimizer"); + } +} + +} // namespace nntrainer diff --git a/nntrainer/src/pooling2d_layer.cpp b/nntrainer/src/pooling2d_layer.cpp index ba0078a..f2deb74 100644 --- a/nntrainer/src/pooling2d_layer.cpp +++ b/nntrainer/src/pooling2d_layer.cpp @@ -179,6 +179,8 @@ void Pooling2DLayer::setBatch(unsigned int batch) { } void Pooling2DLayer::copy(std::shared_ptr l) { + Layer::copy(l); + std::shared_ptr from = std::static_pointer_cast(l); @@ -189,11 +191,6 @@ void Pooling2DLayer::copy(std::shared_ptr l) { this->stride[i] = from->stride[i]; this->padding[i] = from->padding[i]; } - - this->input.copy(from->input); - this->hidden.copy(from->hidden); - this->input_dim = from->input_dim; - this->output_dim = from->output_dim; } void Pooling2DLayer::setProperty(const PropertyType type, diff --git a/nntrainer/src/sgd.cpp b/nntrainer/src/sgd.cpp new file mode 100644 index 0000000..27594e2 --- /dev/null +++ b/nntrainer/src/sgd.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2020 Parichay Kapoor + * + * @file sgd.cpp + * @date 6 October 2020 + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * @brief This is the SGD optimizer. + */ + +#include + +namespace nntrainer { + +void SGD::apply_gradient(Weight &weight, int tensor_idx, double updated_lr, + int iteration) { + Tensor &x = weight.getVariableRef(); + const Tensor &x_grad = weight.getGradientRef(); + x.add_i(x_grad, -updated_lr); +} + +} // namespace nntrainer diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 831284a..282c085 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -338,6 +338,9 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/nntrainer-api-common.h %{_includedir}/nntrainer/blas_interface.h %{_includedir}/nntrainer/weight.h +%{_includedir}/nntrainer/adam.h +%{_includedir}/nntrainer/sgd.h +%{_includedir}/nntrainer/optimizer_factory.h %{_libdir}/pkgconfig/nntrainer.pc %files devel-static diff --git a/test/unittest/unittest_nntrainer_internal.cpp b/test/unittest/unittest_nntrainer_internal.cpp index 33b0141..4a02351 100644 --- a/test/unittest/unittest_nntrainer_internal.cpp +++ b/test/unittest/unittest_nntrainer_internal.cpp @@ -20,13 +20,16 @@ * @author Jijoong Moon * @bug No known bugs */ -#include "databuffer_file.h" -#include "databuffer_func.h" -#include "neuralnet.h" -#include "nntrainer_test_util.h" -#include "util_func.h" #include + +#include +#include +#include #include +#include +#include + +#include /** * @brief Neural Network Model initialization @@ -196,54 +199,28 @@ TEST(nntrainer_NeuralNetwork, init_03_p) { } /** - * @brief Optimizer set type + * @brief Optimizer create */ -TEST(nntrainer_Optimizer, setType_01_p) { - int status = ML_ERROR_NONE; - nntrainer::Optimizer op; - nntrainer::OptType t = nntrainer::OptType::adam; - status = op.setType(t); - EXPECT_EQ(status, ML_ERROR_NONE); +TEST(nntrainer_Optimizer, create_01_p) { + std::shared_ptr op; + EXPECT_NO_THROW(op = createOptimizer(nntrainer::OptType::adam)); } /** - * @brief Optimizer set type + * @brief Optimizer create */ TEST(nntrainer_Optimizer, setType_02_p) { - int status = ML_ERROR_NONE; - nntrainer::Optimizer op; - nntrainer::OptType t = nntrainer::OptType::sgd; - status = op.setType(t); - EXPECT_EQ(status, ML_ERROR_NONE); + std::shared_ptr op; + EXPECT_NO_THROW(op = createOptimizer(nntrainer::OptType::sgd)); } /** - * @brief Optimizer set type + * @brief Optimizer create */ TEST(nntrainer_Optimizer, setType_03_n) { - int status = ML_ERROR_NONE; - nntrainer::Optimizer op; - nntrainer::OptType t = nntrainer::OptType::unknown; - status = op.setType(t); - EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER); -} - -/** - * @brief Optimizer set Opt Param - */ -TEST(nntrainer_Optimizer, setOptParam_01_p) { - int status = ML_ERROR_NONE; - nntrainer::Optimizer op; - nntrainer::OptType t = nntrainer::OptType::adam; - nntrainer::OptParam p; - status = op.setType(t); - EXPECT_EQ(status, ML_ERROR_NONE); - p.learning_rate = -0.001; - p.beta1 = 0.9; - p.beta2 = 0.9999; - p.epsilon = 1e-7; - status = op.setOptParam(p); - EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER); + std::shared_ptr op; + EXPECT_THROW(op = createOptimizer(nntrainer::OptType::unknown), + std::invalid_argument); } /** diff --git a/test/unittest/unittest_nntrainer_layers.cpp b/test/unittest/unittest_nntrainer_layers.cpp index 89545b6..56ef5b4 100644 --- a/test/unittest/unittest_nntrainer_layers.cpp +++ b/test/unittest/unittest_nntrainer_layers.cpp @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include @@ -134,10 +134,10 @@ protected: input_str.push_back((*i).str()); } - nntrainer::Optimizer op; - int status = op.setType(type); - EXPECT_EQ(status, ML_ERROR_NONE); - status = op.setProperty(input_str); + std::shared_ptr op; + EXPECT_NO_THROW(op = createOptimizer(type)); + + status = op->setProperty(input_str); EXPECT_EQ(status, ML_ERROR_NONE); status = layer.setOptimizer(op); EXPECT_EQ(status, ML_ERROR_NONE); @@ -1288,18 +1288,16 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_03_p) { status = layer2.initialize(); EXPECT_EQ(status, ML_ERROR_NONE); - nntrainer::Optimizer op; - int status = op.setType(nntrainer::OptType::sgd); - EXPECT_EQ(status, ML_ERROR_NONE); - status = op.setProperty({"learning_rate=1.0"}); + std::shared_ptr op; + EXPECT_NO_THROW(op = createOptimizer(nntrainer::OptType::sgd)); + status = op->setProperty({"learning_rate=1.0"}); EXPECT_EQ(status, ML_ERROR_NONE); status = layer1.setOptimizer(op); EXPECT_EQ(status, ML_ERROR_NONE); - nntrainer::Optimizer op2; - status = op2.setType(nntrainer::OptType::sgd); - EXPECT_EQ(status, ML_ERROR_NONE); - status = op2.setProperty({"learning_rate=1.0"}); + std::shared_ptr op2; + EXPECT_NO_THROW(op2 = createOptimizer(nntrainer::OptType::sgd)); + status = op2->setProperty({"learning_rate=1.0"}); EXPECT_EQ(status, ML_ERROR_NONE); status = layer2.setOptimizer(op2); EXPECT_EQ(status, ML_ERROR_NONE); -- 2.7.4