[lr] Support constant learning rate scheduler
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 9 Dec 2021 12:17:07 +0000 (21:17 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 17 Feb 2022 04:14:41 +0000 (13:14 +0900)
Support constant learning rate scheduler.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/optimizers/lr_scheduler_constant.cpp [new file with mode: 0644]
nntrainer/optimizers/lr_scheduler_constant.h [new file with mode: 0644]
nntrainer/optimizers/meson.build

diff --git a/nntrainer/optimizers/lr_scheduler_constant.cpp b/nntrainer/optimizers/lr_scheduler_constant.cpp
new file mode 100644 (file)
index 0000000..aa09481
--- /dev/null
@@ -0,0 +1,43 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   lr_scheduler_constant.cpp
+ * @date   09 December 2021
+ * @brief  This is Constant Learning Rate Scheduler class
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#include <cmath>
+
+#include <common_properties.h>
+#include <lr_scheduler_constant.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+
+namespace nntrainer {
+
+ConstantLearningRateScheduler::ConstantLearningRateScheduler() :
+  lr_props(props::LearningRate()) {}
+
+void ConstantLearningRateScheduler::setProperty(
+  const std::vector<std::string> &values) {
+  auto left = loadProperties(values, lr_props);
+  NNTR_THROW_IF(left.size(), std::invalid_argument)
+    << "[ConstantLearningRateScheduler] There are unparsed properties";
+}
+
+void ConstantLearningRateScheduler::exportTo(
+  Exporter &exporter, const ExportMethods &method) const {
+  exporter.saveResult(lr_props, method, this);
+}
+
+double ConstantLearningRateScheduler::getLearningRate(size_t iteration) {
+  return std::get<props::LearningRate>(lr_props);
+}
+
+} // namespace nntrainer
diff --git a/nntrainer/optimizers/lr_scheduler_constant.h b/nntrainer/optimizers/lr_scheduler_constant.h
new file mode 100644 (file)
index 0000000..2a91d39
--- /dev/null
@@ -0,0 +1,73 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   lr_scheduler_constant.h
+ * @date   09 December 2021
+ * @brief  This is Constant Learning Rate Scheduler class
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __LEARNING_RATE_SCHEDULER_CONSTANT__
+#define __LEARNING_RATE_SCHEDULER_CONSTANT__
+#ifdef __cplusplus
+
+#include <string>
+
+#include <lr_scheduler.h>
+
+namespace nntrainer {
+
+/**
+ * @class   Constant Learning Rate Scheduler class
+ * @brief   class for constant Learning Rate Schedulers
+ */
+class ConstantLearningRateScheduler : public LearningRateScheduler {
+
+public:
+  /**
+   * @brief Construct a new constant learning rate scheduler object
+   *
+   */
+  ConstantLearningRateScheduler();
+
+  /**
+   * @copydoc LearningRateScheduler::getLearningRate(size_t iteration) const
+   *
+   */
+  virtual double getLearningRate(size_t iteration) override;
+
+  /**
+   * @copydoc LearningRateScheduler::exportTo(Exporter &exporter, const
+   * ExportMethods& method)
+   *
+   */
+  void exportTo(Exporter &exporter, const ExportMethods &method) const override;
+
+  /**
+   * @copydoc LearningRateScheduler::setProperty(const std::vector<std::string>
+   * &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc LearningRateScheduler::getType() const
+   *
+   */
+  const std::string getType() const override {
+    return ConstantLearningRateScheduler::type;
+  }
+
+  inline static const std::string type = "constant";
+
+private:
+  std::tuple<props::LearningRate> lr_props;
+};
+
+} /* namespace nntrainer */
+
+#endif /* __cplusplus */
+#endif /* __LEARNING_RATE_SCHEDULER_CONSTANT__ */
index 5269560..1d0de25 100644 (file)
@@ -3,13 +3,14 @@ optimizer_sources = [
   'optimizer_devel.cpp',
   'optimizer_impl.cpp',
   'sgd.cpp',
-  'optimizer_context.cpp'
+  'optimizer_context.cpp',
+  'lr_scheduler_constant.cpp'
 ]
 
 optimizer_headers = [
   'optimizer_devel.h',
   'optimizer_impl.h',
-  'optimizer_context.h'
+  'optimizer_context.h',
 ]
 
 foreach s : optimizer_sources