[planner] Optimized v1 planner unittests
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 3 Sep 2021 04:53:29 +0000 (13:53 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 5 Oct 2021 04:54:00 +0000 (13:54 +0900)
This patch adds unittests for the optimized v1 planner.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
test/unittest/memory/memory_planner_validate.cpp
test/unittest/memory/unittest_memory_planner.cpp
test/unittest/memory/unittest_memory_pool.cpp

index 46a5c26..5097c8e 100644 (file)
@@ -17,6 +17,7 @@
 
 #include <basic_planner.h>
 #include <memory_planner_validate.h>
+#include <optimized_v1_planner.h>
 
 constexpr unsigned int MEM_BYTES = 50;
 constexpr unsigned int MEM_QUANT = 100;
@@ -28,6 +29,8 @@ void MemoryPlannerValidate::SetUp() {
   planner = nullptr;
   if (plan_type == nntrainer::BasicPlanner::type)
     planner = std::make_unique<nntrainer::BasicPlanner>();
+  else if (plan_type == nntrainer::OptimizedV1Planner::type)
+    planner = std::make_unique<nntrainer::OptimizedV1Planner>();
   else
     throw std::invalid_argument("Invalid planner type");
 }
@@ -69,6 +72,58 @@ static bool validateAllOverlap(const std::vector<size_t> &memory_size,
 }
 
 /**
+ * @brief Validate the provided layout does partially overlap
+ *
+ * @note this test assumes that the memory validity is sorted such that start
+ * validity at index idx = idx, and end validity at index idx > idx.
+ */
+static bool validateIntervalOverlap(
+  const std::vector<std::pair<unsigned int, unsigned int>> &memory_validity,
+  const std::vector<size_t> &memory_size,
+  const std::vector<size_t> &memory_offset) {
+  std::vector<unsigned int> valid_intervals;
+  for (unsigned int idx = 0; idx < memory_size.size(); idx++) {
+    /**
+     * intervals which have finished before the start of the current intervals
+     * must be popped
+     */
+    auto expired_intervals = std::remove_if(
+      valid_intervals.begin(), valid_intervals.end(),
+      [memory_validity, idx](auto const &jdx) {
+        return memory_validity[jdx].second <= memory_validity[idx].first;
+      });
+    valid_intervals.erase(expired_intervals, valid_intervals.end());
+
+    /** get the max of the existing intervals */
+    size_t max_allocated = 0;
+    if (!valid_intervals.empty()) {
+      auto max_idx = *std::max_element(
+        valid_intervals.begin(), valid_intervals.end(),
+        [memory_offset, memory_size](auto const &v1, auto const &v2) {
+          return memory_offset[v1] + memory_size[v1] <
+                 memory_offset[v2] + memory_size[v2];
+        });
+      max_allocated = memory_offset[max_idx] + memory_size[max_idx];
+    }
+
+    /** the memory planner must allocate after the max_allocated */
+    EXPECT_GE(memory_offset[idx], max_allocated);
+    /**
+     * if feeling confident about the planner, then it must allocate
+     * from exactly the max_allocated location
+     *
+     * @note this still does not guarantee the most optimal allocation
+     * @note this can be disabled for some of the planners
+     */
+    EXPECT_EQ(memory_offset[idx], max_allocated);
+
+    valid_intervals.push_back(idx);
+  }
+
+  return true;
+}
+
+/**
  * @brief Validate the provided layout does not overflow outside the given
  * size
  */
@@ -171,8 +226,11 @@ TEST_P(MemoryPlannerValidate, partial_overlap) {
               std::accumulate(memory_size.begin(), memory_size.end(), 0u));
     EXPECT_TRUE(validateNoOverlap(memory_size, memory_offset, pool_size));
   } else {
-    EXPECT_EQ(pool_size,
+    EXPECT_GE(pool_size,
               *std::max_element(memory_size.begin(), memory_size.end()));
-    // EXPECT_TRUE(validateIntervalOverlap(memory_size, memory_offset));
+    EXPECT_LE(pool_size,
+              std::accumulate(memory_size.begin(), memory_size.end(), 0));
+    EXPECT_TRUE(
+      validateIntervalOverlap(memory_validity, memory_size, memory_offset));
   }
 }
index 32eba9b..20b1ba2 100644 (file)
 #include <memory_planner_validate.h>
 
 #include <basic_planner.h>
+#include <optimized_v1_planner.h>
 
 INSTANTIATE_TEST_CASE_P(BasicPlanner, MemoryPlannerValidate,
                         ::testing::Values(nntrainer::BasicPlanner::type));
+
+INSTANTIATE_TEST_CASE_P(OptimizedV1Planner, MemoryPlannerValidate,
+                        ::testing::Values(nntrainer::OptimizedV1Planner::type));
index 3c1f361..bafbe4d 100644 (file)
@@ -480,8 +480,13 @@ TEST_P(MemoryPlannerValidate, validate_memory_no_overlap) {
   }
 
   EXPECT_NO_THROW(pool.planLayout(*planner.get()));
-  EXPECT_EQ(pool.size(),
-            std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+  if (planner->getType() == nntrainer::BasicPlanner::type) {
+    EXPECT_EQ(pool.size(),
+              std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+  } else {
+    EXPECT_EQ(pool.size(),
+              *std::max_element(memory_size.begin(), memory_size.end()));
+  }
   EXPECT_NO_THROW(pool.allocate());
 
   for (unsigned int idx = 0; idx < MEM_QUANT; idx++)
@@ -523,8 +528,15 @@ TEST_P(MemoryPlannerValidate, validate_memory_partial_overlap) {
   }
 
   EXPECT_NO_THROW(pool.planLayout(*planner.get()));
-  EXPECT_EQ(pool.size(),
-            std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+  if (planner->getType() == nntrainer::BasicPlanner::type) {
+    EXPECT_EQ(pool.size(),
+              std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+  } else {
+    EXPECT_GE(pool.size(),
+              *std::max_element(memory_size.begin(), memory_size.end()));
+    EXPECT_LE(pool.size(),
+              std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+  }
   EXPECT_NO_THROW(pool.allocate());
 
   for (unsigned int idx = 0; idx < MEM_QUANT; idx++)