[TP] Implement extend()
authorJihoon Lee <jhoon.it.lee@samsung.com>
Fri, 5 Nov 2021 08:08:13 +0000 (17:08 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 16 Nov 2021 07:46:20 +0000 (16:46 +0900)
This patch implements extend() and corresponding tests

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/tensor/tensor_pool.cpp
test/unittest/unittest_nntrainer_tensor_pool.cpp

index 65cf401..fabbdd3 100644 (file)
@@ -316,6 +316,16 @@ Tensor *TensorPool::view(const std::string &name, const std::string &reference,
                                    Tensor::Initializer::NONE, offset);
 }
 
+Tensor *TensorPool::extend(const std::string &name,
+                           const std::vector<unsigned int> &exec_order,
+                           TensorLifespan lifespan) {
+  NNTR_THROW_IF(!tensorExist(name), std::invalid_argument)
+    << " cannot extend tensor which does not exist, name: " << name;
+  auto &spec = getSourceSpec(name);
+  expandLifespan(spec, exec_order, lifespan);
+  return getTensor(name);
+}
+
 bool TensorPool::tensorExist(const std::string &name) {
   return name_map.count(name);
 }
index d6c3d2d..2141a27 100644 (file)
@@ -590,6 +590,75 @@ TEST(TensorPool, view_of_placeholder_out_of_range_n) {
   EXPECT_ANY_THROW(pool.view("t1", "t0", {1}, {0}, max_ls, 11));
 }
 
+TEST(TensorPool, extend_source_p) {
+  nntrainer::TensorPool pool;
+  pool.create("t0", {10}, {0},
+              nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN);
+  pool.extend("t0", {1}, nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN);
+
+  auto &exec_order = pool.getExecutionOrder("t0");
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 0),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 1),
+            exec_order.end());
+}
+
+TEST(TensorPool, extend_view_p) {
+  nntrainer::TensorPool pool;
+  pool.create("t0", {10}, {0},
+              nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN);
+  pool.view("t1", "t0", {10}, {1},
+            nntrainer::TensorLifespan::BACKWARD_FUNC_LIFESPAN);
+  pool.extend("t1", {2}, max_ls);
+
+  auto &exec_order = pool.getExecutionOrder("t0");
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 0),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 1),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 2),
+            exec_order.end());
+}
+
+TEST(TensorPool, extend_placeholder_p) {
+  nntrainer::TensorPool pool;
+  pool.placeholder("t0", {10});
+  pool.extend("t0", {2}, max_ls);
+
+  auto &exec_order = pool.getExecutionOrder("t0");
+  EXPECT_EQ(std::find(exec_order.begin(), exec_order.end(), 0),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 2),
+            exec_order.end());
+}
+
+TEST(TensorPool, extend_view_of_placeholder_p) {
+  nntrainer::TensorPool pool;
+  pool.placeholder("t0", {10});
+  pool.view("t1", "t0", {10}, {1},
+            nntrainer::TensorLifespan::BACKWARD_FUNC_LIFESPAN);
+  pool.extend("t1", {2}, max_ls);
+
+  auto &exec_order = pool.getExecutionOrder("t0");
+  EXPECT_EQ(std::find(exec_order.begin(), exec_order.end(), 0),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 1),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 2),
+            exec_order.end());
+}
+
+TEST(TensorPool, extend_out_of_range_n) {
+  nntrainer::TensorPool pool;
+  EXPECT_ANY_THROW(pool.extend("t1", {2}, max_ls));
+}
+
+TEST(TensorPool, extend_unmanged_n) {
+  nntrainer::TensorPool pool;
+  pool.create("t0", {10}, {0}, nntrainer::TensorLifespan::UNMANAGED);
+  EXPECT_ANY_THROW(pool.extend("t1", {2}, max_ls));
+}
+
 /**
  * @brief Main gtest
  */