[planner] Add optimized v1 planner
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 3 Sep 2021 04:50:57 +0000 (13:50 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 5 Oct 2021 04:54:00 +0000 (13:54 +0900)
This patch add optimized v1 planner for memory sharing.
This planner assigns memory in the order of start of the validity,
and then by decreasing order of the end of the validity. This matches
the memory use pattern while training the model.
This planner is supposed to work better for training than for inference.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/tensor/basic_planner.cpp
nntrainer/tensor/meson.build
nntrainer/tensor/optimized_v1_planner.cpp [new file with mode: 0644]
nntrainer/tensor/optimized_v1_planner.h [new file with mode: 0644]

index 9b1bac7..3122549 100644 (file)
@@ -2,7 +2,7 @@
 /**
  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
  *
- * @file   basic_planner.h
+ * @file   basic_planner.cpp
  * @date   11 August 2021
  * @see    https://github.com/nnstreamer/nntrainer
  * @author Parichay Kapoor <pk.kapoor@samsung.com>
index 68fb715..deb4b84 100644 (file)
@@ -8,7 +8,8 @@ tensor_sources = [
   'weight.cpp',
   'basic_planner.cpp',
   'memory_pool.cpp',
-  'tensor_pool.cpp'
+  'tensor_pool.cpp',
+  'optimized_v1_planner.cpp'
 ]
 
 tensor_headers = [
diff --git a/nntrainer/tensor/optimized_v1_planner.cpp b/nntrainer/tensor/optimized_v1_planner.cpp
new file mode 100644 (file)
index 0000000..9b8dfb6
--- /dev/null
@@ -0,0 +1,124 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   optimized_v1_planner.cpp
+ * @date   3 September 2021
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is Optimized V1 Memory Planner
+ *
+ */
+
+#include <algorithm>
+#include <queue>
+#include <vector>
+
+#include <optimized_v1_planner.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Memory Request data structure clubbing all the requests
+ *
+ */
+struct MemoryRequest {
+  unsigned int start; /**< start of the validity (inclusive) */
+  unsigned int end;   /**< end of the validity (exclusive) */
+  unsigned int loc;   /**< index/location of the this request */
+  size_t size;        /**< size of the request */
+  size_t offset;      /**< offset for this request */
+
+  /**
+   * @brief Constructor for the Memory Request
+   *
+   */
+  MemoryRequest(size_t s, std::pair<unsigned int, unsigned int> valid,
+                unsigned int idx) :
+    start(valid.first),
+    end(valid.second),
+    loc(idx),
+    size(s),
+    offset(0) {}
+};
+
+/**
+ * @copydoc MemoryPlanner::planLayout(
+ * const std::vector<size_t> &memory_size,
+ * const std::vector<std::pair<unsigned int, unsigned int>> &memory_validity,
+ * std::vector<size_t> &memory_offset);
+ *
+ * @details The optimized v1 memory planner assigns memory to the requests whose
+ * validity starts first.
+ * The requested memories are sorted based on the ascending order of the start
+ * timestamps, and descending order using the end timestamps. The
+ * sorted memories are given increasing offset based on the memory size.
+ * At the end of each timestamp, invalid memories are freed, and offset updated
+ * for reuse. This planner allocates overlapping memory for all the required
+ * memories.
+ *
+ */
+size_t OptimizedV1Planner::planLayout(
+  const std::vector<size_t> &memory_size,
+  const std::vector<std::pair<unsigned int, unsigned int>> &memory_validity,
+  std::vector<size_t> &memory_offset) const {
+
+  /** create memory requests structure array for easier management */
+  std::vector<MemoryRequest> requests;
+  requests.reserve(memory_size.size());
+
+  for (unsigned int idx = 0; idx < memory_size.size(); idx++) {
+    requests.emplace_back(memory_size[idx], memory_validity[idx], idx);
+  }
+
+  /**
+   * sort the memory requests with ascending order of start time first, and
+   * then end time
+   */
+  std::sort(requests.begin(), requests.end(),
+            [](auto const &v1, auto const &v2) -> int {
+              if (v1.start == v2.start)
+                return v1.end < v2.end;
+              return v1.start < v2.start;
+              /** TODO: try this */
+              //   if (v1.end == v2.end)
+              //     return v1.start < v2.start;
+              //   return v1.end > v2.end;
+            });
+
+  /** all the memories in use sorted by their assigned offset */
+  auto cmp = [](const MemoryRequest *v1, const MemoryRequest *v2) -> bool {
+    return v1->offset < v2->offset;
+  };
+  std::priority_queue<MemoryRequest *, std::vector<MemoryRequest *>,
+                      decltype(cmp)>
+    pq(cmp);
+
+  /** iterate over the sorted requests and start allocation of the requests */
+  size_t memory_req = 0;
+  for (auto &req : requests) {
+    /** remove expired memories and update offset */
+    while (!pq.empty() && pq.top()->end <= req.start)
+      pq.pop();
+
+    /** get the offset based on the max valid offset */
+    size_t offset = 0;
+    if (!pq.empty())
+      offset = pq.top()->offset + pq.top()->size;
+
+    /** assign offset to the new request and push to queue */
+    req.offset = offset;
+    memory_req = std::max(memory_req, req.offset + req.size);
+    pq.push(&req);
+  }
+
+  /** set the memory offset in the return array */
+  memory_offset.resize(memory_size.size());
+  for (auto const &req : requests)
+    memory_offset[req.loc] = req.offset;
+
+  return memory_req;
+}
+
+} // namespace nntrainer
diff --git a/nntrainer/tensor/optimized_v1_planner.h b/nntrainer/tensor/optimized_v1_planner.h
new file mode 100644 (file)
index 0000000..6834f24
--- /dev/null
@@ -0,0 +1,75 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file   optimzied_v1_planner.h
+ * @date   2 September 2021
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is Optimized V1 Memory Planner
+ *
+ * @note This planner has been design to give reduced memory usage for training
+ * and might not perform very well for inference.
+ *
+ * @details The principle for this planner is to give memory to the requests in
+ * the order of the start of their validity.
+ * This takes advantage of the pattern that the outputs of the layer nodes
+ * allocated during forwarding are also used during backwarding as well.
+ *
+ * If two memory requests have the same start time, then the memory request with
+ * higher end is allocated first. This is to minimize the fragmentation once the
+ * memory is being freed.
+ *
+ * The assigned memories are cached, and once their validity is finished, they
+ * are freed and reused for the next allocations.
+ */
+
+#ifndef __OPTIMIZED_V1_PLANNER_H_
+#define __OPTIMIZED_V1_PLANNER_H_
+
+#include <vector>
+
+#include <memory_planner.h>
+
+namespace nntrainer {
+
+/**
+ * @class   OptimizedV1Planner
+ * @brief   Optimized V1 Memory Planner provides the optimized plan for memory
+ * layout
+ * @details optimized planner performs sharing of overlapping memory sharing
+ * upto certain extent
+ */
+class OptimizedV1Planner : public MemoryPlanner {
+public:
+  /**
+   * @brief OptimizedV1Planner destructor
+   *
+   */
+  OptimizedV1Planner() = default;
+
+  /**
+   * @copydoc MemoryPlanner::planLayout(
+   * const std::vector<size_t> &memory_size,
+   * const std::vector<std::pair<unsigned int, unsigned int>> &memory_validity,
+   * std::vector<size_t> &memory_offset);
+   *
+   */
+  size_t planLayout(
+    const std::vector<size_t> &memory_size,
+    const std::vector<std::pair<unsigned int, unsigned int>> &memory_validity,
+    std::vector<size_t> &memory_offset) const;
+
+  /**
+   * @copydoc MemoryPlanner::getType() const
+   *
+   */
+  const std::string &getType() const { return type; }
+
+  inline static const std::string type = "optimized_v1_planner";
+};
+
+} // namespace nntrainer
+
+#endif /** __OPTIMIZED_V1_PLANNER_H_ */