} // namespace optimizer
/**
- * @brief Enumeration of learning type
+ * @brief Enumeration of learning rate scheduler type
*/
-enum LearningRateType {
+enum LearningRateSchedulerType {
CONSTANT = 0, /**< constant */
EXPONENTIAL, /**< exponentially decay */
STEP /**< step wise decay */
* @brief Factory creator with constructor for learning rate scheduler type
*/
std::unique_ptr<ml::train::LearningRateScheduler>
-createLearningRateScheduler(const LearningRateType &type,
+createLearningRateScheduler(const LearningRateSchedulerType &type,
const std::vector<std::string> &properties = {});
/**
*/
inline std::unique_ptr<LearningRateScheduler>
Constant(const std::vector<std::string> &properties = {}) {
- return createLearningRateScheduler(LearningRateType::CONSTANT, properties);
+ return createLearningRateScheduler(LearningRateSchedulerType::CONSTANT,
+ properties);
}
/**
*/
inline std::unique_ptr<LearningRateScheduler>
Exponential(const std::vector<std::string> &properties = {}) {
- return createLearningRateScheduler(LearningRateType::EXPONENTIAL, properties);
+ return createLearningRateScheduler(LearningRateSchedulerType::EXPONENTIAL,
+ properties);
}
/**
*/
inline std::unique_ptr<LearningRateScheduler>
Step(const std::vector<std::string> &properties = {}) {
- return createLearningRateScheduler(LearningRateType::STEP, properties);
+ return createLearningRateScheduler(LearningRateSchedulerType::STEP,
+ properties);
}
} // namespace learning_rate
* @brief Factory creator with constructor for learning rate scheduler type
*/
std::unique_ptr<ml::train::LearningRateScheduler>
-createLearningRateScheduler(const LearningRateType &type,
+createLearningRateScheduler(const LearningRateSchedulerType &type,
const std::vector<std::string> &properties) {
auto &ac = nntrainer::AppContext::Global();
return ac.createObject<ml::train::LearningRateScheduler>(type, properties);
ac.registerFactory(AppContext::unknownFactory<nntrainer::Optimizer>,
"unknown", OptType::UNKNOWN);
- using LRType = LearningRateType;
+ using LRType = LearningRateSchedulerType;
ac.registerFactory(
ml::train::createLearningRateScheduler<ConstantLearningRateScheduler>,
ConstantLearningRateScheduler::type, LRType::CONSTANT);