From: Jihoon Lee Date: Fri, 5 Nov 2021 08:38:02 +0000 (+0900) Subject: Policy based tensor request for tensor pool. X-Git-Tag: accepted/tizen/unified/20220323.062643~216 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e46917b2811f84be7f318f32142ef0d07298a84f;p=platform%2Fcore%2Fml%2Fnntrainer.git Policy based tensor request for tensor pool. This commit specifically implement createOrExtend(). From this commit, basic tensor pool request will be possible except `reidentifySource()`. `reidentifySource()` will come as a separate PR to get better review. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- diff --git a/nntrainer/tensor/tensor_pool.cpp b/nntrainer/tensor/tensor_pool.cpp index fabbdd3..b262137 100644 --- a/nntrainer/tensor/tensor_pool.cpp +++ b/nntrainer/tensor/tensor_pool.cpp @@ -23,7 +23,10 @@ namespace nntrainer { -/// lambda overload helper +/** + * @brief lambda overload helper + * + */ template struct overloaded_ : Ts... { using Ts::operator()...; }; template overloaded_(Ts...)->overloaded_; @@ -326,6 +329,26 @@ Tensor *TensorPool::extend(const std::string &name, return getTensor(name); } +Tensor *TensorPool::createOrExtend(const std::string &name, + const TensorDim &dim, + const std::vector &exec_order, + TensorLifespan lifespan, + const Tensor::Initializer &init) { + NNTR_THROW_IF(lifespan == TensorLifespan::UNMANAGED, std::invalid_argument) + << "unmanaged life span is not supported"; + + if (tensorExist(name)) { + Tensor *t = getTensor(name); + NNTR_THROW_IF(t->getDim() != dim, std::invalid_argument) + << "tensor dimension mismatch for createOrExtend name: " << name; + NNTR_THROW_IF(t->getInitializer() != init, std::invalid_argument) + << "tensor initializer mismatch for createOrExtend name: " << name; + return extend(name, exec_order, lifespan); + } else { + return create(name, dim, exec_order, lifespan, init); + } +} + bool TensorPool::tensorExist(const std::string &name) { return name_map.count(name); } diff --git a/nntrainer/tensor/tensor_pool.h b/nntrainer/tensor/tensor_pool.h index 29381eb..07faa34 100644 --- a/nntrainer/tensor/tensor_pool.h +++ b/nntrainer/tensor/tensor_pool.h @@ -259,6 +259,9 @@ public: /** * @brief create a new tensor if tensor does not exist else return the tensor * while extending the tensor's life according to the given arguments. + * @note Created (or extended) tensor is considered identical and managed. It + * is invalid to create a tensor with lifespan::UNMANAGED or dimension and + * initializer is different unon extension. * * @param name Name of the tensor * @param dim dimension diff --git a/test/unittest/unittest_nntrainer_tensor_pool.cpp b/test/unittest/unittest_nntrainer_tensor_pool.cpp index 2141a27..0d06be9 100644 --- a/test/unittest/unittest_nntrainer_tensor_pool.cpp +++ b/test/unittest/unittest_nntrainer_tensor_pool.cpp @@ -659,6 +659,42 @@ TEST(TensorPool, extend_unmanged_n) { EXPECT_ANY_THROW(pool.extend("t1", {2}, max_ls)); } +TEST(TensorPool, createOrExtend_p) { + nntrainer::TensorPool pool; + auto t1 = pool.createOrExtend("t", {10}, {0}, max_ls); + auto t2 = pool.createOrExtend("t", {10}, {1}, max_ls); + + auto &exec_order = pool.getExecutionOrder("t"); + EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 0), + exec_order.end()); + EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 1), + exec_order.end()); + EXPECT_EQ(t1, t2); + pool.finalize(nntrainer::BasicPlanner(), 0, 2); + pool.allocate(); + EXPECT_EQ(*t1, *t2); + pool.deallocate(); +} + +TEST(TensorPool, createOrExtend_different_dim_n) { + nntrainer::TensorPool pool; + pool.createOrExtend("t", {10, 1}, {0}, max_ls); + EXPECT_ANY_THROW(pool.createOrExtend("t", {1, 10}, {1}, max_ls)); +} + +TEST(TensorPool, createOrExtend_init_n) { + nntrainer::TensorPool pool; + pool.createOrExtend("t", {10}, {0}, max_ls, + nntrainer::Tensor::Initializer::ONES); + EXPECT_ANY_THROW(pool.createOrExtend("t", {10}, {1}, max_ls, + nntrainer::Tensor::Initializer::ZEROS)); +} +TEST(TensorPool, createOrExtend_unmanaged_n) { + nntrainer::TensorPool pool; + EXPECT_ANY_THROW( + pool.createOrExtend("t", {10}, {0}, nntrainer::TensorLifespan::UNMANAGED)); +} + /** * @brief Main gtest */