#include <basic_planner.h>
#include <memory_planner_validate.h>
+#include <optimized_v1_planner.h>
constexpr unsigned int MEM_BYTES = 50;
constexpr unsigned int MEM_QUANT = 100;
planner = nullptr;
if (plan_type == nntrainer::BasicPlanner::type)
planner = std::make_unique<nntrainer::BasicPlanner>();
+ else if (plan_type == nntrainer::OptimizedV1Planner::type)
+ planner = std::make_unique<nntrainer::OptimizedV1Planner>();
else
throw std::invalid_argument("Invalid planner type");
}
}
/**
+ * @brief Validate the provided layout does partially overlap
+ *
+ * @note this test assumes that the memory validity is sorted such that start
+ * validity at index idx = idx, and end validity at index idx > idx.
+ */
+static bool validateIntervalOverlap(
+ const std::vector<std::pair<unsigned int, unsigned int>> &memory_validity,
+ const std::vector<size_t> &memory_size,
+ const std::vector<size_t> &memory_offset) {
+ std::vector<unsigned int> valid_intervals;
+ for (unsigned int idx = 0; idx < memory_size.size(); idx++) {
+ /**
+ * intervals which have finished before the start of the current intervals
+ * must be popped
+ */
+ auto expired_intervals = std::remove_if(
+ valid_intervals.begin(), valid_intervals.end(),
+ [memory_validity, idx](auto const &jdx) {
+ return memory_validity[jdx].second <= memory_validity[idx].first;
+ });
+ valid_intervals.erase(expired_intervals, valid_intervals.end());
+
+ /** get the max of the existing intervals */
+ size_t max_allocated = 0;
+ if (!valid_intervals.empty()) {
+ auto max_idx = *std::max_element(
+ valid_intervals.begin(), valid_intervals.end(),
+ [memory_offset, memory_size](auto const &v1, auto const &v2) {
+ return memory_offset[v1] + memory_size[v1] <
+ memory_offset[v2] + memory_size[v2];
+ });
+ max_allocated = memory_offset[max_idx] + memory_size[max_idx];
+ }
+
+ /** the memory planner must allocate after the max_allocated */
+ EXPECT_GE(memory_offset[idx], max_allocated);
+ /**
+ * if feeling confident about the planner, then it must allocate
+ * from exactly the max_allocated location
+ *
+ * @note this still does not guarantee the most optimal allocation
+ * @note this can be disabled for some of the planners
+ */
+ EXPECT_EQ(memory_offset[idx], max_allocated);
+
+ valid_intervals.push_back(idx);
+ }
+
+ return true;
+}
+
+/**
* @brief Validate the provided layout does not overflow outside the given
* size
*/
std::accumulate(memory_size.begin(), memory_size.end(), 0u));
EXPECT_TRUE(validateNoOverlap(memory_size, memory_offset, pool_size));
} else {
- EXPECT_EQ(pool_size,
+ EXPECT_GE(pool_size,
*std::max_element(memory_size.begin(), memory_size.end()));
- // EXPECT_TRUE(validateIntervalOverlap(memory_size, memory_offset));
+ EXPECT_LE(pool_size,
+ std::accumulate(memory_size.begin(), memory_size.end(), 0));
+ EXPECT_TRUE(
+ validateIntervalOverlap(memory_validity, memory_size, memory_offset));
}
}
}
EXPECT_NO_THROW(pool.planLayout(*planner.get()));
- EXPECT_EQ(pool.size(),
- std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+ if (planner->getType() == nntrainer::BasicPlanner::type) {
+ EXPECT_EQ(pool.size(),
+ std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+ } else {
+ EXPECT_EQ(pool.size(),
+ *std::max_element(memory_size.begin(), memory_size.end()));
+ }
EXPECT_NO_THROW(pool.allocate());
for (unsigned int idx = 0; idx < MEM_QUANT; idx++)
}
EXPECT_NO_THROW(pool.planLayout(*planner.get()));
- EXPECT_EQ(pool.size(),
- std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+ if (planner->getType() == nntrainer::BasicPlanner::type) {
+ EXPECT_EQ(pool.size(),
+ std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+ } else {
+ EXPECT_GE(pool.size(),
+ *std::max_element(memory_size.begin(), memory_size.end()));
+ EXPECT_LE(pool.size(),
+ std::accumulate(memory_size.begin(), memory_size.end(), 0u));
+ }
EXPECT_NO_THROW(pool.allocate());
for (unsigned int idx = 0; idx < MEM_QUANT; idx++)