From bee89a7ce4bb6d2d608ab490249834ffb65f7a5c Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Fri, 3 Sep 2021 13:53:29 +0900 Subject: [PATCH] [planner] Optimized v1 planner unittests This patch adds unittests for the optimized v1 planner. Signed-off-by: Parichay Kapoor --- test/unittest/memory/memory_planner_validate.cpp | 62 +++++++++++++++++++++++- test/unittest/memory/unittest_memory_planner.cpp | 4 ++ test/unittest/memory/unittest_memory_pool.cpp | 20 ++++++-- 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/test/unittest/memory/memory_planner_validate.cpp b/test/unittest/memory/memory_planner_validate.cpp index 46a5c26..5097c8e 100644 --- a/test/unittest/memory/memory_planner_validate.cpp +++ b/test/unittest/memory/memory_planner_validate.cpp @@ -17,6 +17,7 @@ #include #include +#include constexpr unsigned int MEM_BYTES = 50; constexpr unsigned int MEM_QUANT = 100; @@ -28,6 +29,8 @@ void MemoryPlannerValidate::SetUp() { planner = nullptr; if (plan_type == nntrainer::BasicPlanner::type) planner = std::make_unique(); + else if (plan_type == nntrainer::OptimizedV1Planner::type) + planner = std::make_unique(); else throw std::invalid_argument("Invalid planner type"); } @@ -69,6 +72,58 @@ static bool validateAllOverlap(const std::vector &memory_size, } /** + * @brief Validate the provided layout does partially overlap + * + * @note this test assumes that the memory validity is sorted such that start + * validity at index idx = idx, and end validity at index idx > idx. + */ +static bool validateIntervalOverlap( + const std::vector> &memory_validity, + const std::vector &memory_size, + const std::vector &memory_offset) { + std::vector valid_intervals; + for (unsigned int idx = 0; idx < memory_size.size(); idx++) { + /** + * intervals which have finished before the start of the current intervals + * must be popped + */ + auto expired_intervals = std::remove_if( + valid_intervals.begin(), valid_intervals.end(), + [memory_validity, idx](auto const &jdx) { + return memory_validity[jdx].second <= memory_validity[idx].first; + }); + valid_intervals.erase(expired_intervals, valid_intervals.end()); + + /** get the max of the existing intervals */ + size_t max_allocated = 0; + if (!valid_intervals.empty()) { + auto max_idx = *std::max_element( + valid_intervals.begin(), valid_intervals.end(), + [memory_offset, memory_size](auto const &v1, auto const &v2) { + return memory_offset[v1] + memory_size[v1] < + memory_offset[v2] + memory_size[v2]; + }); + max_allocated = memory_offset[max_idx] + memory_size[max_idx]; + } + + /** the memory planner must allocate after the max_allocated */ + EXPECT_GE(memory_offset[idx], max_allocated); + /** + * if feeling confident about the planner, then it must allocate + * from exactly the max_allocated location + * + * @note this still does not guarantee the most optimal allocation + * @note this can be disabled for some of the planners + */ + EXPECT_EQ(memory_offset[idx], max_allocated); + + valid_intervals.push_back(idx); + } + + return true; +} + +/** * @brief Validate the provided layout does not overflow outside the given * size */ @@ -171,8 +226,11 @@ TEST_P(MemoryPlannerValidate, partial_overlap) { std::accumulate(memory_size.begin(), memory_size.end(), 0u)); EXPECT_TRUE(validateNoOverlap(memory_size, memory_offset, pool_size)); } else { - EXPECT_EQ(pool_size, + EXPECT_GE(pool_size, *std::max_element(memory_size.begin(), memory_size.end())); - // EXPECT_TRUE(validateIntervalOverlap(memory_size, memory_offset)); + EXPECT_LE(pool_size, + std::accumulate(memory_size.begin(), memory_size.end(), 0)); + EXPECT_TRUE( + validateIntervalOverlap(memory_validity, memory_size, memory_offset)); } } diff --git a/test/unittest/memory/unittest_memory_planner.cpp b/test/unittest/memory/unittest_memory_planner.cpp index 32eba9b..20b1ba2 100644 --- a/test/unittest/memory/unittest_memory_planner.cpp +++ b/test/unittest/memory/unittest_memory_planner.cpp @@ -15,6 +15,10 @@ #include #include +#include INSTANTIATE_TEST_CASE_P(BasicPlanner, MemoryPlannerValidate, ::testing::Values(nntrainer::BasicPlanner::type)); + +INSTANTIATE_TEST_CASE_P(OptimizedV1Planner, MemoryPlannerValidate, + ::testing::Values(nntrainer::OptimizedV1Planner::type)); diff --git a/test/unittest/memory/unittest_memory_pool.cpp b/test/unittest/memory/unittest_memory_pool.cpp index 3c1f361..bafbe4d 100644 --- a/test/unittest/memory/unittest_memory_pool.cpp +++ b/test/unittest/memory/unittest_memory_pool.cpp @@ -480,8 +480,13 @@ TEST_P(MemoryPlannerValidate, validate_memory_no_overlap) { } EXPECT_NO_THROW(pool.planLayout(*planner.get())); - EXPECT_EQ(pool.size(), - std::accumulate(memory_size.begin(), memory_size.end(), 0u)); + if (planner->getType() == nntrainer::BasicPlanner::type) { + EXPECT_EQ(pool.size(), + std::accumulate(memory_size.begin(), memory_size.end(), 0u)); + } else { + EXPECT_EQ(pool.size(), + *std::max_element(memory_size.begin(), memory_size.end())); + } EXPECT_NO_THROW(pool.allocate()); for (unsigned int idx = 0; idx < MEM_QUANT; idx++) @@ -523,8 +528,15 @@ TEST_P(MemoryPlannerValidate, validate_memory_partial_overlap) { } EXPECT_NO_THROW(pool.planLayout(*planner.get())); - EXPECT_EQ(pool.size(), - std::accumulate(memory_size.begin(), memory_size.end(), 0u)); + if (planner->getType() == nntrainer::BasicPlanner::type) { + EXPECT_EQ(pool.size(), + std::accumulate(memory_size.begin(), memory_size.end(), 0u)); + } else { + EXPECT_GE(pool.size(), + *std::max_element(memory_size.begin(), memory_size.end())); + EXPECT_LE(pool.size(), + std::accumulate(memory_size.begin(), memory_size.end(), 0u)); + } EXPECT_NO_THROW(pool.allocate()); for (unsigned int idx = 0; idx < MEM_QUANT; idx++) -- 2.7.4