[tp] Implement reidentifySource()
authorJihoon Lee <jhoon.it.lee@samsung.com>
Fri, 5 Nov 2021 09:42:23 +0000 (18:42 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 17 Nov 2021 06:56:45 +0000 (15:56 +0900)
This patch implements reidentifySource(). Test will be followed

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/tensor/tensor_pool.cpp

index e4d628b..b7be963 100644 (file)
@@ -99,10 +99,10 @@ Tensor *TensorPool::requestPrerequestedTensor(
   /** @note below invalidates spec reference */
   /** @note in case of view of view, internal datastructure saves the src to
    * view index, not view to view reference in order to flatten depth */
-  auto parent_name = name_map.at(spec.tensor->getName());
+  auto parent_idx = name_map.at(spec.tensor->getName());
   return registerRequestSpec(
     {std::make_unique<Tensor>(dim, false, init, name),
-     TensorPool::DependentDetails{parent_name, adjusted_offset}});
+     TensorPool::DependentDetails{parent_idx, adjusted_offset}});
 }
 
 /**
@@ -351,6 +351,48 @@ Tensor *TensorPool::requestOrExtend(const std::string &name,
   }
 }
 
+void TensorPool::reidentifySource(const std::string &dest,
+                                  const std::string &new_src,
+                                  unsigned int offset) {
+  /// @todo add test
+  /// source tensor of dest tensor becomes a view of new_src
+  auto &old_spec = getSourceSpec(dest);
+  auto &old_details = std::get<SourceDetails>(old_spec.details);
+
+  /// 1. extend new_src with old src
+  auto &new_spec = getSourceSpec(new_src);
+  expandLifespan(new_spec, old_details.exec_order, old_details.lifespan);
+  auto &new_dependents = std::get<SourceDetails>(new_spec.details).dependents;
+  new_dependents.insert(new_dependents.end(), old_details.dependents.begin(),
+                        old_details.dependents.end());
+
+  /// 2. calcaulate base offset from the new_src
+  auto new_parent_idx = name_map.at(new_src);
+  unsigned base_offset = std::visit(
+    [](const auto &s) {
+      using T = std::decay_t<decltype(s)>;
+      if constexpr (std::is_same_v<T, SourceDetails>) {
+        return 0u;
+      } else if constexpr (std::is_same_v<T, DependentDetails>) {
+        return s.offset;
+      }
+      return 0u;
+    },
+    pool[new_parent_idx].details);
+  base_offset += offset;
+
+  /// 3. transform parent idx/offset of old src's dependents base on the offset
+  for (auto &dep : old_details.dependents) {
+    auto &dep_spec = pool.at(dep);
+    auto &details = std::get<DependentDetails>(dep_spec.details);
+    details.offset += base_offset;
+    details.parent_idx = new_parent_idx;
+  }
+
+  /// 4. replace old details to dependent srcs
+  old_spec.details = DependentDetails{new_parent_idx, base_offset};
+}
+
 bool TensorPool::tensorExist(const std::string &name) {
   /// @todo consider use a helper function to check, eg) something like
   /// getTensor()