From e9a3c65d1f9db4aa2e19b447fd93cb5c57f54aaa Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Thu, 9 Dec 2021 21:17:07 +0900 Subject: [PATCH] [lr] Support constant learning rate scheduler Support constant learning rate scheduler. Signed-off-by: Parichay Kapoor --- nntrainer/optimizers/lr_scheduler_constant.cpp | 43 +++++++++++++++ nntrainer/optimizers/lr_scheduler_constant.h | 73 ++++++++++++++++++++++++++ nntrainer/optimizers/meson.build | 5 +- 3 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 nntrainer/optimizers/lr_scheduler_constant.cpp create mode 100644 nntrainer/optimizers/lr_scheduler_constant.h diff --git a/nntrainer/optimizers/lr_scheduler_constant.cpp b/nntrainer/optimizers/lr_scheduler_constant.cpp new file mode 100644 index 0000000..aa09481 --- /dev/null +++ b/nntrainer/optimizers/lr_scheduler_constant.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Parichay Kapoor + * + * @file lr_scheduler_constant.cpp + * @date 09 December 2021 + * @brief This is Constant Learning Rate Scheduler class + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * + */ + +#include + +#include +#include +#include +#include +#include + +namespace nntrainer { + +ConstantLearningRateScheduler::ConstantLearningRateScheduler() : + lr_props(props::LearningRate()) {} + +void ConstantLearningRateScheduler::setProperty( + const std::vector &values) { + auto left = loadProperties(values, lr_props); + NNTR_THROW_IF(left.size(), std::invalid_argument) + << "[ConstantLearningRateScheduler] There are unparsed properties"; +} + +void ConstantLearningRateScheduler::exportTo( + Exporter &exporter, const ExportMethods &method) const { + exporter.saveResult(lr_props, method, this); +} + +double ConstantLearningRateScheduler::getLearningRate(size_t iteration) { + return std::get(lr_props); +} + +} // namespace nntrainer diff --git a/nntrainer/optimizers/lr_scheduler_constant.h b/nntrainer/optimizers/lr_scheduler_constant.h new file mode 100644 index 0000000..2a91d39 --- /dev/null +++ b/nntrainer/optimizers/lr_scheduler_constant.h @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Parichay Kapoor + * + * @file lr_scheduler_constant.h + * @date 09 December 2021 + * @brief This is Constant Learning Rate Scheduler class + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * + */ + +#ifndef __LEARNING_RATE_SCHEDULER_CONSTANT__ +#define __LEARNING_RATE_SCHEDULER_CONSTANT__ +#ifdef __cplusplus + +#include + +#include + +namespace nntrainer { + +/** + * @class Constant Learning Rate Scheduler class + * @brief class for constant Learning Rate Schedulers + */ +class ConstantLearningRateScheduler : public LearningRateScheduler { + +public: + /** + * @brief Construct a new constant learning rate scheduler object + * + */ + ConstantLearningRateScheduler(); + + /** + * @copydoc LearningRateScheduler::getLearningRate(size_t iteration) const + * + */ + virtual double getLearningRate(size_t iteration) override; + + /** + * @copydoc LearningRateScheduler::exportTo(Exporter &exporter, const + * ExportMethods& method) + * + */ + void exportTo(Exporter &exporter, const ExportMethods &method) const override; + + /** + * @copydoc LearningRateScheduler::setProperty(const std::vector + * &values) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc LearningRateScheduler::getType() const + * + */ + const std::string getType() const override { + return ConstantLearningRateScheduler::type; + } + + inline static const std::string type = "constant"; + +private: + std::tuple lr_props; +}; + +} /* namespace nntrainer */ + +#endif /* __cplusplus */ +#endif /* __LEARNING_RATE_SCHEDULER_CONSTANT__ */ diff --git a/nntrainer/optimizers/meson.build b/nntrainer/optimizers/meson.build index 5269560..1d0de25 100644 --- a/nntrainer/optimizers/meson.build +++ b/nntrainer/optimizers/meson.build @@ -3,13 +3,14 @@ optimizer_sources = [ 'optimizer_devel.cpp', 'optimizer_impl.cpp', 'sgd.cpp', - 'optimizer_context.cpp' + 'optimizer_context.cpp', + 'lr_scheduler_constant.cpp' ] optimizer_headers = [ 'optimizer_devel.h', 'optimizer_impl.h', - 'optimizer_context.h' + 'optimizer_context.h', ] foreach s : optimizer_sources -- 2.7.4