[Load/Save] Update to load and save the adam momentum variables
authorjijoong.moon <jijoong.moon@samsung.com>
Thu, 20 Jan 2022 09:06:41 +0000 (18:06 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 18 Feb 2022 07:03:54 +0000 (16:03 +0900)
In this PR, Load and save is updated for the adam optimizer variables.
In order to do it, optimizer load twice, one for the weight, and one
for the optimizer variables.

**Changes proposed in this PR:**
- Added TOC generator for README.md

Resolves:

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
nntrainer/layers/layer_node.cpp
nntrainer/models/neuralnet.cpp

index 1a593b0..f9e6d50 100644 (file)
@@ -420,10 +420,13 @@ void LayerNode::read(std::ifstream &file, bool opt_var) {
     << __func__ << " layer needs to be finalized first!";
   if (opt_var) {
     for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
-      if (run_context->isGradientLastAccess(i)) {
+      if (run_context->isGradientLastAccess(i) && getTrainable()) {
         // @note read optimizer variables
-        for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i); ++j) {
-          run_context->getWeightOptVar(i, j).read(file);
+        if (run_context->weightHasGradient(i)) {
+          for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i);
+               ++j) {
+            run_context->getWeightOptVar(i, j).read(file);
+          }
         }
       }
     }
@@ -440,12 +443,16 @@ void LayerNode::read(std::ifstream &file, bool opt_var) {
 void LayerNode::save(std::ofstream &file, bool opt_var) const {
   NNTR_THROW_IF(!run_context, std::runtime_error)
     << __func__ << " layer needs to be finalized first!";
+
   if (opt_var) {
     for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
-      if (run_context->isGradientLastAccess(i)) {
+      if (run_context->isGradientLastAccess(i) && getTrainable()) {
         // @note save optimizer variables
-        for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i); ++j) {
-          run_context->getWeightOptVar(i, j).save(file);
+        if (run_context->weightHasGradient(i)) {
+          for (unsigned int j = 0; j < run_context->getNumWeightOptVar(i);
+               ++j) {
+            run_context->getWeightOptVar(i, j).save(file);
+          }
         }
       }
     }
index 6f9e0c7..85180c6 100644 (file)
@@ -205,7 +205,11 @@ int NeuralNetwork::initialize() {
 
   initialized = true;
 
-  if (!load_path.empty()) {
+  // @note we need check loadedWeight for the case of multiple call of load to
+  // load weight. Only the weight needs to be loaded here. Becuase the buffer
+  // for the optimizer is not allocated yet.
+  // loadedWeight check is just for the duplicate load of weight.
+  if (!load_path.empty() && !loadedWeight) {
     load(load_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
   }
 
@@ -328,7 +332,7 @@ void NeuralNetwork::save(const std::string &file_path,
 
     opt->save(model_file);
 
-    if (opt->getType() == "adam") {
+    if (istrequal(opt->getType(), "adam")) {
       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
            iter++) {
         (*iter)->save(model_file, true);
@@ -381,18 +385,22 @@ void NeuralNetwork::load(const std::string &file_path,
       }
       loadedWeight = true;
       bin_file_pos = model_file.tellg();
+      load_path = file_path;
+      return;
     }
     try {
       /// this is assuming that the failure is allowed at the end of the file
       /// read. so, after this line, additional read shouldn't be called
       model_file.seekg(bin_file_pos);
 
-      std::string opt_type;
-      opt_type = readString(model_file);
-      if (istrequal(opt_type, "adam") && istrequal(opt->getType(), "adam")) {
-        for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
-             iter++) {
-          (*iter)->read(model_file, true);
+      if (istrequal(opt->getType(), "adam")) {
+        char opt_type[4];
+        model_file.read(opt_type, 4);
+        if (istrequal(opt_type, "adam")) {
+          for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
+               iter++) {
+            (*iter)->read(model_file, true);
+          }
         }
       }