[manager] Support initialization/allocation of weights
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 17 Jun 2021 07:33:33 +0000 (16:33 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 23 Jun 2021 07:42:19 +0000 (16:42 +0900)
Support initialization and allocation of weights for LayerV2

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/tensor/manager.cpp

index 516fd81..9ae6d70 100644 (file)
@@ -719,6 +719,7 @@ NetworkGraph::updateRunContext(std::shared_ptr<Manager> &manager,
    */
   // const RunLayerContext &run_context = lnode->getRunContext();
   lnode->updateRunContext(RunLayerContext(
+    // TODO: update weights spec for trainable based on layer trainable prop
     manager->requestWeights(gnode, init_context.getWeightsSpec()), inputs,
     outputs, manager->requestTensors(gnode, init_context.getTensorsSpec())));
 
index 3aa8262..5793ee1 100644 (file)
@@ -31,6 +31,7 @@
 
 #include <activation_layer.h>
 #include <flatten_layer.h>
+#include <layer_node.h>
 #include <manager.h>
 #include <nntrainer_log.h>
 #include <rnn.h>
@@ -261,15 +262,28 @@ void Manager::initializeWeights() {
 
   unsigned int weight_offset = 0;
 
-  for (auto &l_w : weights) {
-    for (auto &w : l_w) {
-      Weight &weight = w.get();
-      auto dim = weight.getDim();
+  if (LAYER_V2) {
+    for (auto &l_w : weights_v2) {
+      for (auto &w : l_w) {
+        const auto &dim = w->getDim();
+
+        Tensor weight_prealloc = allocate_weight(dim, weight_offset);
+        weight_offset += dim.getDataLen();
 
-      Tensor weight_prealloc = allocate_weight(dim, weight_offset);
-      weight_offset += dim.getDataLen();
+        w->initializeVariable(weight_prealloc);
+      }
+    }
+  } else {
+    for (auto &l_w : weights) {
+      for (auto &w : l_w) {
+        Weight &weight = w.get();
+        auto dim = weight.getDim();
 
-      weight.initializeVariable(weight_prealloc);
+        Tensor weight_prealloc = allocate_weight(dim, weight_offset);
+        weight_offset += dim.getDataLen();
+
+        weight.initializeVariable(weight_prealloc);
+      }
     }
   }
 
@@ -280,24 +294,38 @@ void Manager::allocateWeights() {
   if (weights_allocated)
     return;
 
-  for (auto &l_w : weights) {
-    for (auto &w : l_w) {
-      Weight &weight = w.get();
-      weight.allocateVariable();
+  if (LAYER_V2) {
+    for (auto &l_w : weights_v2) {
+      for (auto &w : l_w) {
+        w->allocateVariable();
+      }
+    }
+  } else {
+    for (auto &l_w : weights) {
+      for (auto &w : l_w) {
+        Weight &weight = w.get();
+        weight.allocateVariable();
+      }
     }
   }
-
   weights_allocated = true;
 }
 
 void Manager::deallocateWeights() {
-  for (auto &l_w : weights) {
-    for (auto &w : l_w) {
-      Weight &weight = w.get();
-      weight.deallocateVariable();
+  if (LAYER_V2) {
+    for (auto &l_w : weights_v2) {
+      for (auto &w : l_w) {
+        w->deallocateVariable();
+      }
+    }
+  } else {
+    for (auto &l_w : weights) {
+      for (auto &w : l_w) {
+        Weight &weight = w.get();
+        weight.deallocateVariable();
+      }
     }
   }
-
   weights_allocated = false;
 }
 
@@ -306,10 +334,18 @@ void Manager::allocateGradients() {
   if (!shared_grad.uninitialized())
     shared_grad.allocate();
 
-  for (auto &l_w : weights) {
-    for (auto &w : l_w) {
-      Weight &weight = w.get();
-      weight.allocateGradient();
+  if (LAYER_V2) {
+    for (auto &l_w : weights_v2) {
+      for (auto &w : l_w) {
+        w->allocateGradient();
+      }
+    }
+  } else {
+    for (auto &l_w : weights) {
+      for (auto &w : l_w) {
+        Weight &weight = w.get();
+        weight.allocateGradient();
+      }
     }
   }
 }
@@ -317,10 +353,18 @@ void Manager::allocateGradients() {
 void Manager::deallocateGradients() {
   shared_grad.deallocate();
 
-  for (auto &l_w : weights) {
-    for (auto &w : l_w) {
-      Weight &weight = w.get();
-      weight.deallocateGradient();
+  if (LAYER_V2) {
+    for (auto &l_w : weights_v2) {
+      for (auto &w : l_w) {
+        w->deallocateGradient();
+      }
+    }
+  } else {
+    for (auto &l_w : weights) {
+      for (auto &w : l_w) {
+        Weight &weight = w.get();
+        weight.deallocateGradient();
+      }
     }
   }
 }