[optimizer] Refactor optimizer
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 8 Oct 2020 06:42:32 +0000 (15:42 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 13 Oct 2020 03:16:18 +0000 (12:16 +0900)
This patch refactors optimizer in the following fashion:
- split optimizer implementations of different types to derived classes of adam and sgd
- create a optimizer_factory to create the optimizer class objs based on its type
  - this can be directly used with ccapi
- OptParam struct has been removed
- applyGradients has been broken down into different methods
- updated associated unittests

**Self evaluation:**
1. Build test: [x]Passed [ ]Failed [ ]Skipped
2. Run test: [x]Passed [ ]Failed [ ]Skipped

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
33 files changed:
api/capi/src/nntrainer.cpp
jni/Android.mk
nntrainer/include/activation_layer.h
nntrainer/include/adam.h [new file with mode: 0644]
nntrainer/include/addition_layer.h
nntrainer/include/flatten_layer.h
nntrainer/include/input_layer.h
nntrainer/include/layer.h
nntrainer/include/neuralnet.h
nntrainer/include/optimizer.h
nntrainer/include/optimizer_factory.h [new file with mode: 0644]
nntrainer/include/sgd.h [new file with mode: 0644]
nntrainer/include/weight.h
nntrainer/meson.build
nntrainer/src/activation_layer.cpp
nntrainer/src/adam.cpp [new file with mode: 0644]
nntrainer/src/addition_layer.cpp
nntrainer/src/bn_layer.cpp
nntrainer/src/conv2d_layer.cpp
nntrainer/src/fc_layer.cpp
nntrainer/src/flatten_layer.cpp
nntrainer/src/input_layer.cpp
nntrainer/src/layer.cpp
nntrainer/src/loss_layer.cpp
nntrainer/src/model_loader.cpp
nntrainer/src/neuralnet.cpp
nntrainer/src/optimizer.cpp
nntrainer/src/optimizer_factory.cpp [new file with mode: 0644]
nntrainer/src/pooling2d_layer.cpp
nntrainer/src/sgd.cpp [new file with mode: 0644]
packaging/nntrainer.spec
test/unittest/unittest_nntrainer_internal.cpp
test/unittest/unittest_nntrainer_layers.cpp

index 15a5770..05d0287 100644 (file)
@@ -28,6 +28,7 @@
 #include <nntrainer_error.h>
 #include <nntrainer_internal.h>
 #include <nntrainer_log.h>
+#include <optimizer_factory.h>
 #include <parse_util.h>
 #include <sstream>
 #include <stdarg.h>
@@ -610,26 +611,19 @@ int ml_train_optimizer_create(ml_train_optimizer_h *optimizer,
 
   ml_train_optimizer *nnopt = new ml_train_optimizer;
   nnopt->magic = ML_NNTRAINER_MAGIC;
-
-  status =
-    exception_bounded_make_shared<nntrainer::Optimizer>(nnopt->optimizer);
-  if (status != ML_ERROR_NONE) {
-    delete nnopt;
-    ml_loge("creating optimizer failed");
-    return status;
-  }
-
   nnopt->in_use = false;
 
-  *optimizer = nnopt;
-
   returnable f = [&]() {
-    return nnopt->optimizer->setType(ml_optimizer_to_nntrainer_type(type));
+    nnopt->optimizer = createOptimizer(ml_optimizer_to_nntrainer_type(type));
+    return ML_ERROR_NONE;
   };
-  status = nntrainer_exception_boundary(f);
 
+  status = nntrainer_exception_boundary(f);
   if (status != ML_ERROR_NONE) {
     delete nnopt;
+    ml_loge("creating optimizer failed");
+  } else {
+    *optimizer = nnopt;
   }
 
   return status;
index 5923411..ae424d8 100644 (file)
@@ -42,7 +42,10 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/src/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/src/model_loader.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/src/addition_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/src/blas_interface.cpp \
-                  $(NNTRAINER_ROOT)/nntrainer/src/weight.cpp
+                  $(NNTRAINER_ROOT)/nntrainer/src/weight.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/src/adam.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/src/sgd.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/src/optimizer_factory.cpp
 
 NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer/include \
                       $(NNTRAINER_ROOT)/api \
index e9df137..26118ea 100644 (file)
@@ -68,12 +68,6 @@ public:
   sharedConstTensor backwarding(sharedConstTensor in, int iteration);
 
   /**
-   * @brief     copy layer
-   * @param[in] l layer to copy
-   */
-  void copy(std::shared_ptr<Layer> l);
-
-  /**
    * @brief setActivation by preset ActivationType
    *
    * @param[in] ActivationTypeeActivationTypeeActivationTypeet
diff --git a/nntrainer/include/adam.h b/nntrainer/include/adam.h
new file mode 100644 (file)
index 0000000..32c55de
--- /dev/null
@@ -0,0 +1,107 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file       adam.h
+ * @date       6 October 2020
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Jijoong Moon <jijoong.moon@samsung.com>
+ * @author     Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug                No known bugs except for NYI items
+ * @brief      This is the Adam optimizer.
+ */
+#ifndef __ADAM_H__
+#define __ADAM_H__
+#ifdef __cplusplus
+
+#include <optimizer.h>
+
+namespace nntrainer {
+
+/**
+ * @class   Adam optimizer class
+ * @brief   Adam optimizer
+ */
+class Adam : public Optimizer {
+public:
+  /**
+   * @brief     Constructor of Optimizer Class
+   */
+  template <typename... Args>
+  Adam(float lr = 0.001f, double b1 = 0.9f, double b2 = 0.999f,
+       double ep = 1.0e-7f, Args... args) :
+    Optimizer(OptType::adam, lr, args...),
+    beta1(b1),
+    beta2(b2),
+    epsilon(ep) {}
+
+  /**
+   * @copydoc apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+   * int iteration)
+   */
+  void apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+                      int iteration);
+
+  /**
+   * @brief     get the base name for the optimizer
+   * @retval    base name of the optimizer
+   */
+  std::string getBaseName() { return "Adam"; };
+
+  /**
+   * @copydoc   getLearningRate(int iteration)
+   */
+  double getLearningRate(int iteration);
+
+  /**
+   * @copydoc setProperty(const PropertyType type,
+                           const std::string &value = "")
+   */
+  void setProperty(const PropertyType type, const std::string &value = "");
+
+  /**
+   * @copydoc Optimizer::initialize(std::shared_ptr<Weight> params, unsigned int
+   num_weights, bool setTensor)
+   */
+  int initialize(std::shared_ptr<Weight> params, unsigned int num_weights,
+                 bool setTensor);
+
+  /**
+   * @copydoc read(std::ifstream &file)
+   */
+  void read(std::ifstream &file);
+
+  /**
+   * @copydoc save(std::ofstream &file)
+   */
+  void save(std::ofstream &file);
+
+  /**
+   * @brief get beta1
+   */
+  double getBeta1() { return beta1; };
+
+  /**
+   * @brief get beta2
+   */
+  double getBeta2() { return beta2; };
+
+  /**
+   * @brief get epsilon
+   */
+  double getEpsilon() { return epsilon; }
+
+private:
+  /**
+   * @brief Internal Tensors for adam Optimizer
+   */
+  std::vector<std::pair<Tensor, Tensor>> weight_mv;
+
+  double beta1;   /** momentum for grad */
+  double beta2;   /** momentum for grad**2 */
+  double epsilon; /** epsilon to protect overflow */
+};
+} /* namespace nntrainer */
+
+#endif /* __cplusplus */
+#endif /* __ADAM_H__ */
index 356efd4..e5407ce 100644 (file)
@@ -82,12 +82,6 @@ public:
   sharedConstTensor backwarding(sharedConstTensor in, int iteration);
 
   /**
-   * @brief     copy layer
-   * @param[in] l layer to copy
-   */
-  void copy(std::shared_ptr<Layer> l);
-
-  /**
    * @brief     get the base name for the layer
    * @retval    base name of the layer
    */
index 0c085c8..54501c2 100644 (file)
@@ -78,12 +78,6 @@ public:
   sharedConstTensor backwarding(sharedConstTensor in, int iteration);
 
   /**
-   * @brief     copy layer
-   * @param[in] l layer to copy
-   */
-  void copy(std::shared_ptr<Layer> l);
-
-  /**
    * @brief     get the base name for the layer
    * @retval    base name of the layer
    */
index 1693357..ac1ae2d 100644 (file)
@@ -87,12 +87,6 @@ public:
   int initialize();
 
   /**
-   * @brief     Copy Layer
-   * @param[in] l layer to copy
-   */
-  void copy(std::shared_ptr<Layer> l);
-
-  /**
    * @brief     get the base name for the layer
    * @retval    base name of the layer
    */
index 2fb6cee..f3caaf1 100644 (file)
@@ -230,7 +230,7 @@ public:
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  int setOptimizer(Optimizer &opt);
+  int setOptimizer(std::shared_ptr<Optimizer> opt);
 
   /**
    * @brief     Activation Type Getter
@@ -394,7 +394,8 @@ protected:
   /**
    * @brief     Optimizer for this layer
    */
-  Optimizer opt;
+  // TODO: fix with #630
+  std::shared_ptr<Optimizer> opt;
 
   /**
    * @brief     Layer type
index 16b5db7..eb0c894 100644 (file)
@@ -115,7 +115,7 @@ public:
    * @brief     Get Learning rate
    * @retval    Learning rate
    */
-  float getLearningRate() { return opt.getLearningRate(); };
+  float getLearningRate() { return opt->getLearningRate(); };
 
   /**
    * @brief     Create and load the Network with ini configuration file.
@@ -305,8 +305,8 @@ private:
 
   std::string save_path; /**< Model path to save / read */
 
-  Optimizer opt; /**< Optimizer, This gets copied into each layer, do not use
-                    this directly */
+  std::shared_ptr<Optimizer> opt; /**< Optimizer; this gets copied into each
+                    layer, do not use this directly */
 
   NetType net_type; /**< Network Type */
 
index 0bb3ec0..f85e27b 100644 (file)
@@ -37,45 +37,29 @@ namespace nntrainer {
  */
 enum class OptType { sgd = 0, adam = 1, unknown = 2 };
 
-/**
- * @brief     type for the Optimizor to save hyper-parameter
- */
-typedef struct _OptParam {
-  float learning_rate;
-  double beta1;
-  double beta2;
-  double epsilon;
-  float decay_rate;
-  float decay_steps;
-  bool continue_train; /** Continue training with previous tensors for adam */
+class Optimizer {
 
-  _OptParam(OptType type = OptType::adam) :
-    learning_rate(0.001f),
-    beta1(0.9f),
-    beta2(0.999f),
-    epsilon(1.0e-7f),
-    decay_rate(1.0f),
-    decay_steps(-1.0f),
-    continue_train(false) {
-    if (type == OptType::sgd) {
-      learning_rate = 0.01f;
-    }
-  }
-} OptParam;
+  /** Allow layer to initialize optimizer with itself */
+  friend class Layer;
 
-class Optimizer {
 public:
   /**
-   * @brief     Constructor of Optimizer Class
+   * @brief     Default Constructor of Optimizer Class
    */
-  Optimizer() : type(OptType::unknown), popt() {}
-
-  Optimizer(const OptType type, OptParam popt);
+  Optimizer(const OptType t, float lr, float decay_rate = 1.0f,
+            float decay_steps = -1.0f, float continue_train = false) :
+    type(t),
+    learning_rate(lr),
+    decay_rate(decay_rate),
+    decay_steps(decay_steps),
+    continue_train(continue_train) {
+    checkValidation();
+  }
 
   /**
    * @brief     Destructor of Optimizer Class
    */
-  ~Optimizer() {}
+  virtual ~Optimizer() {}
 
   /**
    * @brief  copy constructor
@@ -102,14 +86,6 @@ public:
   Optimizer &operator=(Optimizer &&rhs) = default;
 
   /**
-   * @brief     set Optimizer Type
-   * @param[in] t Optimizer type
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   */
-  int setType(OptType t);
-
-  /**
    * @brief     get Optimizer Type
    * @retval    Optimizer type
    */
@@ -119,27 +95,19 @@ public:
    * @brief     get Learning Rate
    * @retval    Learning rate
    */
-  float getLearningRate() { return popt.learning_rate; };
+  float getLearningRate() { return learning_rate; };
 
   /**
    * @brief     get Decay Rate for learning rate decay
    * @retval    decay rate
    */
-  float getDecayRate() { return popt.decay_rate; };
+  float getDecayRate() { return decay_rate; };
 
   /**
    * @brief     get Decay Steps for learning rate decay
    * @retval    decay steps
    */
-  float getDecaySteps() { return popt.decay_steps; };
-
-  /**
-   * @brief     set Optimizer Parameters
-   * @param[in] p Optimizer Parameter : OptParam
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   */
-  int setOptParam(OptParam p);
+  float getDecaySteps() { return decay_steps; };
 
   /**
    * @brief     set Optimizer Parameters
@@ -150,25 +118,6 @@ public:
   int setProperty(std::vector<std::string> values);
 
   /**
-   * @brief     get Optimizer Parameters
-   * @retval OptParam
-   */
-  OptParam getOptParam() { return popt; };
-
-  /**
-   * @brief     initialize optimizer. Initialize Weight if it is adam
-   * @param[in] params Weight list
-   * @param[in] num_weights size of the array
-   * @param[in] setTensor true if the layer need weight update.
-   *            Input Layer and Batch Normalization layer won't need it.
-   *            Therefore, it sets false.
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   */
-  int initialize(std::shared_ptr<Weight> params, unsigned int num_weights,
-                 bool setTensor);
-
-  /**
    * @brief     apply gradient to weight_list
    * @param[in] params Weight list
    * @param[in] num_weights size of the array
@@ -201,36 +150,82 @@ public:
    * @brief     Read Training optimizer paramters from file
    * @param[in] file input stream file
    */
-  void read(std::ifstream &file);
+  virtual void read(std::ifstream &file);
 
   /**
    * @brief     Save Training optimizer paramters from file
    * @param[in] file output stream file
    */
-  void save(std::ofstream &file);
+  virtual void save(std::ofstream &file);
 
   /**
-   * @brief     get the base name for the layer
-   * @retval    base name of the layer
+   * @brief setProperty by PropertyType
+   * @note By passing empty string, this can validate if @a type is valid
+   * @param[in] type property type to be passed
+   * @param[in] value value to be passed, if empty string is passed, do nothing
+   * but throws error when @a type is invalid
+   * @exception exception::not_supported     when property type is not valid for
+   * the particular layer
+   * @exception std::invalid_argument invalid argument
    */
-  std::string getBaseName() { return "Optimizer"; };
+  virtual void setProperty(const PropertyType type,
+                           const std::string &value = "");
+
+  /**
+   * @brief     get the base name for the optimizer
+   * @retval    base name of the optimizer
+   */
+  virtual std::string getBaseName() = 0;
 
-private:
+  /**
+   * @brief     validate the optimizer
+   */
+  virtual void checkValidation();
+
+protected:
   /**
    * @brief Optimizer Type
    */
   OptType type;
 
   /**
-   * @brief Optimizer Hyper Parmeters
+   * @brief     get Learning Rate for the given iteration
+   * @param[in] iteration Iteration for the learning rate
+   * @retval    Learning rate
    */
-  OptParam popt;
+  virtual double getLearningRate(int iteration);
 
+  float learning_rate; /** learning rate */
+  float decay_rate;    /** decay rate for learning rate */
+  float decay_steps;   /** decay steps for learning rate */
+  bool continue_train; /** Continue training with previous tensors for adam */
+
+private:
   /**
-   * @brief Internal Tensors for adam Optimizer
+   * @brief     initialize optimizer. Initialize Weight if it is adam
+   * @param[in] params Weight list
+   * @param[in] num_weights size of the array
+   * @param[in] setTensor true if the layer need weight update.
+   *            Input Layer and Batch Normalization layer won't need it.
+   *            Therefore, it sets false.
+   * @retval #ML_ERROR_NONE Successful.
+   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  std::vector<std::pair<Tensor, Tensor>> weight_mv;
+  virtual int initialize(std::shared_ptr<Weight> params,
+                         unsigned int num_weights, bool setTensor);
+
+  /**
+   * @brief     apply gradient to the given weight
+   * @param[in] weight Weight and gradient set to be updated
+   * @param[in] tensor_idx Idx of this tensor in the tensors list
+   * @param[in] num_weights size of the array
+   * @param[in] iteration nth epoch number
+   * @note weight which is called upon can be assumed to be trainable
+   */
+  virtual void apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+                              int iteration) = 0;
 };
+
 } /* namespace nntrainer */
 
 #endif /* __cplusplus */
diff --git a/nntrainer/include/optimizer_factory.h b/nntrainer/include/optimizer_factory.h
new file mode 100644 (file)
index 0000000..f25abcf
--- /dev/null
@@ -0,0 +1,48 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file       optimizer_factory.h
+ * @date       7 October 2020
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug                No known bugs except for NYI items
+ * @brief      This is the optimizer factory.
+ */
+
+#ifndef __OPTIMIZER_FACTORY_H__
+#define __OPTIMIZER_FACTORY_H__
+#ifdef __cplusplus
+
+#include <adam.h>
+#include <optimizer.h>
+#include <sgd.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Factory creator with copy constructor
+ */
+std::unique_ptr<Optimizer> createOptimizer(OptType type, const Optimizer &opt);
+
+/**
+ * @brief Factory creator with constructor
+ */
+template <typename... Args>
+std::unique_ptr<Optimizer> createOptimizer(OptType type, Args... args) {
+  switch (type) {
+  case OptType::sgd:
+    return std::make_unique<SGD>(args...);
+  case OptType::adam:
+    return std::make_unique<Adam>(args...);
+  case OptType::unknown:
+    /** fallthrough intended */
+  default:
+    throw std::invalid_argument("Unknown type for the optimizer");
+  }
+}
+
+} // namespace nntrainer
+
+#endif // __cplusplus
+#endif // __OPTIMIZER_FACTORY_H__
diff --git a/nntrainer/include/sgd.h b/nntrainer/include/sgd.h
new file mode 100644 (file)
index 0000000..30834a8
--- /dev/null
@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file       sgd.h
+ * @date       6 October 2020
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Jijoong Moon <jijoong.moon@samsung.com>
+ * @author     Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug                No known bugs except for NYI items
+ * @brief      This is the SGD optimizer.
+ */
+#ifndef __SGD_H__
+#define __SGD_H__
+#ifdef __cplusplus
+
+#include <optimizer.h>
+
+namespace nntrainer {
+
+/**
+ * @class   SGD optimizer class
+ * @brief   Stochastic Gradient Descent optimizer class
+ */
+class SGD : public Optimizer {
+public:
+  /**
+   * @brief     Constructor of Optimizer Class
+   */
+  template <typename... Args>
+  SGD(float lr = 0.0001f, Args... args) :
+    Optimizer(OptType::sgd, lr, args...) {}
+
+  /**
+   * @copydoc apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+   * int iteration)
+   */
+  void apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+                      int iteration);
+
+  /**
+   * @brief     get the base name for the optimizer
+   * @retval    base name of the optimizer
+   */
+  std::string getBaseName() { return "SGD"; };
+};
+} /* namespace nntrainer */
+
+#endif /* __cplusplus */
+#endif /* __SGD_H__ */
index af0714b..d62d364 100644 (file)
@@ -55,6 +55,8 @@ class Weight {
 
   /** Declare opitmizer as friend to get variable/gradient reference */
   friend class Optimizer;
+  friend class SGD;
+  friend class Adam;
 
 public:
   /**
index 9366592..b656cb1 100644 (file)
@@ -43,8 +43,11 @@ nntrainer_sources = [
   'src/neuralnet.cpp',
   'src/nntrainer_logger.cpp',
   'src/optimizer.cpp',
+  'src/optimizer_factory.cpp',
   'src/parse_util.cpp',
   'src/pooling2d_layer.cpp',
+  'src/sgd.cpp',
+  'src/adam.cpp',
   'src/tensor.cpp',
   'src/tensor_dim.cpp',
   'src/util_func.cpp',
@@ -73,10 +76,13 @@ nntrainer_headers = [
   'include/optimizer.h',
   'include/parse_util.h',
   'include/pooling2d_layer.h',
+  'include/sgd.h',
+  'include/adam.h',
   'include/tensor.h',
   'include/tensor_dim.h',
   'include/util_func.h',
   'include/weight.h',
+  'include/optimizer_factory.h',
   '../api/nntrainer-api-common.h'
 ]
 
index 2604d1d..f0c798d 100644 (file)
@@ -72,18 +72,6 @@ sharedConstTensor ActivationLayer::backwarding(sharedConstTensor derivative,
   return MAKE_SHARED_TENSOR(std::move(ret));
 }
 
-/**
- * @brief     copy layer
- * @param[in] l layer to copy
- */
-void ActivationLayer::copy(std::shared_ptr<Layer> l) {
-  std::shared_ptr<ActivationLayer> from =
-    std::static_pointer_cast<ActivationLayer>(l);
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-  this->activation_type = from->activation_type;
-};
-
 int ActivationLayer::setActivation(
   std::function<Tensor(Tensor const &)> const &activation_fn,
   std::function<Tensor(Tensor const &, Tensor const &)> const
diff --git a/nntrainer/src/adam.cpp b/nntrainer/src/adam.cpp
new file mode 100644 (file)
index 0000000..a95252e
--- /dev/null
@@ -0,0 +1,155 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file       adam.cpp
+ * @date       6 October 2020
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Jijoong Moon <jijoong.moon@samsung.com>
+ * @author     Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug                No known bugs except for NYI items
+ * @brief      This is the Adam optimizer.
+ */
+
+#include <cmath>
+#include <fstream>
+
+#include <adam.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <parse_util.h>
+#include <util_func.h>
+
+namespace nntrainer {
+
+int Adam::initialize(std::shared_ptr<Weight> weight_list,
+                     unsigned int num_weights, bool set_tensor) {
+  int status = ML_ERROR_NONE;
+  weight_mv.clear();
+
+  if (set_tensor) {
+    for (unsigned int i = 0; i < num_weights; ++i) {
+      Weight &w = weight_list.get()[i];
+
+      // TODO: only trainable weights must be sent to optimizer
+      if (!w.getTrainable())
+        continue;
+
+      Tensor m = Tensor(w.getDim());
+      m.setZero();
+      Tensor v = Tensor(w.getDim());
+      v.setZero();
+      std::pair<Tensor, Tensor> p =
+        std::pair<Tensor, Tensor>(std::move(m), std::move(v));
+      weight_mv.push_back(std::move(p));
+    }
+  }
+  return status;
+}
+
+double Adam::getLearningRate(int iteration) {
+  double ll = Optimizer::getLearningRate(iteration);
+
+  std::function<float(double)> biasCorrection = [&](float f) {
+    return 1.0f - pow(f, iteration + 1);
+  };
+
+  ll *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
+
+  return ll;
+}
+
+void Adam::apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+                          int iteration) {
+
+  Tensor &x = weight.getVariableRef();
+  const Tensor &x_grad = weight.getGradientRef();
+
+  // This is implementation of adam from original paper.
+  // This is not deleted intentionally.
+  // float biasCorrection1 = 1 - pow(beta1, iteration + 1);
+  // float biasCorrection2 = 1 - pow(beta2, iteration + 1);
+  // Tensor &wm = weight_mv[idx].first;
+  // Tensor &wv = weight_mv[idx].second;
+
+  // wm.multiply_i(beta1);
+  // wm.add_i(x_grad, 1.0f - beta1);
+
+  // wv.multiply_i(beta2);
+  // wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
+
+  // Tensor denom = wv.apply(sqrtFloat)
+  //                  .divide(sqrtFloat(biasCorrection2))
+  //                  .add(epsilon);
+  // x.add_i(wm.divide(denom), -ll / biasCorrection1);
+
+  std::function<double(double)> sqrtEps = [&](double f) {
+    return sqrtDouble(f) + this->epsilon;
+  };
+
+  Tensor &wm = weight_mv[tensor_idx].first;
+  Tensor &wv = weight_mv[tensor_idx].second;
+
+  wm.multiply_i(beta1);
+  wm.add_i(x_grad, 1.0f - beta1);
+
+  wv.multiply_i(beta2);
+  wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
+
+  x.add_i(wm.divide(wv.apply(sqrtEps)), -updated_lr);
+}
+
+void Adam::setProperty(const PropertyType type, const std::string &value) {
+  int status = ML_ERROR_NONE;
+
+  switch (type) {
+  case PropertyType::beta1:
+    status = setDouble(beta1, value);
+    break;
+  case PropertyType::beta2:
+    status = setDouble(beta2, value);
+    break;
+  case PropertyType::epsilon:
+    status = setDouble(epsilon, value);
+    break;
+  default:
+    Optimizer::setProperty(type, value);
+    status = ML_ERROR_NONE;
+    break;
+  }
+
+  throw_status(status);
+}
+
+void Adam::read(std::ifstream &file) {
+  OptType loaded_type;
+  file.read((char *)&loaded_type, sizeof(OptType));
+
+  if (loaded_type == type) {
+    if (continue_train) {
+      for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) {
+        (*iter).first.read(file);
+        (*iter).second.read(file);
+      }
+    } else {
+      size_t total_size = 0;
+      for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++)
+        total_size += (*iter).first.getSize() + (*iter).second.getSize();
+
+      file.seekg(total_size, std::ifstream::cur);
+    }
+  } else {
+    ml_logw("Not loading saved optimizer parameters due to mismatched type");
+  }
+}
+
+void Adam::save(std::ofstream &file) {
+  Optimizer::save(file);
+
+  for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) {
+    (*iter).first.save(file);
+    (*iter).second.save(file);
+  }
+}
+
+} // namespace nntrainer
index bb73ea8..ea9de6d 100644 (file)
@@ -83,13 +83,4 @@ void AdditionLayer::setProperty(const PropertyType type,
   }
 }
 
-void AdditionLayer::copy(std::shared_ptr<Layer> l) {
-  std::shared_ptr<AdditionLayer> from =
-    std::static_pointer_cast<AdditionLayer>(l);
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
-}
-
 } /* namespace nntrainer */
index 828a3cb..3c65840 100644 (file)
@@ -176,7 +176,7 @@ BatchNormalizationLayer::backwarding(sharedConstTensor derivative,
   Tensor dx = dx_2.multiply(dx_1);
   dx.divide_i(N);
 
-  opt.apply_gradients(weight_list, num_weights, iteration);
+  opt->apply_gradients(weight_list, num_weights, iteration);
 
   return MAKE_SHARED_TENSOR(std::move(dx));
 }
@@ -186,11 +186,6 @@ void BatchNormalizationLayer::copy(std::shared_ptr<Layer> l) {
 
   std::shared_ptr<BatchNormalizationLayer> from =
     std::static_pointer_cast<BatchNormalizationLayer>(l);
-  this->opt = from->opt;
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
   this->cvar.copy(from->cvar);
 }
 
index 6cecc33..17808f6 100644 (file)
@@ -348,7 +348,7 @@ sharedConstTensor Conv2DLayer::backwarding(sharedConstTensor derivative,
       }
     }
 
-    opt.apply_gradients(weight_list, num_weights, iteration);
+    opt->apply_gradients(weight_list, num_weights, iteration);
   }
 
   return MAKE_SHARED_TENSOR(std::move(strip_pad(ret, padding)));
@@ -356,6 +356,7 @@ sharedConstTensor Conv2DLayer::backwarding(sharedConstTensor derivative,
 
 void Conv2DLayer::copy(std::shared_ptr<Layer> l) {
   Layer::copy(l);
+
   std::shared_ptr<Conv2DLayer> from = std::static_pointer_cast<Conv2DLayer>(l);
   this->filter_size = from->filter_size;
   for (unsigned int i = 0; i < CONV2D_DIM; ++i) {
@@ -363,11 +364,6 @@ void Conv2DLayer::copy(std::shared_ptr<Layer> l) {
     this->stride[i] = from->stride[i];
     this->padding[i] = from->padding[i];
   }
-
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
 }
 
 int Conv2DLayer::setSize(int *size, PropertyType type) {
index 7db4f7a..ff9d451 100644 (file)
@@ -88,12 +88,14 @@ sharedConstTensor FullyConnectedLayer::forwarding(sharedConstTensor in) {
 
 void FullyConnectedLayer::read(std::ifstream &file) {
   Layer::read(file);
-  opt.read(file);
+  if (opt)
+    opt->read(file);
 }
 
 void FullyConnectedLayer::save(std::ofstream &file) {
   Layer::save(file);
-  opt.save(file);
+  if (opt)
+    opt->save(file);
 }
 
 void FullyConnectedLayer::copy(std::shared_ptr<Layer> l) {
@@ -101,13 +103,7 @@ void FullyConnectedLayer::copy(std::shared_ptr<Layer> l) {
 
   std::shared_ptr<FullyConnectedLayer> from =
     std::static_pointer_cast<FullyConnectedLayer>(l);
-  this->opt = from->opt;
   this->unit = from->unit;
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-  this->loss = from->loss;
 }
 
 sharedConstTensor FullyConnectedLayer::backwarding(sharedConstTensor derivative,
@@ -127,7 +123,7 @@ sharedConstTensor FullyConnectedLayer::backwarding(sharedConstTensor derivative,
   djdw = djdw.sum(0);
 
   if (trainable) {
-    opt.apply_gradients(weight_list, num_weights, iteration);
+    opt->apply_gradients(weight_list, num_weights, iteration);
   }
 
   return MAKE_SHARED_TENSOR(std::move(ret));
index 66dc2e1..7808db0 100644 (file)
@@ -51,13 +51,4 @@ sharedConstTensor FlattenLayer::backwarding(sharedConstTensor in,
   return MAKE_SHARED_TENSOR(std::move(temp));
 }
 
-void FlattenLayer::copy(std::shared_ptr<Layer> l) {
-  std::shared_ptr<FlattenLayer> from =
-    std::static_pointer_cast<FlattenLayer>(l);
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
-}
-
 } /* namespace nntrainer */
index 89a1843..d1cfec7 100644 (file)
@@ -51,15 +51,6 @@ void InputLayer::setProperty(const PropertyType type,
   }
 }
 
-void InputLayer::copy(std::shared_ptr<Layer> l) {
-  std::shared_ptr<InputLayer> from = std::static_pointer_cast<InputLayer>(l);
-  this->opt = from->opt;
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-}
-
 sharedConstTensor InputLayer::forwarding(sharedConstTensor in) {
   input = *in;
 
index 5c20bba..daccd27 100644 (file)
@@ -24,6 +24,7 @@
 #include <layer.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
+#include <optimizer_factory.h>
 #include <parse_util.h>
 #include <util_func.h>
 
@@ -40,11 +41,9 @@ int Layer::setActivation(ActivationType acti) {
   return status;
 }
 
-int Layer::setOptimizer(Optimizer &opt) {
-  this->opt.setType(opt.getType());
-  this->opt.setOptParam(opt.getOptParam());
-
-  return this->opt.initialize(weight_list, num_weights, true);
+int Layer::setOptimizer(std::shared_ptr<Optimizer> opt) {
+  this->opt = createOptimizer(opt->getType(), *opt.get());
+  return this->opt->initialize(weight_list, num_weights, true);
 }
 
 int Layer::checkValidation() {
@@ -72,6 +71,23 @@ void Layer::copy(std::shared_ptr<Layer> l) {
   for (unsigned int i = 0; i < num_weights; ++i) {
     weightAt(i) = l->weightAt(i);
   }
+
+  // TODO: fix this #630
+  this->opt = l->opt;
+  this->input_dim = l->input_dim;
+  this->output_dim = l->output_dim;
+  this->input.copy(l->input);
+  this->hidden.copy(l->hidden);
+  this->activation_type = l->activation_type;
+  this->loss = l->loss;
+  this->type = l->type;
+  this->weight_regularizer = l->weight_regularizer;
+  this->weight_regularizer_constant = l->weight_regularizer_constant;
+  this->weight_initializer = l->weight_initializer;
+  this->flatten = l->flatten;
+  this->trainable = l->trainable;
+  this->num_inputs = l->num_inputs;
+  this->num_outputs = l->num_outputs;
 }
 
 void Layer::read(std::ifstream &file) {
index 3cb05a0..692a55c 100644 (file)
@@ -126,10 +126,10 @@ void LossLayer::updateLoss(const Tensor &l) {
 }
 
 void LossLayer::copy(std::shared_ptr<Layer> l) {
+  Layer::copy(l);
+
   std::shared_ptr<LossLayer> from = std::static_pointer_cast<LossLayer>(l);
-  this->input.copy(from->input);
   this->loss_type = from->loss_type;
-  this->loss = from->loss;
 }
 
 sharedConstTensor LossLayer::backwarding(sharedConstTensor derivative,
index 4014333..36c3b21 100644 (file)
@@ -18,6 +18,7 @@
 #include <neuralnet.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
+#include <optimizer_factory.h>
 #include <parse_util.h>
 #include <sstream>
 #include <util_func.h>
@@ -54,22 +55,53 @@ int ModelLoader::loadModelConfigIni(dictionary *ini, NeuralNetwork &model) {
     iniparser_getint(ini, "Model:Batch_Size", model.batch_size);
 
   /** Default to adam optimizer */
-  status = model.opt.setType((OptType)parseType(
-    iniparser_getstring(ini, "Model:Optimizer", "adam"), TOKEN_OPT));
-  NN_RETURN_STATUS();
+  OptType opt_type = (OptType)parseType(
+    iniparser_getstring(ini, "Model:Optimizer", "adam"), TOKEN_OPT);
+
+  try {
+    model.opt = createOptimizer(opt_type);
+  } catch (std::exception &e) {
+    ml_loge("%s %s", typeid(e).name(), e.what());
+    return ML_ERROR_INVALID_PARAMETER;
+  } catch (...) {
+    ml_loge("Creating the optimizer failed");
+    return ML_ERROR_INVALID_PARAMETER;
+  }
+
+  std::vector<std::string> optimizer_prop = {};
+  optimizer_prop.push_back(
+    {"learning_rate=" +
+     std::string(iniparser_getstring(
+       ini, "Model:Learning_rate",
+       std::to_string(model.opt->getLearningRate()).c_str()))});
+
+  optimizer_prop.push_back(
+    {"decay_steps=" + std::string(iniparser_getstring(
+                        ini, "Model:Decay_steps",
+                        std::to_string(model.opt->getDecaySteps()).c_str()))});
+  optimizer_prop.push_back(
+    {"decay_rate=" + std::string(iniparser_getstring(
+                       ini, "Model:Decay_rate",
+                       std::to_string(model.opt->getDecayRate()).c_str()))});
+
+  if (model.opt->getType() == OptType::adam) {
+    std::shared_ptr<Adam> opt_adam = std::static_pointer_cast<Adam>(model.opt);
+
+    optimizer_prop.push_back(
+      {"beta1=" +
+       std::string(iniparser_getstring(
+         ini, "Model:Beta1", std::to_string(opt_adam->getBeta1()).c_str()))});
+    optimizer_prop.push_back(
+      {"beta2=" +
+       std::string(iniparser_getstring(
+         ini, "Model:Beta2", std::to_string(opt_adam->getBeta2()).c_str()))});
+    optimizer_prop.push_back(
+      {"epsilon=" + std::string(iniparser_getstring(
+                      ini, "Model:Epsilon",
+                      std::to_string(opt_adam->getEpsilon()).c_str()))});
+  }
 
-  OptParam popt(model.opt.getType());
-  popt.learning_rate =
-    iniparser_getdouble(ini, "Model:Learning_rate", popt.learning_rate);
-  popt.decay_steps =
-    iniparser_getint(ini, "Model:Decay_steps", popt.decay_steps);
-  popt.decay_rate =
-    iniparser_getdouble(ini, "Model:Decay_rate", popt.decay_rate);
-  popt.beta1 = iniparser_getdouble(ini, "Model:beta1", popt.beta1);
-  popt.beta2 = iniparser_getdouble(ini, "Model:beta2", popt.beta2);
-  popt.epsilon = iniparser_getdouble(ini, "Model:epsilon", popt.epsilon);
-
-  status = model.opt.setOptParam(popt);
+  status = model.opt->setProperty(optimizer_prop);
   NN_RETURN_STATUS();
 
   return status;
index 3eb1f04..1a4a27a 100644 (file)
@@ -162,7 +162,7 @@ int NeuralNetwork::setTrainConfig(std::vector<std::string> values) {
       status = setBoolean(cont_train, value);
       NN_RETURN_STATUS();
       continue_train = cont_train;
-      opt.setProperty({values[i]});
+      opt->setProperty({values[i]});
     } break;
     case PropertyType::batch_size: {
       status = setUint(batch_size, value);
@@ -597,7 +597,7 @@ int NeuralNetwork::setOptimizer(std::shared_ptr<Optimizer> optimizer) {
     return ML_ERROR_NOT_SUPPORTED;
   }
 
-  opt = *optimizer.get();
+  opt = optimizer;
 
   return ML_ERROR_NONE;
 }
index 4fe43d7..8065f97 100644 (file)
 
 namespace nntrainer {
 
-Optimizer::Optimizer(const OptType t, const OptParam p) {
-  type = t;
-  popt = p;
+int Optimizer::initialize(std::shared_ptr<Weight> weight_list,
+                          unsigned int num_weights, bool set_tensor) {
+  return ML_ERROR_NONE;
 }
 
-int Optimizer::setType(OptType t) {
-  int status = ML_ERROR_NONE;
-  if (t == OptType::unknown) {
-    ml_loge("Error: Optimizer is unknown");
-    return ML_ERROR_INVALID_PARAMETER;
-  }
-  type = t;
-  return status;
-}
+double Optimizer::getLearningRate(int iteration) {
+  double ll = learning_rate;
 
-int Optimizer::setOptParam(OptParam p) {
-  int status = ML_ERROR_NONE;
-  if (p.learning_rate <= 0) {
-    ml_loge("Error: learning_rate should be grater than 0 (%f)",
-            p.learning_rate);
-    return ML_ERROR_INVALID_PARAMETER;
+  if (decay_steps != -1) {
+    ll = ll * pow(decay_rate, (iteration / decay_steps));
   }
 
-  popt = p;
-  return status;
-}
-
-int Optimizer::initialize(std::shared_ptr<Weight> weight_list,
-                          unsigned int num_weights, bool set_tensor) {
-  int status = ML_ERROR_NONE;
-
-  if (type == OptType::adam && set_tensor) {
-    for (unsigned int i = 0; i < num_weights; ++i) {
-      Weight &w = weight_list.get()[i];
-
-      // TODO: only trainable weights must be sent to optimizer
-      if (!w.getTrainable())
-        continue;
-
-      Tensor m = Tensor(w.getDim());
-      m.setZero();
-      Tensor v = Tensor(w.getDim());
-      v.setZero();
-      std::pair<Tensor, Tensor> p =
-        std::pair<Tensor, Tensor>(std::move(m), std::move(v));
-      weight_mv.push_back(std::move(p));
-    }
-  }
-  return status;
+  return ll;
 }
 
 void Optimizer::apply_gradients(std::shared_ptr<Weight> weight_list,
                                 unsigned int num_weights, int iteration) {
 
-  double ll = popt.learning_rate;
-
-  if (popt.decay_steps != -1) {
-    ll = ll * pow(popt.decay_rate, (iteration / popt.decay_steps));
-  }
-
-  if (type == OptType::adam) {
-    std::function<float(double)> biasCorrection = [&](float f) {
-      return 1.0f - pow(f, iteration + 1);
-    };
-
-    ll *= sqrt(biasCorrection(popt.beta2)) / biasCorrection(popt.beta1);
-  }
+  double ll = getLearningRate(iteration);
 
   int idx = 0;
   for (unsigned int i = 0; i < num_weights; ++i) {
@@ -109,55 +61,7 @@ void Optimizer::apply_gradients(std::shared_ptr<Weight> weight_list,
     if (!weight.getTrainable())
       continue;
 
-    Tensor &x = weight.getVariableRef();
-    const Tensor &x_grad = weight.getGradientRef();
-    switch (type) {
-    case OptType::sgd:
-      x.add_i(x_grad, -ll);
-      break;
-    case OptType::adam: {
-
-      // This is implementation of adam from original paper.
-      // This is not deleted intentionally.
-      // float biasCorrection1 = 1 - pow(popt.beta1, iteration + 1);
-      // float biasCorrection2 = 1 - pow(popt.beta2, iteration + 1);
-      // Tensor &wm = weight_mv[idx].first;
-      // Tensor &wv = weight_mv[idx].second;
-
-      // wm.multiply_i(popt.beta1);
-      // wm.add_i(x_grad, 1.0f - popt.beta1);
-
-      // wv.multiply_i(popt.beta2);
-      // wv.add_i(x_grad.multiply(x_grad), 1.0f - popt.beta2);
-
-      // Tensor denom = wv.apply(sqrtFloat)
-      //                  .divide(sqrtFloat(biasCorrection2))
-      //                  .add(popt.epsilon);
-      // x.add_i(wm.divide(denom), -ll / biasCorrection1);
-
-      std::function<double(double)> sqrtEps = [&](double f) {
-        return sqrtDouble(f) + this->popt.epsilon;
-      };
-
-      Tensor &wm = weight_mv[idx].first;
-      Tensor &wv = weight_mv[idx].second;
-
-      wm.multiply_i(popt.beta1);
-      wm.add_i(x_grad, 1.0f - popt.beta1);
-
-      wv.multiply_i(popt.beta2);
-      wv.add_i(x_grad.multiply(x_grad), 1.0f - popt.beta2);
-
-      x.add_i(wm.divide(wv.apply(sqrtEps)), -ll);
-
-      break;
-    }
-    case OptType::unknown:
-    default:
-      throw std::runtime_error("Unknown optimizer.");
-      break;
-    }
-
+    apply_gradient(weight, idx, ll, iteration);
     idx += 1;
   }
 }
@@ -168,75 +72,74 @@ int Optimizer::setProperty(std::vector<std::string> values) {
   for (unsigned int i = 0; i < values.size(); ++i) {
     std::string key;
     std::string value;
+
     status = getKeyValue(values[i], key, value);
+    NN_RETURN_STATUS();
 
-    unsigned int type = parseOptProperty(key.c_str());
-
-    switch (static_cast<PropertyType>(type)) {
-    case PropertyType::learning_rate:
-      status = setFloat(popt.learning_rate, value);
-      NN_RETURN_STATUS();
-      break;
-    case PropertyType::decay_steps:
-      status = setFloat(popt.decay_steps, value);
-      NN_RETURN_STATUS();
-      break;
-    case PropertyType::decay_rate:
-      status = setFloat(popt.decay_rate, value);
-      NN_RETURN_STATUS();
-      break;
-    case PropertyType::beta1:
-      status = setDouble(popt.beta1, value);
-      NN_RETURN_STATUS();
-      break;
-    case PropertyType::beta2:
-      status = setDouble(popt.beta2, value);
-      NN_RETURN_STATUS();
-      break;
-    case PropertyType::epsilon:
-      status = setDouble(popt.epsilon, value);
-      NN_RETURN_STATUS();
-      break;
-    case PropertyType::continue_train:
-      status = setBoolean(popt.continue_train, value);
-      NN_RETURN_STATUS();
-      break;
-    default:
-      ml_loge("Error: Unknown Optimizer Property Key");
-      status = ML_ERROR_INVALID_PARAMETER;
-      break;
+    unsigned int type = parseOptProperty(key);
+
+    if (value.empty()) {
+      return ML_ERROR_INVALID_PARAMETER;
+    }
+
+    try {
+      /// @note this calls derived setProperty if available
+      setProperty(static_cast<PropertyType>(type), value);
+    } catch (...) {
+      return ML_ERROR_INVALID_PARAMETER;
     }
   }
 
+  try {
+    checkValidation();
+  } catch (...) {
+    return ML_ERROR_INVALID_PARAMETER;
+  }
   return status;
 }
 
+void Optimizer::checkValidation() {
+  if (learning_rate <= 0.0f)
+    throw std::invalid_argument("Learning rate must be positive");
+}
+
+void Optimizer::setProperty(const PropertyType type, const std::string &value) {
+  int status = ML_ERROR_NONE;
+
+  switch (type) {
+  case PropertyType::learning_rate:
+    status = setFloat(learning_rate, value);
+    break;
+  case PropertyType::decay_steps:
+    status = setFloat(decay_steps, value);
+    break;
+  case PropertyType::decay_rate:
+    status = setFloat(decay_rate, value);
+    break;
+  case PropertyType::continue_train:
+    status = setBoolean(continue_train, value);
+    break;
+  default:
+    ml_loge("Error: Unknown Optimizer Property Key");
+    status = ML_ERROR_INVALID_PARAMETER;
+    break;
+  }
+
+  throw_status(status);
+}
+
 void Optimizer::read(std::ifstream &file) {
   OptType loaded_type;
   file.read((char *)&loaded_type, sizeof(OptType));
-  if (type == OptType::adam and loaded_type == type) {
-    if (popt.continue_train) {
-      for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) {
-        (*iter).first.read(file);
-        (*iter).second.read(file);
-      }
-    } else {
-      size_t total_size = 0;
-      for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++)
-        total_size += (*iter).first.getSize() + (*iter).second.getSize();
-
-      file.seekg(total_size, std::ifstream::cur);
-    }
-  }
+
+  if (loaded_type >= OptType::unknown)
+    throw std::runtime_error("Saved file has unknown optimizer");
 }
 
 void Optimizer::save(std::ofstream &file) {
+  if (type >= OptType::unknown)
+    throw std::runtime_error("Cannot save unknown optimizer");
+
   file.write((char *)&type, sizeof(OptType));
-  if (type == OptType::adam) {
-    for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) {
-      (*iter).first.save(file);
-      (*iter).second.save(file);
-    }
-  }
 }
 } // namespace nntrainer
diff --git a/nntrainer/src/optimizer_factory.cpp b/nntrainer/src/optimizer_factory.cpp
new file mode 100644 (file)
index 0000000..825b06d
--- /dev/null
@@ -0,0 +1,35 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file       optimizer_factory.cpp
+ * @date       7 October 2020
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug                No known bugs except for NYI items
+ * @brief      This is the optimizer factory.
+ */
+
+#include <adam.h>
+#include <optimizer.h>
+#include <sgd.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Factory creator with copy constructor
+ */
+std::unique_ptr<Optimizer> createOptimizer(OptType type, const Optimizer &opt) {
+  switch (type) {
+  case OptType::sgd:
+    return std::make_unique<SGD>(static_cast<const SGD &>(opt));
+  case OptType::adam:
+    return std::make_unique<Adam>(static_cast<const Adam &>(opt));
+  case OptType::unknown:
+    /** fallthrough intended */
+  default:
+    throw std::invalid_argument("Unknown type for the optimizer");
+  }
+}
+
+} // namespace nntrainer
index ba0078a..f2deb74 100644 (file)
@@ -179,6 +179,8 @@ void Pooling2DLayer::setBatch(unsigned int batch) {
 }
 
 void Pooling2DLayer::copy(std::shared_ptr<Layer> l) {
+  Layer::copy(l);
+
   std::shared_ptr<Pooling2DLayer> from =
     std::static_pointer_cast<Pooling2DLayer>(l);
 
@@ -189,11 +191,6 @@ void Pooling2DLayer::copy(std::shared_ptr<Layer> l) {
     this->stride[i] = from->stride[i];
     this->padding[i] = from->padding[i];
   }
-
-  this->input.copy(from->input);
-  this->hidden.copy(from->hidden);
-  this->input_dim = from->input_dim;
-  this->output_dim = from->output_dim;
 }
 
 void Pooling2DLayer::setProperty(const PropertyType type,
diff --git a/nntrainer/src/sgd.cpp b/nntrainer/src/sgd.cpp
new file mode 100644 (file)
index 0000000..27594e2
--- /dev/null
@@ -0,0 +1,25 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file       sgd.cpp
+ * @date       6 October 2020
+ * @see                https://github.com/nnstreamer/nntrainer
+ * @author     Jijoong Moon <jijoong.moon@samsung.com>
+ * @author     Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug                No known bugs except for NYI items
+ * @brief      This is the SGD optimizer.
+ */
+
+#include <sgd.h>
+
+namespace nntrainer {
+
+void SGD::apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
+                         int iteration) {
+  Tensor &x = weight.getVariableRef();
+  const Tensor &x_grad = weight.getGradientRef();
+  x.add_i(x_grad, -updated_lr);
+}
+
+} // namespace nntrainer
index 831284a..282c085 100644 (file)
@@ -338,6 +338,9 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/
 %{_includedir}/nntrainer/nntrainer-api-common.h
 %{_includedir}/nntrainer/blas_interface.h
 %{_includedir}/nntrainer/weight.h
+%{_includedir}/nntrainer/adam.h
+%{_includedir}/nntrainer/sgd.h
+%{_includedir}/nntrainer/optimizer_factory.h
 %{_libdir}/pkgconfig/nntrainer.pc
 
 %files devel-static
index 33b0141..4a02351 100644 (file)
  * @author      Jijoong Moon <jijoong.moon@samsung.com>
  * @bug         No known bugs
  */
-#include "databuffer_file.h"
-#include "databuffer_func.h"
-#include "neuralnet.h"
-#include "nntrainer_test_util.h"
-#include "util_func.h"
 #include <fstream>
+
+#include <databuffer_file.h>
+#include <databuffer_func.h>
+#include <neuralnet.h>
 #include <nntrainer_error.h>
+#include <optimizer_factory.h>
+#include <util_func.h>
+
+#include <nntrainer_test_util.h>
 
 /**
  * @brief Neural Network Model initialization
@@ -196,54 +199,28 @@ TEST(nntrainer_NeuralNetwork, init_03_p) {
 }
 
 /**
- * @brief Optimizer set type
+ * @brief Optimizer create
  */
-TEST(nntrainer_Optimizer, setType_01_p) {
-  int status = ML_ERROR_NONE;
-  nntrainer::Optimizer op;
-  nntrainer::OptType t = nntrainer::OptType::adam;
-  status = op.setType(t);
-  EXPECT_EQ(status, ML_ERROR_NONE);
+TEST(nntrainer_Optimizer, create_01_p) {
+  std::shared_ptr<nntrainer::Optimizer> op;
+  EXPECT_NO_THROW(op = createOptimizer(nntrainer::OptType::adam));
 }
 
 /**
- * @brief Optimizer set type
+ * @brief Optimizer create
  */
 TEST(nntrainer_Optimizer, setType_02_p) {
-  int status = ML_ERROR_NONE;
-  nntrainer::Optimizer op;
-  nntrainer::OptType t = nntrainer::OptType::sgd;
-  status = op.setType(t);
-  EXPECT_EQ(status, ML_ERROR_NONE);
+  std::shared_ptr<nntrainer::Optimizer> op;
+  EXPECT_NO_THROW(op = createOptimizer(nntrainer::OptType::sgd));
 }
 
 /**
- * @brief Optimizer set type
+ * @brief Optimizer create
  */
 TEST(nntrainer_Optimizer, setType_03_n) {
-  int status = ML_ERROR_NONE;
-  nntrainer::Optimizer op;
-  nntrainer::OptType t = nntrainer::OptType::unknown;
-  status = op.setType(t);
-  EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
-}
-
-/**
- * @brief Optimizer set Opt Param
- */
-TEST(nntrainer_Optimizer, setOptParam_01_p) {
-  int status = ML_ERROR_NONE;
-  nntrainer::Optimizer op;
-  nntrainer::OptType t = nntrainer::OptType::adam;
-  nntrainer::OptParam p;
-  status = op.setType(t);
-  EXPECT_EQ(status, ML_ERROR_NONE);
-  p.learning_rate = -0.001;
-  p.beta1 = 0.9;
-  p.beta2 = 0.9999;
-  p.epsilon = 1e-7;
-  status = op.setOptParam(p);
-  EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
+  std::shared_ptr<nntrainer::Optimizer> op;
+  EXPECT_THROW(op = createOptimizer(nntrainer::OptType::unknown),
+               std::invalid_argument);
 }
 
 /**
index 89545b6..56ef5b4 100644 (file)
@@ -23,7 +23,7 @@
 #include <loss_layer.h>
 #include <nntrainer_error.h>
 #include <nntrainer_test_util.h>
-#include <optimizer.h>
+#include <optimizer_factory.h>
 #include <pooling2d_layer.h>
 #include <tensor_dim.h>
 #include <util_func.h>
@@ -134,10 +134,10 @@ protected:
       input_str.push_back((*i).str());
     }
 
-    nntrainer::Optimizer op;
-    int status = op.setType(type);
-    EXPECT_EQ(status, ML_ERROR_NONE);
-    status = op.setProperty(input_str);
+    std::shared_ptr<nntrainer::Optimizer> op;
+    EXPECT_NO_THROW(op = createOptimizer(type));
+
+    status = op->setProperty(input_str);
     EXPECT_EQ(status, ML_ERROR_NONE);
     status = layer.setOptimizer(op);
     EXPECT_EQ(status, ML_ERROR_NONE);
@@ -1288,18 +1288,16 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_03_p) {
   status = layer2.initialize();
   EXPECT_EQ(status, ML_ERROR_NONE);
 
-  nntrainer::Optimizer op;
-  int status = op.setType(nntrainer::OptType::sgd);
-  EXPECT_EQ(status, ML_ERROR_NONE);
-  status = op.setProperty({"learning_rate=1.0"});
+  std::shared_ptr<nntrainer::Optimizer> op;
+  EXPECT_NO_THROW(op = createOptimizer(nntrainer::OptType::sgd));
+  status = op->setProperty({"learning_rate=1.0"});
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = layer1.setOptimizer(op);
   EXPECT_EQ(status, ML_ERROR_NONE);
 
-  nntrainer::Optimizer op2;
-  status = op2.setType(nntrainer::OptType::sgd);
-  EXPECT_EQ(status, ML_ERROR_NONE);
-  status = op2.setProperty({"learning_rate=1.0"});
+  std::shared_ptr<nntrainer::Optimizer> op2;
+  EXPECT_NO_THROW(op2 = createOptimizer(nntrainer::OptType::sgd));
+  status = op2->setProperty({"learning_rate=1.0"});
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = layer2.setOptimizer(op2);
   EXPECT_EQ(status, ML_ERROR_NONE);