From d4a3f097fc5998bc8c271e80e8d1ba562129582b Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Tue, 27 Dec 2022 11:18:13 +0900 Subject: [PATCH] [Execution Order] Set exectuion order properly for Opt Variables. This patch includes the proper assignment of exectuion order for optimizer variables, e.g., M and V for adam optimizer. Only apply gradient requies optimizer variables. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- nntrainer/graph/network_graph.cpp | 2 +- nntrainer/graph/network_graph.h | 2 +- nntrainer/tensor/basic_planner.cpp | 4 +- nntrainer/tensor/basic_planner.h | 4 +- nntrainer/tensor/manager.cpp | 24 +- nntrainer/tensor/manager.h | 2 +- nntrainer/tensor/memory_planner.h | 4 +- nntrainer/tensor/memory_pool.cpp | 5 +- nntrainer/tensor/memory_pool.h | 8 +- nntrainer/tensor/optimized_v1_planner.cpp | 96 +------ nntrainer/tensor/optimized_v1_planner.h | 4 +- nntrainer/tensor/optimized_v2_planner.cpp | 302 +++++++++++++++++++++++ nntrainer/tensor/optimized_v2_planner.h | 78 ++++++ nntrainer/tensor/tensor_pool.cpp | 19 ++ test/unittest/memory/unittest_memory_planner.cpp | 6 +- 15 files changed, 456 insertions(+), 104 deletions(-) create mode 100644 nntrainer/tensor/optimized_v2_planner.cpp create mode 100644 nntrainer/tensor/optimized_v2_planner.h diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 8cb956d..565c118 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -1111,7 +1111,7 @@ void NetworkGraph::requestOptimizerVariable( std::vector dims = cb(dim); w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( dims, w->getName(), TensorLifespan::MAX_LIFESPAN, - Tensor::Initializer::ZEROS)); + w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS)); } } } diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index cc22c5e..9f97446 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -312,7 +312,7 @@ public: */ void allocateWeights() { tensor_manager->allocateWeights( - std::get<0>((*(cend() - 1))->getExecutionOrder())); + std::get<2>(backward_iter_end->getExecutionOrder())); } /** diff --git a/nntrainer/tensor/basic_planner.cpp b/nntrainer/tensor/basic_planner.cpp index c2a2da2..3bdc746 100644 --- a/nntrainer/tensor/basic_planner.cpp +++ b/nntrainer/tensor/basic_planner.cpp @@ -31,8 +31,8 @@ namespace nntrainer { size_t BasicPlanner::planLayout( const std::vector &memory_size, const std::vector> &memory_validity, - std::vector &memory_offset, - std::vector &memory_is_wgrad) const { + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad) const { memory_offset.resize(memory_size.size()); size_t csum = 0; diff --git a/nntrainer/tensor/basic_planner.h b/nntrainer/tensor/basic_planner.h index 1fea56f..716a291 100644 --- a/nntrainer/tensor/basic_planner.h +++ b/nntrainer/tensor/basic_planner.h @@ -45,8 +45,8 @@ public: size_t planLayout( const std::vector &memory_size, const std::vector> &memory_validity, - std::vector &memory_offset, - std::vector &memory_is_wgrad) const; + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad = 0) const; /** * @copydoc MemoryPlanner::getType() const diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 9d52ad0..1bb14ea 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -360,6 +360,12 @@ std::vector Manager::requestWeights( std::vector default_var_exec_order( {forwarding_order, calcDerivative_order}); + /** + * TODO: This needs to be fixed. calcDerivative does not needs the gradient. + * However, current implementation of loss needs the gradient computation. + * and therefore, if we remove the calcDerivative order, then tests fails. + */ + TensorLifespan var_ls = TensorLifespan::MAX_LIFESPAN; TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN; @@ -384,8 +390,10 @@ std::vector Manager::requestWeights( * order with the max exec order where it will be used for clipping and then * applied to the weight. */ - if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) + if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) { grad_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + var_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + } Tensor *var = nullptr, *grad = nullptr; bool is_dependent = !shared_names.empty(); @@ -622,18 +630,26 @@ bool Manager::isSecondLastAccess(const std::string &name, */ std::vector Manager::requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, Tensor::Initializer initializer) { + const TensorLifespan &lifespan, bool is_grad_clip, + Tensor::Initializer initializer) { auto const exec_order = weight_pool.getExecutionOrder(name); std::vector ret; ret.reserve(dims.size()); + std::vector exec; + exec.reserve(1); + if (is_grad_clip) { + exec.emplace_back(TensorPool::PERSIST_END_ORDER); + } else { + exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second); + } + /// @note this is assuming weight optimizer variables is treated as weight, if /// not, there is room to optimize below behavior for (unsigned int idx = 0; idx < dims.size(); idx++) ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx), - dims[idx], exec_order, lifespan, - initializer)); + dims[idx], exec, lifespan, initializer)); return ret; } diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index aa31258..eec90ff 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -214,7 +214,7 @@ public: */ std::vector requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, + const TensorLifespan &lifespan, bool is_grad_clip, Tensor::Initializer initializer = Tensor::Initializer::NONE); /** diff --git a/nntrainer/tensor/memory_planner.h b/nntrainer/tensor/memory_planner.h index 22532f0..033050f 100644 --- a/nntrainer/tensor/memory_planner.h +++ b/nntrainer/tensor/memory_planner.h @@ -47,8 +47,8 @@ public: virtual size_t planLayout( const std::vector &memory_size, const std::vector> &memory_validity, - std::vector &memory_offset, - std::vector &memory_is_wgrad) const = 0; + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad) const = 0; /** * @brief Get type of the planner diff --git a/nntrainer/tensor/memory_pool.cpp b/nntrainer/tensor/memory_pool.cpp index 1193cb2..457ff14 100644 --- a/nntrainer/tensor/memory_pool.cpp +++ b/nntrainer/tensor/memory_pool.cpp @@ -44,6 +44,8 @@ unsigned int MemoryPool::requestMemory(size_t bytes, unsigned int start_time, memory_validity.push_back({start_time, end_time}); memory_exec_order.push_back(exec_order); memory_is_wgrad.push_back(is_wgrad); + if (is_wgrad) + n_wgrad++; /** invalidate min_pool_size if already there */ min_pool_size = 0; @@ -75,7 +77,7 @@ double MemoryPool::planLayout(const MemoryPlanner &planner) { min_pool_size = calcMinMemoryRequirement(); pool_size = planner.planLayout(memory_size, memory_validity, memory_offset, - memory_is_wgrad); + memory_is_wgrad, n_wgrad); if (pool_size < min_pool_size || !validateLayout()) throw std::runtime_error("Planned layout is not feasible"); @@ -323,6 +325,7 @@ void MemoryPool::clear() { pool_size = 0; min_pool_size = 0; + n_wgrad = 0; } /** diff --git a/nntrainer/tensor/memory_pool.h b/nntrainer/tensor/memory_pool.h index 6c209d7..1d409c5 100644 --- a/nntrainer/tensor/memory_pool.h +++ b/nntrainer/tensor/memory_pool.h @@ -39,7 +39,11 @@ public: * @brief MemoryPool default constructor * */ - explicit MemoryPool() : mem_pool(nullptr), pool_size(0), min_pool_size(0) {} + explicit MemoryPool() : + mem_pool(nullptr), + pool_size(0), + min_pool_size(0), + n_wgrad(0) {} /** * @brief MemoryPool destructor @@ -209,6 +213,8 @@ private: size_t pool_size; /**< memory requirement for this pool */ size_t min_pool_size; /**< minimum theoretical memory requirement */ + + size_t n_wgrad; }; } // namespace nntrainer diff --git a/nntrainer/tensor/optimized_v1_planner.cpp b/nntrainer/tensor/optimized_v1_planner.cpp index a0166f8..d9ad486 100644 --- a/nntrainer/tensor/optimized_v1_planner.cpp +++ b/nntrainer/tensor/optimized_v1_planner.cpp @@ -42,18 +42,8 @@ struct MemoryRequest { end(valid.second), loc(idx), size(s), - offset(0) {} -}; - -/** - * @brief Memory Request data structure clubbing for the weight gradient - * requests - */ -struct WGradMemoryRequest { - MemoryRequest *mem_req; - std::vector> start_end; - - WGradMemoryRequest(MemoryRequest *req) : mem_req(req) {} + offset(0), + size_from_offset(0) {} }; /** @@ -144,20 +134,10 @@ size_t OptimizedV1Planner::planLayout( /** create memory requests structure array for easier management */ std::vector requests; - requests.reserve(memory_size.size() - n_wgrad); - if (n_wgrad) { - for (unsigned int idx = 0; idx < memory_size.size(); idx++) { - if (!memory_is_wgrad[idx]) { - requests.emplace_back(memory_size[idx], memory_validity[idx], idx); - } else { - wgrad_requests.emplace_back(memory_size[idx], memory_validity[idx], - idx); - } - } - } else { - for (unsigned int idx = 0; idx < memory_size.size(); idx++) { - requests.emplace_back(memory_size[idx], memory_validity[idx], idx); - } + requests.reserve(memory_size.size()); + + for (unsigned int idx = 0; idx < memory_size.size(); idx++) { + requests.emplace_back(memory_size[idx], memory_validity[idx], idx); } /** @@ -213,70 +193,18 @@ size_t OptimizedV1Planner::planLayout( memory_offset[req.loc] = offset; memory_req = std::max(memory_req, req.offset + req.size); sorted_req.push_back(&req); - } - - if (wgrad_requests.size()) { - /** TODO: We donot need to start from memory_req. We might find proper - * offset considering execution order */ - size_t last_offset = memory_req; - /* sort the memory request with ascending order of size */ - std::sort( - wgrad_requests.begin(), wgrad_requests.end(), - [](auto const &v1, auto const &v2) -> int { return v1.size > v2.size; }); - - std::vector wgrad_sorted_req; - - bool replace_and_fill = false; - for (auto &req : wgrad_requests) { - for (unsigned int idx = 0; idx < wgrad_sorted_req.size(); idx++) { - auto const sr = wgrad_sorted_req[idx]; - bool merge = true; - if (sr.mem_req->size >= req.size) { - for (auto &interval : sr.start_end) { - if ((interval.first < req.start && interval.first < req.end && - req.end < interval.second) || - (req.start > interval.first && req.start < interval.second && - req.end > interval.second) || - (req.start == interval.first && req.end == interval.second)) { - merge = false; - break; - } - } - } - - if (merge) { - req.offset = sr.mem_req->offset; - memory_offset[req.loc] = req.offset; - replace_and_fill = true; - wgrad_sorted_req[idx].start_end.push_back( - std::make_pair(req.start, req.end)); - break; - } else { - replace_and_fill = false; - } - } - if (replace_and_fill) { - continue; - } - - size_t offset = last_offset; - if (!wgrad_sorted_req.empty()) - offset = wgrad_sorted_req.back().mem_req->offset + - wgrad_sorted_req.back().mem_req->size; - - req.offset = offset; - memory_offset[req.loc] = offset; - memory_req = std::max(memory_req, req.offset + req.size); - wgrad_sorted_req.push_back(WGradMemoryRequest(&req)); - wgrad_sorted_req.back().start_end.push_back( - std::make_pair(req.start, req.end)); +#ifdef DEBUG + if (n_wgrad && memory_is_wgrad[req.loc]) { + new_grad_cnt++; + new_grad_size += req.size; } +#endif + } // validateIntervalOverlap(memory_validity, memory_size, memory_offset, // memory_req); - return memory_req; } diff --git a/nntrainer/tensor/optimized_v1_planner.h b/nntrainer/tensor/optimized_v1_planner.h index 9e44616..0158701 100644 --- a/nntrainer/tensor/optimized_v1_planner.h +++ b/nntrainer/tensor/optimized_v1_planner.h @@ -60,8 +60,8 @@ public: size_t planLayout( const std::vector &memory_size, const std::vector> &memory_validity, - std::vector &memory_offset, - std::vector &memory_is_wgrad) const; + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad = 0) const; /** * @copydoc MemoryPlanner::getType() const diff --git a/nntrainer/tensor/optimized_v2_planner.cpp b/nntrainer/tensor/optimized_v2_planner.cpp new file mode 100644 index 0000000..e686f14 --- /dev/null +++ b/nntrainer/tensor/optimized_v2_planner.cpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2022 Jijoong Moon + * + * @file optimized_v2_planner.cpp + * @date 29 December 2022 + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @author Jijoong Moon + * @bug No known bugs except for NYI items + * @brief This is Optimized V2 Memory Planner + * + */ + +#include +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +/** + * @brief Memory Request data structure clubbing all the requests + * + */ +struct MemoryRequest { + unsigned int start; /**< start of the validity (inclusive) */ + unsigned int end; /**< end of the validity (exclusive) */ + unsigned int loc; /**< index/location of the this request */ + size_t size; /**< size of the request */ + size_t offset; /**< offset for this request */ + + /** + * @brief Constructor for the Memory Request + * + */ + MemoryRequest(size_t s, const std::pair &valid, + unsigned int idx) : + start(valid.first), + end(valid.second), + loc(idx), + size(s), + offset(0) {} +}; + +/** + * @brief Memory Request data structure clubbing for the weight gradient + * requests + */ +struct WGradMemoryRequest { + MemoryRequest *mem_req; + std::vector> start_end; + + WGradMemoryRequest(MemoryRequest *req) : mem_req(req) {} +}; + +/** + * @brief check if validate interval is overlapping in a very naive way. + * + * @param memory_validity validity + * @param memory_size size + * @param memory_offset offset + * @param memory_req request + */ +[[maybe_unused]] static void validateIntervalOverlap( + const std::vector> &memory_validity, + const std::vector &memory_size, + const std::vector &memory_offset, size_t memory_req) { + auto bits = std::make_unique(memory_req); + + for (size_t i = 0; i < memory_req; ++i) { + bits[i] = 0; + } + + auto exec_start = + std::min_element(memory_validity.begin(), memory_validity.end(), + [](auto &a, auto &b) { return a.first < b.first; }); + + auto exec_end = + std::max_element(memory_validity.begin(), memory_validity.end(), + [](auto &a, auto &b) { return a.second < b.second; }); + + auto set = [&](int offset, size_t size, int idx) { + for (unsigned int i = offset; i < size; ++i) { + NNTR_THROW_IF(bits[i], std::invalid_argument) + << " bits taken at i: " << i << " offset: " << offset + << " size: " << size << " idx: " << idx; + bits[i] = 1; + } + }; + + auto unset = [&](int offset, size_t size, int idx) { + for (unsigned int i = offset; i < size; ++i) { + NNTR_THROW_IF(!bits[i], std::invalid_argument) + << "double freeing bits at i: " << i << " offset: " << offset + << " size: " << size << " idx: " << idx; + bits[i] = 0; + } + }; + + for (unsigned int exec = exec_start->first; exec <= exec_end->second; + ++exec) { + + for (unsigned int idx = 0; idx < memory_validity.size(); ++idx) { + auto &validity = memory_validity.at(idx); + auto &sz = memory_size.at(idx); + auto &offset = memory_offset.at(idx); + if (validity.first == exec) { + set(offset, sz, idx); + } + if (validity.second == exec) { + unset(offset, sz, idx); + } + } + } + // check if there is any dangling memory + set(0, memory_req, memory_validity.size()); +} + +/** + * @copydoc MemoryPlanner::planLayout( + * const std::vector &memory_size, + * const std::vector> &memory_validity, + * std::vector &memory_offset, + * std::vector &memory_is_wgrad); + * + * @details The optimized v1 memory planner assigns memory to the requests whose + * validity starts first. + * The requested memories are sorted based on the ascending order of the start + * timestamps, and descending order using the end timestamps. The + * sorted memories are given increasing offset based on the memory size. + * At the end of each timestamp, invalid memories are freed, and offset updated + * for reuse. This planner allocates overlapping memory for all the required + * memories. + * + */ +size_t OptimizedV2Planner::planLayout( + const std::vector &memory_size, + const std::vector> &memory_validity, + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad) const { + + std::vector wgrad_requests; + wgrad_requests.reserve(n_wgrad); + + /** create memory requests structure array for easier management */ + std::vector requests; + requests.reserve(memory_size.size() - n_wgrad); + if (n_wgrad) { + for (unsigned int idx = 0; idx < memory_size.size(); idx++) { + if (!memory_is_wgrad[idx]) { + requests.emplace_back(memory_size[idx], memory_validity[idx], idx); + } else { + wgrad_requests.emplace_back(memory_size[idx], memory_validity[idx], + idx); + } + } + } else { + for (unsigned int idx = 0; idx < memory_size.size(); idx++) { + requests.emplace_back(memory_size[idx], memory_validity[idx], idx); + } + } + + /** + * sort the memory requests with ascending order of start time first, and + * then end time + */ + std::sort(requests.begin(), requests.end(), + [](auto const &v1, auto const &v2) -> int { + if (v1.start == v2.start) + return v1.end < v2.end; + return v1.start < v2.start; + /** TODO: try this */ + // if (v1.end == v2.end) + // return v1.start < v2.start; + // return v1.end > v2.end; + }); + + /** all the memories in use sorted by their assigned offset and size */ + std::vector sorted_req; + + /** iterate over the sorted requests and start allocation of the requests */ + memory_offset.resize(memory_size.size()); + size_t memory_req = 0; + for (auto &req : requests) { + /** remove expired memories and update offset */ + while (!sorted_req.empty() && sorted_req.back()->end <= req.start) + sorted_req.pop_back(); + + /** if there exists an expired memory with same size (not at the edge), + * reuse it */ + bool replace_and_fill = false; + for (int idx = sorted_req.size() - 1; idx >= 0; idx--) { + auto const &sr = sorted_req[idx]; + /** TODO: reuse if memory size not exactly match */ + if (sr->end <= req.start && sr->size == req.size) { + req.offset = sr->offset; + memory_offset[req.loc] = req.offset; + sorted_req[idx] = &req; + replace_and_fill = true; + break; + } + } + if (replace_and_fill) { + continue; + } + + size_t offset = 0; + if (!sorted_req.empty()) + offset = sorted_req.back()->offset + sorted_req.back()->size; + + /** assign offset to the new request and push to queue */ + req.offset = offset; + memory_offset[req.loc] = offset; + memory_req = std::max(memory_req, req.offset + req.size); + sorted_req.push_back(&req); + } + + if (wgrad_requests.size()) { + /** TODO: We donot need to start from memeory_req. We might find proper + * offset considering execution order */ + size_t last_offset = memory_req; + + /* sort the memory request with ascending order of size */ + std::sort( + wgrad_requests.begin(), wgrad_requests.end(), + [](auto const &v1, auto const &v2) -> int { return v1.size > v2.size; }); + + std::vector wgrad_sorted_req; + + bool replace_and_fill = false; + unsigned int new_grad_cnt = 0; + unsigned int reused_grad_cnt = 0; + size_t new_grad_size = 0; + size_t reused_grad_size = 0; + for (auto &req : wgrad_requests) { + for (unsigned int idx = 0; idx < wgrad_sorted_req.size(); idx++) { + auto const sr = wgrad_sorted_req[idx]; + bool merge = true; + if (sr.mem_req->size >= req.size) { + for (auto &interval : sr.start_end) { + if ((interval.first < req.start && interval.first < req.end && + req.end < interval.second) || + (req.start > interval.first && req.start < interval.second && + req.end > interval.second) || + (req.start == interval.first && req.end == interval.second)) { + merge = false; + break; + } + } + } + + if (merge) { + req.offset = sr.mem_req->offset; + memory_offset[req.loc] = req.offset; + replace_and_fill = true; + wgrad_sorted_req[idx].start_end.push_back( + std::make_pair(req.start, req.end)); + reused_grad_size += req.size; + reused_grad_cnt++; + break; + } else { + replace_and_fill = false; + } + } + if (replace_and_fill) { + continue; + } + + size_t offset = last_offset; + if (!wgrad_sorted_req.empty()) + offset = wgrad_sorted_req.back().mem_req->offset + + wgrad_sorted_req.back().mem_req->size; + + req.offset = offset; + memory_offset[req.loc] = offset; + memory_req = std::max(memory_req, req.offset + req.size); + wgrad_sorted_req.push_back(WGradMemoryRequest(&req)); + wgrad_sorted_req.back().start_end.push_back( + std::make_pair(req.start, req.end)); + new_grad_cnt++; + new_grad_size += req.size; + } + + ml_logd("Total Requested Memory(OPTV2): %lf MiB>>>>>>>> \n - new mem for " + "gradient = %d, " + "(%lf MiB) & reused mem for gradient = %d (%lf MiB)\n", + memory_req / 1024, new_grad_cnt, new_grad_size / 1024, + reused_grad_cnt, reused_grad_size / 1024); + } + + // validateIntervalOverlap(memory_validity, memory_size, memory_offset, + // memory_req); + + return memory_req; +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/optimized_v2_planner.h b/nntrainer/tensor/optimized_v2_planner.h new file mode 100644 index 0000000..c1de951 --- /dev/null +++ b/nntrainer/tensor/optimized_v2_planner.h @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2022 Jijoong.Moon + * + * @file optimzied_v2_planner.h + * @date 29 December 2022 + * @see https://github.com/nnstreamer/nntrainer + * @author Parichay Kapoor + * @author Jijoong Moon + * @bug No known bugs except for NYI items + * @brief This is Optimized V2 Memory Planner + * + * @note This planner has been design to give reduced memory usage for training + * and might not perform very well for inference. + * + * @details The principle for this planner is to give memory to the requests in + * the order of the start of their validity. + * This takes advantage of the pattern that the outputs of the layer nodes + * allocated during forwarding are also used during backwarding as well. + * + * If two memory requests have the same start time, then the memory request with + * higher end is allocated first. This is to minimize the fragmentation once the + * memory is being freed. + * + * The assigned memories are cached, and once their validity is finished, they + * are freed and reused for the next allocations. + */ + +#ifndef __OPTIMIZED_V2_PLANNER_H_ +#define __OPTIMIZED_V2_PLANNER_H_ + +#include + +#include + +namespace nntrainer { + +/** + * @class OptimizedV2Planner + * @brief Optimized V2 Memory Planner provides the optimized plan for memory + * layout + * @details optimized planner performs sharing of overlapping memory sharing + * upto certain extent + */ +class OptimizedV2Planner : public MemoryPlanner { +public: + /** + * @brief OptimizedV1Planner destructor + * + */ + OptimizedV2Planner() = default; + + /** + * @copydoc MemoryPlanner::planLayout( + * const std::vector &memory_size, + * const std::vector> &memory_validity, + * std::vector &memory_offset, + * std::vector &memory_is_wgrad); + * + */ + size_t planLayout( + const std::vector &memory_size, + const std::vector> &memory_validity, + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad = 0) const; + + /** + * @copydoc MemoryPlanner::getType() const + * + */ + const std::string &getType() const { return type; } + + inline static const std::string type = "optimized_v2_planner"; +}; + +} // namespace nntrainer + +#endif /** __OPTIMIZED_V2_PLANNER_H_ */ diff --git a/nntrainer/tensor/tensor_pool.cpp b/nntrainer/tensor/tensor_pool.cpp index 9fb9b0f..1447e5b 100644 --- a/nntrainer/tensor/tensor_pool.cpp +++ b/nntrainer/tensor/tensor_pool.cpp @@ -109,6 +109,11 @@ void TensorPool::finalize(const MemoryPlanner &planner, unsigned int start_order, unsigned int end_order) { mem_pool->clear(); unsigned int bytes_requested = 0; + /** if execution order is PERSIST_END_ORDER, then we think it has another + * execution order for gradient clipping + * persist_end_order is for checking if the end order is updated */ + bool persist_end_order = false; + unsigned int old_end_order = end_order; for (auto &spec : pool) { auto details = std::get_if(&spec.details); if (!details || details->lifespan == TensorLifespan::UNMANAGED || @@ -127,12 +132,26 @@ void TensorPool::finalize(const MemoryPlanner &planner, for (unsigned int idx = 0; idx < details->exec_order.size(); idx++) { if (details->exec_order[idx] >= start_order) validity_start = std::min(validity_start, details->exec_order[idx]); + /** This is to enforce not to reach if the execution order is greater than + * backwarding end order. + * e.g., for the input layer, the backwarding is not reached but the + * exeuction order is assigned. + * */ + if (details->exec_order[idx] > old_end_order && + details->exec_order[idx] != PERSIST_END_ORDER) { + details->exec_order[idx] = PERSIST_END_ORDER - 1; + } } unsigned int validity_end = validity_start; for (unsigned int idx = 0; idx < details->exec_order.size(); idx++) { if (details->exec_order[idx] == PERSIST_END_ORDER) { + if (!persist_end_order) { + end_order = end_order + 1; + persist_end_order = true; + } validity_end = end_order; + details->exec_order[idx] = validity_end; break; } diff --git a/test/unittest/memory/unittest_memory_planner.cpp b/test/unittest/memory/unittest_memory_planner.cpp index a571703..69391a7 100644 --- a/test/unittest/memory/unittest_memory_planner.cpp +++ b/test/unittest/memory/unittest_memory_planner.cpp @@ -168,7 +168,7 @@ TEST_P(MemoryPlannerValidate, full_overlap) { std::vector memory_is_wgrad; size_t pool_size = planner->planLayout(memory_size, memory_validity, - memory_offset, memory_is_wgrad); + memory_offset, memory_is_wgrad, 0); EXPECT_EQ(pool_size, std::accumulate(memory_size.begin(), memory_size.end(), 0u)); @@ -196,7 +196,7 @@ TEST_P(MemoryPlannerValidate, none_overlap) { std::vector memory_offset; std::vector memory_is_wgrad; size_t pool_size = planner->planLayout(memory_size, memory_validity, - memory_offset, memory_is_wgrad); + memory_offset, memory_is_wgrad, 0); EXPECT_TRUE(validateOverflow(memory_size, memory_offset, pool_size)); if (planner->getType() == nntrainer::BasicPlanner::type) { @@ -233,7 +233,7 @@ TEST_P(MemoryPlannerValidate, partial_overlap) { std::vector memory_offset; std::vector memory_is_wgrad; size_t pool_size = planner->planLayout(memory_size, memory_validity, - memory_offset, memory_is_wgrad); + memory_offset, memory_is_wgrad, 0); EXPECT_TRUE(validateOverflow(memory_size, memory_offset, pool_size)); if (planner->getType() == nntrainer::BasicPlanner::type) { -- 2.7.4