[WeightSharing] enable weight sharing from manager
authorJihoon Lee <jhoon.it.lee@samsung.com>
Tue, 5 Oct 2021 04:43:37 +0000 (13:43 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 7 Oct 2021 10:20:47 +0000 (19:20 +0900)
This patch enables weight sharing from manager.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/tensor/manager.cpp
nntrainer/tensor/tensor_pool.cpp

index a2038c1..3bfec15 100644 (file)
@@ -265,25 +265,44 @@ std::vector<Weight *> Manager::requestWeights(
   std::vector<Weight *> ret;
   size_t current_size = weights_v2.size();
 
-  for (auto const &ws : std::as_const(weights_spec)) {
-    auto &[dim, t_initializer, w_reg, w_reg_const, need_gradient, name] = ws;
+  for (unsigned int i = 0; i < weights_spec.size(); ++i) {
+    auto &[dim, t_initializer, w_reg, w_reg_const, need_gradient, name] =
+      weights_spec.at(i);
+
+    Tensor *var = nullptr, *grad = nullptr;
 
     if (!shared_names.empty()) {
-      /** @todo add case when shared names are given */
+      const auto &shared_name = shared_names.at(i);
+      /** case when shared names are given */
+      var = weight_pool.requestPrerequestedTensor(
+        dim, var_exec_order, var_ls,
+        name,         /// name
+        shared_name,  /// shared name
+        t_initializer /// tensor initializer
+      );
+
+      if (trainable && need_gradient) {
+        grad = tensor_pool.requestPrerequestedTensor(
+          dim, grad_exec_order, grad_ls,
+          name + Var_Grad::grad_suffix,        /// name
+          shared_name + Var_Grad::grad_suffix, /// shared name
+          Tensor::Initializer::ZEROS           /// tensor initializer
+        );
+      }
 
     } else {
-      Tensor *var = weight_pool.requestTensor(dim, var_exec_order, var_ls, name,
-                                              t_initializer);
+      /** case requesting fresh weights */
+      var = weight_pool.requestTensor(dim, var_exec_order, var_ls, name,
+                                      t_initializer);
 
-      Tensor *grad = nullptr;
       if (trainable && need_gradient)
         grad = tensor_pool.requestTensor(dim, grad_exec_order, grad_ls,
                                          name + Var_Grad::grad_suffix,
                                          Tensor::Initializer::ZEROS);
-
-      weights_v2.emplace_back(
-        std::make_unique<Weight>(var, grad, w_reg, w_reg_const));
     }
+
+    weights_v2.emplace_back(
+      std::make_unique<Weight>(var, grad, w_reg, w_reg_const));
   }
 
   std::transform(weights_v2.begin() + current_size, weights_v2.end(),
@@ -383,7 +402,7 @@ Manager::requestInputs(const GraphNode &node,
         dim, /// tensor dim
         var_exec_order, var_ls,
         var_name,                 /// name
-        outputs_name[idx],        /// name
+        outputs_name[idx],        /// shared name
         Tensor::Initializer::NONE /// tensor initializer
       );
 
@@ -391,7 +410,7 @@ Manager::requestInputs(const GraphNode &node,
         dim, /// tensor dim
         grad_exec_order, grad_ls,
         var_name + Var_Grad::grad_suffix,          /// name
-        outputs_name[idx] + Var_Grad::grad_suffix, /// name
+        outputs_name[idx] + Var_Grad::grad_suffix, /// shared name
         Tensor::Initializer::ZEROS                 /// tensor initializer
       );
     } else if (!node.getInputConnections().empty()) {
index 59ad6a2..a15f05a 100644 (file)
@@ -77,7 +77,8 @@ Tensor *TensorPool::requestPrerequestedTensor(
   TensorLifespan lifespan, const std::string &name,
   const std::string &shared_name, const Tensor::Initializer &init) {
   if (name_map.find(shared_name) == name_map.end())
-    throw std::invalid_argument("Requested shared tensor not found");
+    throw std::invalid_argument("Requested shared tensor not found, name: " +
+                                shared_name);
 
   /** find the parent non-dependent node where the spec is stored */
   int parent_spec_idx = name_map[shared_name];