From 752f640a4764877c93128fab4bc8086f2a22e1a0 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Fri, 10 Dec 2021 12:48:40 +0900 Subject: [PATCH] [lr-scheduler] Add finalize to interface Add finalize to the interface of the learning rate scheduler with the purpose to verify that the required properties have been set. Signed-off-by: Parichay Kapoor --- nntrainer/optimizers/lr_scheduler.h | 10 ++++++++++ nntrainer/optimizers/lr_scheduler_constant.cpp | 6 ++++++ nntrainer/optimizers/lr_scheduler_constant.h | 6 ++++++ nntrainer/optimizers/lr_scheduler_exponential.cpp | 10 ++++++++++ nntrainer/optimizers/lr_scheduler_exponential.h | 11 +++++++++-- 5 files changed, 41 insertions(+), 2 deletions(-) diff --git a/nntrainer/optimizers/lr_scheduler.h b/nntrainer/optimizers/lr_scheduler.h index 6395271..2663929 100644 --- a/nntrainer/optimizers/lr_scheduler.h +++ b/nntrainer/optimizers/lr_scheduler.h @@ -35,6 +35,16 @@ public: virtual ~LearningRateScheduler() = default; /** + * @brief Finalize creating the learning rate scheduler + * + * @details Verify that all the needed properties have been and within the + * valid range. + * @note After calling this it is not allowed to + * change properties. + */ + virtual void finalize() = 0; + + /** * @brief get Learning Rate for the given iteration * @param[in] iteration Iteration for the learning rate * @retval Learning rate in double diff --git a/nntrainer/optimizers/lr_scheduler_constant.cpp b/nntrainer/optimizers/lr_scheduler_constant.cpp index aa09481..2019cbc 100644 --- a/nntrainer/optimizers/lr_scheduler_constant.cpp +++ b/nntrainer/optimizers/lr_scheduler_constant.cpp @@ -24,6 +24,12 @@ namespace nntrainer { ConstantLearningRateScheduler::ConstantLearningRateScheduler() : lr_props(props::LearningRate()) {} +void ConstantLearningRateScheduler::finalize() { + NNTR_THROW_IF(std::get(lr_props).empty(), + std::invalid_argument) + << "[ConstantLearningRateScheduler] Learning Rate is not set"; +} + void ConstantLearningRateScheduler::setProperty( const std::vector &values) { auto left = loadProperties(values, lr_props); diff --git a/nntrainer/optimizers/lr_scheduler_constant.h b/nntrainer/optimizers/lr_scheduler_constant.h index 2a91d39..405967f 100644 --- a/nntrainer/optimizers/lr_scheduler_constant.h +++ b/nntrainer/optimizers/lr_scheduler_constant.h @@ -41,6 +41,12 @@ public: virtual double getLearningRate(size_t iteration) override; /** + * @copydoc LearningRateScheduler::finalize() + * + */ + virtual void finalize() override; + + /** * @copydoc LearningRateScheduler::exportTo(Exporter &exporter, const * ExportMethods& method) * diff --git a/nntrainer/optimizers/lr_scheduler_exponential.cpp b/nntrainer/optimizers/lr_scheduler_exponential.cpp index f1b5c59..4b6a25f 100644 --- a/nntrainer/optimizers/lr_scheduler_exponential.cpp +++ b/nntrainer/optimizers/lr_scheduler_exponential.cpp @@ -24,6 +24,16 @@ namespace nntrainer { ExponentialLearningRateScheduler::ExponentialLearningRateScheduler() : lr_props(props::DecayRate(), props::DecaySteps()) {} +void ExponentialLearningRateScheduler::finalize() { + NNTR_THROW_IF(std::get(lr_props).empty(), + std::invalid_argument) + << "[ConstantLearningRateScheduler] Decay Rate is not set"; + NNTR_THROW_IF(std::get(lr_props).empty(), + std::invalid_argument) + << "[ConstantLearningRateScheduler] Decay Steps is not set"; + ConstantLearningRateScheduler::finalize(); +} + void ExponentialLearningRateScheduler::setProperty( const std::vector &values) { auto left = loadProperties(values, lr_props); diff --git a/nntrainer/optimizers/lr_scheduler_exponential.h b/nntrainer/optimizers/lr_scheduler_exponential.h index 0eed038..2aa0491 100644 --- a/nntrainer/optimizers/lr_scheduler_exponential.h +++ b/nntrainer/optimizers/lr_scheduler_exponential.h @@ -25,7 +25,8 @@ namespace nntrainer { * @class Constant Learning Rate Scheduler class * @brief class for constant Learning Rate Schedulers */ -class ExponentialLearningRateScheduler : public ConstantLearningRateScheduler { +class ExponentialLearningRateScheduler final + : public ConstantLearningRateScheduler { public: /** @@ -38,7 +39,13 @@ public: * @copydoc LearningRateScheduler::getLearningRate(size_t iteration) const * */ - virtual double getLearningRate(size_t iteration) override; + double getLearningRate(size_t iteration) override; + + /** + * @copydoc LearningRateScheduler::finalize() + * + */ + void finalize() override; /** * @copydoc LearningRateScheduler::exportTo(Exporter &exporter, const -- 2.7.4