[appcontext] Register learning rate scheduler with app context
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 9 Dec 2021 13:53:08 +0000 (22:53 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Thu, 24 Feb 2022 06:58:55 +0000 (15:58 +0900)
register learning rate scheduler with the app context and regiter its
type factories.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/app_context.cpp
nntrainer/app_context.h
nntrainer/optimizers/lr_scheduler.h
nntrainer/optimizers/lr_scheduler_constant.h

index ad15846..22da6d3 100644 (file)
@@ -49,6 +49,8 @@
 #include <grucell.h>
 #include <identity_layer.h>
 #include <input_layer.h>
+#include <lr_scheduler_constant.h>
+#include <lr_scheduler_exponential.h>
 #include <lstm.h>
 #include <lstmcell.h>
 #include <mol_attention_layer.h>
@@ -218,6 +220,14 @@ static void add_default_object(AppContext &ac) {
   ac.registerFactory(AppContext::unknownFactory<ml::train::Optimizer>,
                      "unknown", OptType::UNKNOWN);
 
+  using LRType = LearningRateType;
+  ac.registerFactory(
+    nntrainer::createLearningRateScheduler<ConstantLearningRateScheduler>,
+    ConstantLearningRateScheduler::type, LRType::CONSTANT);
+  ac.registerFactory(
+    nntrainer::createLearningRateScheduler<ExponentialLearningRateScheduler>,
+    ExponentialLearningRateScheduler::type, LRType::EXPONENTIAL);
+
   using LayerType = ml::train::LayerType;
   ac.registerFactory(nntrainer::createLayer<InputLayer>, InputLayer::type,
                      LayerType::LAYER_IN);
@@ -547,4 +557,12 @@ template const int AppContext::registerFactory<nntrainer::Layer>(
   const FactoryType<nntrainer::Layer> factory, const std::string &key,
   const int int_key);
 
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+template const int
+AppContext::registerFactory<nntrainer::LearningRateScheduler>(
+  const FactoryType<nntrainer::LearningRateScheduler> factory,
+  const std::string &key, const int int_key);
+
 } // namespace nntrainer
index 9b14b5d..36d0f29 100644 (file)
@@ -28,6 +28,7 @@
 
 #include <layer.h>
 #include <layer_devel.h>
+#include <lr_scheduler.h>
 #include <optimizer.h>
 
 #include <nntrainer_error.h>
@@ -263,7 +264,9 @@ public:
   }
 
 private:
-  FactoryMap<ml::train::Optimizer, nntrainer::Layer> factory_map;
+  FactoryMap<ml::train::Optimizer, nntrainer::Layer,
+             nntrainer::LearningRateScheduler>
+    factory_map;
   std::string working_path_base;
 
   template <typename Args, typename T> struct isSupportedHelper;
@@ -298,6 +301,14 @@ extern template const int AppContext::registerFactory<nntrainer::Layer>(
   const FactoryType<nntrainer::Layer> factory, const std::string &key,
   const int int_key);
 
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+extern template const int
+AppContext::registerFactory<nntrainer::LearningRateScheduler>(
+  const FactoryType<nntrainer::LearningRateScheduler> factory,
+  const std::string &key, const int int_key);
+
 namespace plugin {}
 
 } // namespace nntrainer
index 2663929..5711570 100644 (file)
@@ -23,6 +23,14 @@ class Exporter;
 enum class ExportMethods;
 
 /**
+ * @brief     Enumeration of optimizer type
+ */
+enum LearningRateType {
+  CONSTANT = 0, /** constant */
+  EXPONENTIAL   /** exponentially decay */
+};
+
+/**
  * @class   Learning Rate Schedulers Base class
  * @brief   Base class for all Learning Rate Schedulers
  */
@@ -94,6 +102,22 @@ public:
   virtual const std::string getType() const = 0;
 };
 
+/**
+ * @brief General LR Scheduler Factory function to create LR Scheduler
+ *
+ * @param props property representation
+ * @return std::unique_ptr<nntrainer::LearningRateScheduler> created object
+ */
+template <typename T,
+          std::enable_if_t<std::is_base_of<LearningRateScheduler, T>::value, T>
+            * = nullptr>
+std::unique_ptr<LearningRateScheduler>
+createLearningRateScheduler(const std::vector<std::string> &props = {}) {
+  std::unique_ptr<LearningRateScheduler> ptr = std::make_unique<T>();
+  ptr->setProperty(props);
+  return ptr;
+}
+
 } /* namespace nntrainer */
 
 #endif /* __cplusplus */
index 405967f..bbb8c92 100644 (file)
@@ -17,6 +17,7 @@
 
 #include <string>
 
+#include <common_properties.h>
 #include <lr_scheduler.h>
 
 namespace nntrainer {