From f01a3b2a7f39c03c6532ae7db4b5dc2daba07744 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Thu, 9 Dec 2021 20:54:32 +0900 Subject: [PATCH] [lr] Add interface for learning rate scheduler Add interface for the learning rate scheduler which all learning rate schedulers must abide by. Signed-off-by: Parichay Kapoor --- nntrainer/optimizers/lr_scheduler.h | 90 +++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 nntrainer/optimizers/lr_scheduler.h diff --git a/nntrainer/optimizers/lr_scheduler.h b/nntrainer/optimizers/lr_scheduler.h new file mode 100644 index 0000000..6395271 --- /dev/null +++ b/nntrainer/optimizers/lr_scheduler.h @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2021 Parichay Kapoor + * + * @file lr_scheduler.h + * @date 09 December 2021 + * @brief This is Learning Rate Scheduler interface class + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @bug No known bugs except for NYI items + * + */ + +#ifndef __LEARNING_RATE_SCHEDULER__ +#define __LEARNING_RATE_SCHEDULER__ +#ifdef __cplusplus + +#include + +namespace nntrainer { + +class Exporter; +enum class ExportMethods; + +/** + * @class Learning Rate Schedulers Base class + * @brief Base class for all Learning Rate Schedulers + */ +class LearningRateScheduler { + +public: + /** + * @brief Destructor of learning rate scheduler Class + */ + virtual ~LearningRateScheduler() = default; + + /** + * @brief get Learning Rate for the given iteration + * @param[in] iteration Iteration for the learning rate + * @retval Learning rate in double + * @detail the return value of this function and getInitialLearningRate() + * may not match for iteration == 0 (warmup can lead to different initial + * learning rates). + * + * @note this is non-const function intentionally. + */ + virtual double getLearningRate(size_t iteration) = 0; + + /** + * @brief this function helps exporting the learning rate in a predefined + * format, while workarounding issue caused by templated function type eraser + * + * @param exporter exporter that conatins exporting logic + * @param method enum value to identify how it should be exported to + */ + virtual void exportTo(Exporter &exporter, const ExportMethods &method) const { + } + + /** + * @brief Default allowed properties + * Constant Learning rate scheduler + * - learning_rate : float + * + * Exponential Learning rate scheduler + * - learning_rate : float + * - decay_rate : float, + * - decay_steps : float, + * + * more to be added + */ + + /** + * @brief set learning rate scheduler properties + * @param[in] values learning rate scheduler properties list + * @details This function accepts vector of properties in the format - + * { std::string property_name = std::string property_val, ...} + */ + virtual void setProperty(const std::vector &values) = 0; + + /** + * @brief get learning rate scheduler Type + * @retval learning rate scheduler type + */ + virtual const std::string getType() const = 0; +}; + +} /* namespace nntrainer */ + +#endif /* __cplusplus */ +#endif /* __LEARNING_RATE_SCHEDULER__ */ -- 2.7.4