#include <grucell.h>
#include <identity_layer.h>
#include <input_layer.h>
+#include <lr_scheduler_constant.h>
+#include <lr_scheduler_exponential.h>
#include <lstm.h>
#include <lstmcell.h>
#include <mol_attention_layer.h>
ac.registerFactory(AppContext::unknownFactory<ml::train::Optimizer>,
"unknown", OptType::UNKNOWN);
+ using LRType = LearningRateType;
+ ac.registerFactory(
+ nntrainer::createLearningRateScheduler<ConstantLearningRateScheduler>,
+ ConstantLearningRateScheduler::type, LRType::CONSTANT);
+ ac.registerFactory(
+ nntrainer::createLearningRateScheduler<ExponentialLearningRateScheduler>,
+ ExponentialLearningRateScheduler::type, LRType::EXPONENTIAL);
+
using LayerType = ml::train::LayerType;
ac.registerFactory(nntrainer::createLayer<InputLayer>, InputLayer::type,
LayerType::LAYER_IN);
const FactoryType<nntrainer::Layer> factory, const std::string &key,
const int int_key);
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+template const int
+AppContext::registerFactory<nntrainer::LearningRateScheduler>(
+ const FactoryType<nntrainer::LearningRateScheduler> factory,
+ const std::string &key, const int int_key);
+
} // namespace nntrainer
#include <layer.h>
#include <layer_devel.h>
+#include <lr_scheduler.h>
#include <optimizer.h>
#include <nntrainer_error.h>
}
private:
- FactoryMap<ml::train::Optimizer, nntrainer::Layer> factory_map;
+ FactoryMap<ml::train::Optimizer, nntrainer::Layer,
+ nntrainer::LearningRateScheduler>
+ factory_map;
std::string working_path_base;
template <typename Args, typename T> struct isSupportedHelper;
const FactoryType<nntrainer::Layer> factory, const std::string &key,
const int int_key);
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+extern template const int
+AppContext::registerFactory<nntrainer::LearningRateScheduler>(
+ const FactoryType<nntrainer::LearningRateScheduler> factory,
+ const std::string &key, const int int_key);
+
namespace plugin {}
} // namespace nntrainer
enum class ExportMethods;
/**
+ * @brief Enumeration of optimizer type
+ */
+enum LearningRateType {
+ CONSTANT = 0, /** constant */
+ EXPONENTIAL /** exponentially decay */
+};
+
+/**
* @class Learning Rate Schedulers Base class
* @brief Base class for all Learning Rate Schedulers
*/
virtual const std::string getType() const = 0;
};
+/**
+ * @brief General LR Scheduler Factory function to create LR Scheduler
+ *
+ * @param props property representation
+ * @return std::unique_ptr<nntrainer::LearningRateScheduler> created object
+ */
+template <typename T,
+ std::enable_if_t<std::is_base_of<LearningRateScheduler, T>::value, T>
+ * = nullptr>
+std::unique_ptr<LearningRateScheduler>
+createLearningRateScheduler(const std::vector<std::string> &props = {}) {
+ std::unique_ptr<LearningRateScheduler> ptr = std::make_unique<T>();
+ ptr->setProperty(props);
+ return ptr;
+}
+
} /* namespace nntrainer */
#endif /* __cplusplus */
#include <string>
+#include <common_properties.h>
#include <lr_scheduler.h>
namespace nntrainer {