Policy based tensor request for tensor pool.
authorJihoon Lee <jhoon.it.lee@samsung.com>
Fri, 5 Nov 2021 08:38:02 +0000 (17:38 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 16 Nov 2021 07:46:20 +0000 (16:46 +0900)
This commit specifically implement createOrExtend().

From this commit, basic tensor pool request will be possible except
`reidentifySource()`.
`reidentifySource()` will come as a separate PR to get better review.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/tensor/tensor_pool.cpp
nntrainer/tensor/tensor_pool.h
test/unittest/unittest_nntrainer_tensor_pool.cpp

index fabbdd3..b262137 100644 (file)
 
 namespace nntrainer {
 
-/// lambda overload helper
+/**
+ * @brief lambda overload helper
+ *
+ */
 template <class... Ts> struct overloaded_ : Ts... { using Ts::operator()...; };
 template <class... Ts> overloaded_(Ts...)->overloaded_<Ts...>;
 
@@ -326,6 +329,26 @@ Tensor *TensorPool::extend(const std::string &name,
   return getTensor(name);
 }
 
+Tensor *TensorPool::createOrExtend(const std::string &name,
+                                   const TensorDim &dim,
+                                   const std::vector<unsigned int> &exec_order,
+                                   TensorLifespan lifespan,
+                                   const Tensor::Initializer &init) {
+  NNTR_THROW_IF(lifespan == TensorLifespan::UNMANAGED, std::invalid_argument)
+    << "unmanaged life span is not supported";
+
+  if (tensorExist(name)) {
+    Tensor *t = getTensor(name);
+    NNTR_THROW_IF(t->getDim() != dim, std::invalid_argument)
+      << "tensor dimension mismatch for createOrExtend name: " << name;
+    NNTR_THROW_IF(t->getInitializer() != init, std::invalid_argument)
+      << "tensor initializer mismatch for createOrExtend name: " << name;
+    return extend(name, exec_order, lifespan);
+  } else {
+    return create(name, dim, exec_order, lifespan, init);
+  }
+}
+
 bool TensorPool::tensorExist(const std::string &name) {
   return name_map.count(name);
 }
index 29381eb..07faa34 100644 (file)
@@ -259,6 +259,9 @@ public:
   /**
    * @brief create a new tensor if tensor does not exist else return the tensor
    * while extending the tensor's life according to the given arguments.
+   * @note Created (or extended) tensor is considered identical and managed. It
+   * is invalid to create a tensor with lifespan::UNMANAGED or dimension and
+   * initializer is different unon extension.
    *
    * @param name Name of the tensor
    * @param dim dimension
index 2141a27..0d06be9 100644 (file)
@@ -659,6 +659,42 @@ TEST(TensorPool, extend_unmanged_n) {
   EXPECT_ANY_THROW(pool.extend("t1", {2}, max_ls));
 }
 
+TEST(TensorPool, createOrExtend_p) {
+  nntrainer::TensorPool pool;
+  auto t1 = pool.createOrExtend("t", {10}, {0}, max_ls);
+  auto t2 = pool.createOrExtend("t", {10}, {1}, max_ls);
+
+  auto &exec_order = pool.getExecutionOrder("t");
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 0),
+            exec_order.end());
+  EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 1),
+            exec_order.end());
+  EXPECT_EQ(t1, t2);
+  pool.finalize(nntrainer::BasicPlanner(), 0, 2);
+  pool.allocate();
+  EXPECT_EQ(*t1, *t2);
+  pool.deallocate();
+}
+
+TEST(TensorPool, createOrExtend_different_dim_n) {
+  nntrainer::TensorPool pool;
+  pool.createOrExtend("t", {10, 1}, {0}, max_ls);
+  EXPECT_ANY_THROW(pool.createOrExtend("t", {1, 10}, {1}, max_ls));
+}
+
+TEST(TensorPool, createOrExtend_init_n) {
+  nntrainer::TensorPool pool;
+  pool.createOrExtend("t", {10}, {0}, max_ls,
+                      nntrainer::Tensor::Initializer::ONES);
+  EXPECT_ANY_THROW(pool.createOrExtend("t", {10}, {1}, max_ls,
+                                       nntrainer::Tensor::Initializer::ZEROS));
+}
+TEST(TensorPool, createOrExtend_unmanaged_n) {
+  nntrainer::TensorPool pool;
+  EXPECT_ANY_THROW(
+    pool.createOrExtend("t", {10}, {0}, nntrainer::TensorLifespan::UNMANAGED));
+}
+
 /**
  * @brief Main gtest
  */