[optimizer] Add optimizer wrapped
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 10 Dec 2021 12:29:28 +0000 (21:29 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 20 Apr 2022 10:47:00 +0000 (19:47 +0900)
Add optimizer wrapped which wraps the opitmizer and the learning rate
scheduler.
In order to be backward compatible, each optimizer must support setting
the learning rate, decay rate and decay steps, even for new optimizers.
To make this extensible without each optimizer storing this information
and merging with the learning rate schedulers, and not creating new
interfaces, optimizer wrapped is added.
Optimizer wraps around optimizer, and owns both the optimizer and
learning rate scheduler. If the properties of LR or decay are passed to
the optimizer, they are intercepted by the optimizer wrapped and passed
to the learning rate scheduler appropriately.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/optimizers/meson.build
nntrainer/optimizers/optimizer_wrapped.cpp [new file with mode: 0644]
nntrainer/optimizers/optimizer_wrapped.h [new file with mode: 0644]

index 9f5e337..0e0431b 100644 (file)
@@ -5,7 +5,8 @@ optimizer_sources = [
   'sgd.cpp',
   'optimizer_context.cpp',
   'lr_scheduler_constant.cpp',
-  'lr_scheduler_exponential.cpp'
+  'lr_scheduler_exponential.cpp',
+  'optimizer_wrapped.cpp'
 ]
 
 optimizer_headers = [
diff --git a/nntrainer/optimizers/optimizer_wrapped.cpp b/nntrainer/optimizers/optimizer_wrapped.cpp
new file mode 100644 (file)
index 0000000..9f7d983
--- /dev/null
@@ -0,0 +1,144 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   optimizer_wrapped.cpp
+ * @date   10 December 2021
+ * @brief  This is Optimizer Wrapped interface class
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ * @details wraps the optimizer and learning rate scheduler together
+ */
+
+#include <app_context.h>
+#include <common_properties.h>
+#include <lr_scheduler_constant.h>
+#include <lr_scheduler_exponential.h>
+#include <node_exporter.h>
+#include <optimizer_wrapped.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Optimizer wrapped creator with constructor for optimizer
+ */
+std::unique_ptr<OptimizerWrapped>
+createOptimizerWrapped(const ml::train::OptimizerType &type,
+                       const std::vector<std::string> &properties) {
+  auto &ac = nntrainer::AppContext::Global();
+  return createOptimizerWrapped(ac.createObject<OptimizerCore>(type),
+                                properties);
+}
+
+/**
+ * @brief Optimizer wrapped creator with constructor for optimizer
+ */
+std::unique_ptr<OptimizerWrapped>
+createOptimizerWrapped(const std::string &type,
+                       const std::vector<std::string> &properties) {
+  auto &ac = nntrainer::AppContext::Global();
+  return createOptimizerWrapped(ac.createObject<OptimizerCore>(type),
+                                properties);
+}
+
+/**
+ * @brief Optimizer wrapped creator with constructor for optimizer
+ */
+std::unique_ptr<OptimizerWrapped>
+createOptimizerWrapped(std::unique_ptr<OptimizerCore> &&opt,
+                       const std::vector<std::string> &properties) {
+  auto opt_wrapped = std::make_unique<OptimizerWrapped>(std::move(opt));
+
+  opt_wrapped->setProperty(properties);
+  return opt_wrapped;
+}
+
+OptimizerWrapped::OptimizerWrapped(std::unique_ptr<OptimizerCore> &&opt) :
+  optimizer(std::move(opt)),
+  lr_sched(),
+  props(props::LearningRate(), props::DecayRate(), props::DecaySteps()) {}
+
+const std::string OptimizerWrapped::getType() const {
+  return optimizer->getType();
+}
+
+void OptimizerWrapped::setProperty(const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, props);
+  optimizer->setProperty(remain_props);
+}
+
+double OptimizerWrapped::getLearningRate(size_t iteration) {
+  return lr_sched->getLearningRate(iteration);
+}
+
+void OptimizerWrapped::applyGradient(RunOptimizerContext &context) {
+  // optimizer->applyGradient(context);
+}
+
+void OptimizerWrapped::exportTo(Exporter &exporter,
+                                const ExportMethods &method) const {
+  // optimizer->exportTo(exporter, method);
+  lr_sched->exportTo(exporter, method);
+}
+
+void OptimizerWrapped::finalize() {
+  auto const &props_lr = std::get<props::LearningRate>(props);
+  auto const &props_dr = std::get<props::DecayRate>(props);
+  auto const &props_ds = std::get<props::DecaySteps>(props);
+
+  /** if lr_sched already set and property not empty, error */
+  bool props_empty = props_lr.empty() & props_dr.empty() & props_ds.empty();
+
+  NNTR_THROW_IF(props_empty && !lr_sched, std::invalid_argument)
+    << "Learning rate scheduler not set for the optimizer " << getType();
+  NNTR_THROW_IF(!props_empty && lr_sched, std::invalid_argument)
+    << "Multiple learning rate schedulers set for the optimizer " << getType();
+
+  /** if lr_sched not set, make lr_sched from properties */
+  if (!props_empty) {
+    if (!props_dr.empty() || !props_ds.empty()) {
+      lr_sched = std::make_unique<ExponentialLearningRateScheduler>();
+      if (!props_dr.empty())
+        lr_sched->setProperty({"decay_rate=" + std::to_string(props_dr.get())});
+      if (!props_ds.empty())
+        lr_sched->setProperty(
+          {"decay_steps=" + std::to_string(props_ds.get())});
+    } else {
+      lr_sched = std::make_unique<ConstantLearningRateScheduler>();
+    }
+
+    if (!props_lr.empty())
+      lr_sched->setProperty(
+        {"learning_rate=" + std::to_string(props_lr.get())});
+  }
+
+  lr_sched->finalize();
+  // optimizer->finalize();
+}
+
+void OptimizerWrapped::read(std::ifstream &file) {
+  // optimizer->read(file);
+}
+
+void OptimizerWrapped::save(std::ofstream &file) {
+  // optimizer->save(file);
+}
+
+std::vector<TensorDim>
+OptimizerWrapped::getOptimizerVariableDim(const TensorDim &dim) {
+  return {};
+  // return optimizer->getOptimizerVariableDim(dim);
+}
+
+void OptimizerWrapped::setLearningRateScheduler(
+  std::unique_ptr<nntrainer::LearningRateScheduler> &&lrs) {
+  lr_sched = std::move(lrs);
+}
+
+nntrainer::LearningRateScheduler *OptimizerWrapped::setLearningRateScheduler() {
+  return lr_sched.get();
+}
+
+} // namespace nntrainer
\ No newline at end of file
diff --git a/nntrainer/optimizers/optimizer_wrapped.h b/nntrainer/optimizers/optimizer_wrapped.h
new file mode 100644 (file)
index 0000000..2fd4d3d
--- /dev/null
@@ -0,0 +1,198 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   optimizer_wrapped.h
+ * @date   10 December 2021
+ * @brief  This is Optimizer Wrapped interface class
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ * @details wraps the optimizer and learning rate scheduler together
+ */
+
+#ifndef __OPTIMIZER_WRAPPER_H__
+#define __OPTIMIZER_WRAPPER_H__
+
+#if __cplusplus
+
+#include <string>
+#include <vector>
+
+#include <lr_scheduler.h>
+#include <optimizer.h>
+#include <optimizer_devel.h>
+
+namespace nntrainer {
+
+namespace props {
+class LearningRate;
+class DecaySteps;
+class DecayRate;
+} // namespace props
+
+/** TODO: change to nntrainer::Optimizer */
+using OptimizerCore = ml::train::Optimizer;
+
+/**
+ * @class   Optimizer Base class for optimizers
+ * @brief   Base class for all optimizers
+ */
+class OptimizerWrapped : public ml::train::Optimizer {
+public:
+  /**
+   * @brief Constructor of OptimizerWrapped class
+   * @param opt optimizer to wrap
+   *
+   */
+  OptimizerWrapped(std::unique_ptr<OptimizerCore> &&opt);
+
+  /**
+   * @brief     Destructor of Optimizer Class
+   */
+  ~OptimizerWrapped() = default;
+
+  /**
+   * Support all the interface requirements by ml::train::Optimizer
+   */
+
+  /**
+   * @brief     get Optimizer Type
+   * @retval    Optimizer type
+   */
+  const std::string getType() const;
+
+  /**
+   * @brief     Default allowed properties
+   * Available for all optimizers
+   * - learning_rate : float
+   *
+   * Available for SGD and Adam optimizers
+   * - decay_rate : float,
+   * - decay_steps : float,
+   *
+   * Available for Adam optimizer
+   * - beta1 : float,
+   * - beta2 : float,
+   * - epsilon : float,
+   */
+
+  /**
+   * @brief     set Optimizer Parameters
+   * @param[in] values Optimizer Parameter list
+   * @details   This function accepts vector of properties in the format -
+   *  { std::string property_name, void * property_val, ...}
+   */
+  void setProperty(const std::vector<std::string> &values);
+
+  /**
+   * Support all the interface requirements by nntrainer::Optimizer
+   */
+
+  /**
+   * @brief     get Learning Rate for the given iteration
+   * @param[in] iteration Iteration for the learning rate
+   * @retval    Learning rate in double
+   * @detail    the return value of this function and getLearningRate() must
+   * match for iteration == 0.
+   */
+  double getLearningRate(size_t iteration);
+
+  /**
+   * @brief     apply gradient to weight
+   * @param[in] context Optimizer context
+   */
+  void applyGradient(RunOptimizerContext &context);
+
+  /**
+   * @brief this function helps exporting the optimizer in a predefined format,
+   * while workarounding issue caused by templated function type eraser
+   *
+   * @param     exporter exporter that conatins exporting logic
+   * @param     method enum value to identify how it should be exported to
+   */
+  void exportTo(Exporter &exporter, const ExportMethods &method) const;
+
+  /**
+   * @brief     finalize optimizer.
+   */
+  void finalize();
+
+  /**
+   * @brief     Read Training optimizer paramters from file
+   * @param[in] file input stream file
+   */
+  void read(std::ifstream &file);
+
+  /**
+   * @brief     Save Training optimizer paramters from file
+   * @param[in] file output stream file
+   */
+  void save(std::ofstream &file);
+
+  /**
+   * @brief     Get dimension of extra variables if the optimizer needs any.
+   * @param dim Dimension of tensor to be added as a optimizer variable
+   * @return    Vector of dimensions
+   */
+  std::vector<TensorDim> getOptimizerVariableDim(const TensorDim &dim);
+
+  /**
+   * @brief Set the Learning Rate Scheduler object
+   *
+   * @param lrs the learning rate scheduler object
+   */
+  void setLearningRateScheduler(
+    std::unique_ptr<nntrainer::LearningRateScheduler> &&lrs);
+
+  /**
+   * @brief Get the Learning Rate Scheduler object
+   *
+   * @return the learning rate scheduler object
+   */
+  nntrainer::LearningRateScheduler *setLearningRateScheduler();
+
+private:
+  std::unique_ptr<OptimizerCore> optimizer; /**< the underlying optimizer */
+  std::unique_ptr<nntrainer::LearningRateScheduler>
+    lr_sched; /**< the underlying learning rate scheduler */
+
+  std::tuple<props::LearningRate, props::DecayRate, props::DecaySteps>
+    props; /**< lr scheduler props for backward compatibility */
+};
+
+/**
+ * @brief Optimizer wrapped creator with constructor for optimizer
+ *
+ * @params[in] type Type of the optimizer to be constructed
+ * @params[in] properties Properties of the optimizer
+ */
+std::unique_ptr<OptimizerWrapped>
+createOptimizerWrapped(const ml::train::OptimizerType &type,
+                       const std::vector<std::string> &properties = {});
+
+/**
+ * @brief Optimizer wrapped creator with constructor for optimizer
+ *
+ * @params[in] type Type of the optimizer to be constructed
+ * @params[in] properties Properties of the optimizer
+ */
+std::unique_ptr<OptimizerWrapped>
+createOptimizerWrapped(const std::string &type,
+                       const std::vector<std::string> &properties = {});
+
+/**
+ * @brief Optimizer wrapped creator with constructor for optimizer
+ *
+ * @params[in] type Type of the optimizer to be constructed
+ * @params[in] properties Properties of the optimizer
+ */
+std::unique_ptr<OptimizerWrapped>
+createOptimizerWrapped(std::unique_ptr<OptimizerCore> &&opt,
+                       const std::vector<std::string> &properties = {});
+
+} // namespace nntrainer
+
+#endif // __cpluscplus
+#endif // __OPTIMIZER_WRAPPER_H__