namespace nntrainer {
-/// lambda overload helper
+/**
+ * @brief lambda overload helper
+ *
+ */
template <class... Ts> struct overloaded_ : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded_(Ts...)->overloaded_<Ts...>;
return getTensor(name);
}
+Tensor *TensorPool::createOrExtend(const std::string &name,
+ const TensorDim &dim,
+ const std::vector<unsigned int> &exec_order,
+ TensorLifespan lifespan,
+ const Tensor::Initializer &init) {
+ NNTR_THROW_IF(lifespan == TensorLifespan::UNMANAGED, std::invalid_argument)
+ << "unmanaged life span is not supported";
+
+ if (tensorExist(name)) {
+ Tensor *t = getTensor(name);
+ NNTR_THROW_IF(t->getDim() != dim, std::invalid_argument)
+ << "tensor dimension mismatch for createOrExtend name: " << name;
+ NNTR_THROW_IF(t->getInitializer() != init, std::invalid_argument)
+ << "tensor initializer mismatch for createOrExtend name: " << name;
+ return extend(name, exec_order, lifespan);
+ } else {
+ return create(name, dim, exec_order, lifespan, init);
+ }
+}
+
bool TensorPool::tensorExist(const std::string &name) {
return name_map.count(name);
}
EXPECT_ANY_THROW(pool.extend("t1", {2}, max_ls));
}
+TEST(TensorPool, createOrExtend_p) {
+ nntrainer::TensorPool pool;
+ auto t1 = pool.createOrExtend("t", {10}, {0}, max_ls);
+ auto t2 = pool.createOrExtend("t", {10}, {1}, max_ls);
+
+ auto &exec_order = pool.getExecutionOrder("t");
+ EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 0),
+ exec_order.end());
+ EXPECT_NE(std::find(exec_order.begin(), exec_order.end(), 1),
+ exec_order.end());
+ EXPECT_EQ(t1, t2);
+ pool.finalize(nntrainer::BasicPlanner(), 0, 2);
+ pool.allocate();
+ EXPECT_EQ(*t1, *t2);
+ pool.deallocate();
+}
+
+TEST(TensorPool, createOrExtend_different_dim_n) {
+ nntrainer::TensorPool pool;
+ pool.createOrExtend("t", {10, 1}, {0}, max_ls);
+ EXPECT_ANY_THROW(pool.createOrExtend("t", {1, 10}, {1}, max_ls));
+}
+
+TEST(TensorPool, createOrExtend_init_n) {
+ nntrainer::TensorPool pool;
+ pool.createOrExtend("t", {10}, {0}, max_ls,
+ nntrainer::Tensor::Initializer::ONES);
+ EXPECT_ANY_THROW(pool.createOrExtend("t", {10}, {1}, max_ls,
+ nntrainer::Tensor::Initializer::ZEROS));
+}
+TEST(TensorPool, createOrExtend_unmanaged_n) {
+ nntrainer::TensorPool pool;
+ EXPECT_ANY_THROW(
+ pool.createOrExtend("t", {10}, {0}, nntrainer::TensorLifespan::UNMANAGED));
+}
+
/**
* @brief Main gtest
*/