[optimizer] Move optimizer variables to weights
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 3 Dec 2020 05:43:09 +0000 (14:43 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 10 Dec 2020 10:20:41 +0000 (19:20 +0900)
Move optimizer variables to weights
Now all the weight related tensors are handled by weights themselves
So, optimizer can be shared across all layers, no need to create new
copies for all layers

**Self evaluation:**
1. Build test: [x]Passed [ ]Failed [ ]Skipped
2. Run test: [x]Passed [ ]Failed [ ]Skipped

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/layer.cpp
nntrainer/optimizers/adam.cpp
nntrainer/optimizers/adam.h
nntrainer/tensor/weight.h

index 40f84c1..4342b7a 100644 (file)
@@ -38,7 +38,7 @@ void Layer::setActivation(ActivationType acti) {
 }
 
 int Layer::setOptimizer(std::shared_ptr<Optimizer> opt) {
-  this->opt = createOptimizer(opt->getType(), *opt);
+  this->opt = opt;
   return this->opt->initialize(weights, true);
 }
 
index 4a0e808..df955ea 100644 (file)
@@ -24,23 +24,21 @@ namespace nntrainer {
 
 const std::string Adam::type = "adam";
 
+enum AdamParams { wm, wv };
+
 int Adam::initialize(std::vector<Weight> &weight_list, bool set_tensor) {
   int status = ML_ERROR_NONE;
-  weight_mv.clear();
 
   if (set_tensor) {
-    for (auto const &w : weight_list) {
+    for (auto &w : weight_list) {
+      w.clearOptimizerVariables();
+
       // TODO: only trainable weights must be sent to optimizer
       if (!w.getTrainable())
         continue;
 
-      Tensor m = Tensor(w.getDim());
-      m.setZero();
-      Tensor v = Tensor(w.getDim());
-      v.setZero();
-      std::pair<Tensor, Tensor> p =
-        std::pair<Tensor, Tensor>(std::move(m), std::move(v));
-      weight_mv.push_back(std::move(p));
+      w.addOptimizerVariable(w.getDim()); /** Add wm */
+      w.addOptimizerVariable(w.getDim()); /** Add wv */
     }
   }
   return status;
@@ -68,8 +66,8 @@ void Adam::apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
   // This is not deleted intentionally.
   // float biasCorrection1 = 1 - pow(beta1, iteration + 1);
   // float biasCorrection2 = 1 - pow(beta2, iteration + 1);
-  // Tensor &wm = weight_mv[idx].first;
-  // Tensor &wv = weight_mv[idx].second;
+  // Tensor &wm = weight.getOptimizerVariableRef(AdamParams::wm);
+  // Tensor &wv = weight.getOptimizerVariableRef(AdamParams::wv);
 
   // wm.multiply_i(beta1);
   // wm.add_i(x_grad, 1.0f - beta1);
@@ -86,8 +84,8 @@ void Adam::apply_gradient(Weight &weight, int tensor_idx, double updated_lr,
     return 1 / (sqrtDouble(f) + this->epsilon);
   };
 
-  Tensor &wm = weight_mv[tensor_idx].first;
-  Tensor &wv = weight_mv[tensor_idx].second;
+  Tensor &wm = weight.getOptimizerVariableRef(AdamParams::wm);
+  Tensor &wv = weight.getOptimizerVariableRef(AdamParams::wv);
 
   wm.multiply_i(beta1);
   wm.add_i(x_grad, 1.0f - beta1);
@@ -123,31 +121,4 @@ void Adam::setProperty(const PropertyType type, const std::string &value) {
   throw_status(status);
 }
 
-void Adam::read(std::ifstream &file) {
-  /// @todo need strong exception safety guarantee
-  Optimizer::read(file);
-
-  if (continue_train) {
-    for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) {
-      (*iter).first.read(file);
-      (*iter).second.read(file);
-    }
-  } else {
-    size_t total_size = 0;
-    for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++)
-      total_size += (*iter).first.getSize() + (*iter).second.getSize();
-
-    file.seekg(total_size, std::ifstream::cur);
-  }
-}
-
-void Adam::save(std::ofstream &file) {
-  Optimizer::save(file);
-
-  for (auto iter = weight_mv.begin(); iter != weight_mv.end(); iter++) {
-    (*iter).first.save(file);
-    (*iter).second.save(file);
-  }
-}
-
 } // namespace nntrainer
index a43c57c..2590fa5 100644 (file)
@@ -64,16 +64,6 @@ public:
   int initialize(std::vector<Weight> &params, bool setTensor);
 
   /**
-   * @copydoc read(std::ifstream &file)
-   */
-  void read(std::ifstream &file);
-
-  /**
-   * @copydoc save(std::ofstream &file)
-   */
-  void save(std::ofstream &file);
-
-  /**
    * @brief get beta1
    */
   double getBeta1() { return beta1; };
@@ -91,10 +81,6 @@ public:
   static const std::string type;
 
 private:
-  /**
-   * @brief Internal Tensors for adam Optimizer
-   */
-  std::vector<std::pair<Tensor, Tensor>> weight_mv;
 
   double beta1;   /** momentum for grad */
   double beta2;   /** momentum for grad**2 */
index f52005f..e196cd8 100644 (file)
@@ -150,14 +150,38 @@ public:
    *
    * @note New dimension must maintain the shape of the variable
    */
-
   void reset(const TensorDim &dim, const WeightInitializer init, bool train) {
     initializer = init;
     Var_Grad::reset(dim, train);
   }
 
+  /**
+   * @brief Clear optimizer variables
+   */
+  void clearOptimizerVariables() { opt_vars.clear(); }
+
+  /**
+   * @brief Add optimizer variables
+   * @param dim Optimizer variable dimension
+   */
+  void addOptimizerVariable(const TensorDim &dim) {
+    opt_vars.emplace_back(dim);
+    opt_vars.back().setZero();
+  }
+
+  /**
+   * @brief Get optimizer variable reference
+   * @param idx Index of the optimizer variable to get
+   * @retval Reference of the optimizer variable
+   */
+  Tensor &getOptimizerVariableRef(unsigned int idx) {
+    return opt_vars[idx];
+  }
+
 private:
   WeightInitializer initializer; /**< initializer for this variable */
+
+  std::vector<Tensor> opt_vars;  /**< optimizer variables */
 };
 
 } // namespace nntrainer