From: jijoong.moon Date: Thu, 29 Dec 2022 11:16:07 +0000 (+0900) Subject: [Memory Planner] Update the Memory Planner X-Git-Tag: accepted/tizen/unified/20230425.130129~21 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fc8bb08c2e913e964dfac47abbfc0f32b4ddd514;p=platform%2Fcore%2Fml%2Fnntrainer.git [Memory Planner] Update the Memory Planner This PR includes, 1. assigning right execution order depending on layer type. : We do need to move those fixes in to each layer 2. Update memory Planner to use the memory with smaller size **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- diff --git a/nntrainer/graph/graph_node.h b/nntrainer/graph/graph_node.h index 22cee1f..4ca150a 100644 --- a/nntrainer/graph/graph_node.h +++ b/nntrainer/graph/graph_node.h @@ -39,7 +39,8 @@ public: * This ensures that the operations are executed in the order of their * listing. */ - typedef std::tuple ExecutionOrder; + typedef std::tuple + ExecutionOrder; /** * @brief Destructor of Layer Class @@ -71,6 +72,13 @@ public: virtual const std::string getType() const = 0; /** + * @brief Get the trainable parameter + * + * @return bool true / false + */ + virtual bool getTrainable() const = 0; + + /** * @brief Get the input connections for this node * * @return list of name of the nodes which form input connections diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 565c118..6dab352 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -26,16 +26,19 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -475,7 +478,7 @@ void NetworkGraph::allocateTensors(ExecutionMode exec_mode_) { * usage less than the max_exec_order are allocated. */ tensor_manager->allocateTensors( - std::get<2>(backward_iter_end->getExecutionOrder())); + std::get<3>(backward_iter_end->getExecutionOrder())); } } @@ -778,6 +781,15 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, }); } + if (lnode->getType() == RNNCellLayer::type or + lnode->getType() == LSTMCellLayer::type or + lnode->getType() == GRUCellLayer::type) { + std::for_each( + out_specs.begin(), out_specs.end(), [this](VarGradSpecV2 &spec) { + spec.variable_spec.ls = TensorLifespan::FORWARD_GRAD_LIFESPAN; + }); + } + const std::vector &outputs = tensor_manager->requestTensors( out_specs, Manager::TensorGroupType::OUTPUT, lnode->getExecutionOrder(), lnode->getName()); diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 9f97446..c7bd293 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -312,7 +312,7 @@ public: */ void allocateWeights() { tensor_manager->allocateWeights( - std::get<2>(backward_iter_end->getExecutionOrder())); + std::get<3>(backward_iter_end->getExecutionOrder())); } /** diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 701bac2..3d8f617 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -234,8 +234,8 @@ public: */ static VarGradSpecV2 outSpec(const TensorDim &dim, const std::string &name = "out", - TensorLifespan ls = TensorLifespan::FORWARD_GRAD_LIFESPAN, - TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN); + TensorLifespan ls = TensorLifespan::FORWARD_FUNC_LIFESPAN, + TensorLifespan grad_ls = TensorLifespan::CALC_GRAD_DERIV_LIFESPAN); /** * @brief request outputs diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 1bb14ea..b3a961b 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -33,12 +33,18 @@ #include #include #include +#include #include #include +#include +#include +#include #include #include #include #include +#include +#include #include #include #include @@ -146,7 +152,7 @@ void Manager::deallocateWeights() { weight_pool.deallocate(); } static Tensor *requestTensor_(const TensorSpecV2 &spec, const GraphNode::ExecutionOrder &exec_order, const std::string &scope, TensorPool &tp, - bool expose) { + bool expose, bool trainable) { using RT = TensorSpecV2::RequestType; using LS = TensorLifespan; NNTR_THROW_IF(spec.request_type == RT::MAYBE_MODIFYING_VIEW, @@ -209,10 +215,10 @@ Var_Grad *Manager::requestTensor(const VarGradSpecV2 &spec, "requestInputs() requestTensors() instead"; Tensor *var = requestTensor_(spec.variable_spec, exec_order, scope, - tensor_pool, expose_var); + tensor_pool, expose_var, false); Tensor *grad = spec.gradient_spec ? requestTensor_(*spec.gradient_spec, exec_order, scope, - tensor_pool, expose_grad) + tensor_pool, expose_grad, false) : nullptr; /// @note as only supporting identify_as == TensorGroupType::output, only @@ -354,8 +360,8 @@ void Manager::initializeTensorsTrain(unsigned int max_exec_order_) { std::vector Manager::requestWeights( const GraphNode &node, const std::vector &weights_spec, bool trainable, const std::vector &shared_names) { - const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] = - node.getExecutionOrder(); + const auto [forwarding_order, calcGradient_order, calcDerivative_order, + applyGradient_order] = node.getExecutionOrder(); std::vector default_var_exec_order( {forwarding_order, calcDerivative_order}); @@ -450,8 +456,8 @@ std::vector Manager::requestWeights( std::vector Manager::requestTensors( const GraphNode &node, const std::vector &tensors_spec, bool trainable, const std::vector &shared_names) { - const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] = - node.getExecutionOrder(); + const auto [forwarding_order, calcGradient_order, calcDerivative_order, + applyGradient_order] = node.getExecutionOrder(); std::vector ret; size_t current_size = tensors_v2.size(); @@ -478,7 +484,8 @@ std::vector Manager::requestTensors( grad_exec_order.push_back(calcDerivative_order); } - if (trainable && enum_class_logical_and(tspan, TensorLifespan::CALC_AGRAD_LIFESPAN)) { + if (trainable && + enum_class_logical_and(tspan, TensorLifespan::CALC_AGRAD_LIFESPAN)) { var_exec_order.push_back(applyGradient_order); grad_exec_order.push_back(applyGradient_order); } @@ -534,9 +541,18 @@ Manager::requestInputs(const GraphNode &node, if (node.getType() == ActivationLayer::type or node.getType() == MultiOutLayer::type or node.getType() == BatchNormalizationLayer::type or - node.getType() == LayerNormalizationLayer::type) + node.getType() == LayerNormalizationLayer::type or !node.getTrainable()) var_common_spec.ls = TensorLifespan::FORWARD_FUNC_LIFESPAN; + if (node.getType() == MSELossLayer::type or + node.getType() == CrossEntropySoftmaxLossLayer::type or + node.getType() == CrossEntropySigmoidLossLayer::type) + var_common_spec.ls = TensorLifespan::FORWARD_DERIV_LIFESPAN; + + if (node.getType() == GRUCellLayer::type) { + grad_common_spec.ls = TensorLifespan::CALC_GRAD_DERIV_LIFESPAN; + } + std::vector ret; size_t current_size = inputs_v2.size(); @@ -567,9 +583,9 @@ Manager::requestInputs(const GraphNode &node, inputs_v2.emplace_back(std::make_unique( requestTensor_(var_spec, node.getExecutionOrder(), node.getName(), - tensor_pool, false), + tensor_pool, false, node.getTrainable()), requestTensor_(grad_spec, node.getExecutionOrder(), node.getName(), - tensor_pool, false))); + tensor_pool, false, node.getTrainable()))); } ret.reserve(inputs_dim.size()); @@ -662,7 +678,6 @@ Manager::getWeights(const std::function &condition) { if (!condition || condition(w.get())) conditional_weights.push_back(w.get()); } - return conditional_weights; } diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index 0e97efe..40205f4 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -14,6 +14,8 @@ tensor_sources = [ 'swap_device.cpp', 'tensor_pool.cpp', 'optimized_v1_planner.cpp', + 'optimized_v2_planner.cpp', + 'optimized_v3_planner.cpp', 'task_executor.cpp', ] diff --git a/nntrainer/tensor/optimized_v1_planner.cpp b/nntrainer/tensor/optimized_v1_planner.cpp index d9ad486..a6b1f7b 100644 --- a/nntrainer/tensor/optimized_v1_planner.cpp +++ b/nntrainer/tensor/optimized_v1_planner.cpp @@ -42,8 +42,7 @@ struct MemoryRequest { end(valid.second), loc(idx), size(s), - offset(0), - size_from_offset(0) {} + offset(0) {} }; /** @@ -193,14 +192,6 @@ size_t OptimizedV1Planner::planLayout( memory_offset[req.loc] = offset; memory_req = std::max(memory_req, req.offset + req.size); sorted_req.push_back(&req); - -#ifdef DEBUG - if (n_wgrad && memory_is_wgrad[req.loc]) { - new_grad_cnt++; - new_grad_size += req.size; - } -#endif - } // validateIntervalOverlap(memory_validity, memory_size, memory_offset, diff --git a/nntrainer/tensor/optimized_v2_planner.cpp b/nntrainer/tensor/optimized_v2_planner.cpp index e686f14..bc733fa 100644 --- a/nntrainer/tensor/optimized_v2_planner.cpp +++ b/nntrainer/tensor/optimized_v2_planner.cpp @@ -233,10 +233,12 @@ size_t OptimizedV2Planner::planLayout( std::vector wgrad_sorted_req; bool replace_and_fill = false; +#ifdef DEBUG unsigned int new_grad_cnt = 0; unsigned int reused_grad_cnt = 0; size_t new_grad_size = 0; size_t reused_grad_size = 0; +#endif for (auto &req : wgrad_requests) { for (unsigned int idx = 0; idx < wgrad_sorted_req.size(); idx++) { auto const sr = wgrad_sorted_req[idx]; @@ -260,8 +262,10 @@ size_t OptimizedV2Planner::planLayout( replace_and_fill = true; wgrad_sorted_req[idx].start_end.push_back( std::make_pair(req.start, req.end)); +#ifdef DEBUG reused_grad_size += req.size; reused_grad_cnt++; +#endif break; } else { replace_and_fill = false; @@ -282,15 +286,11 @@ size_t OptimizedV2Planner::planLayout( wgrad_sorted_req.push_back(WGradMemoryRequest(&req)); wgrad_sorted_req.back().start_end.push_back( std::make_pair(req.start, req.end)); +#ifdef DEBUG new_grad_cnt++; new_grad_size += req.size; +#endif } - - ml_logd("Total Requested Memory(OPTV2): %lf MiB>>>>>>>> \n - new mem for " - "gradient = %d, " - "(%lf MiB) & reused mem for gradient = %d (%lf MiB)\n", - memory_req / 1024, new_grad_cnt, new_grad_size / 1024, - reused_grad_cnt, reused_grad_size / 1024); } // validateIntervalOverlap(memory_validity, memory_size, memory_offset, diff --git a/nntrainer/tensor/optimized_v3_planner.cpp b/nntrainer/tensor/optimized_v3_planner.cpp new file mode 100644 index 0000000..e26e4a9 --- /dev/null +++ b/nntrainer/tensor/optimized_v3_planner.cpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2023 Jijoong Moon + * + * @file optimized_v3_planner.cpp + * @date 2 January 2023 + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @bug No known bugs except for NYI items + * @brief This is Optimized V3 Memory Planner + * + */ + +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +/** + * @brief Memory Request data structure clubbing all the requests + * + */ +struct MemoryRequest { + unsigned int start; /**< start of the validity (inclusive) */ + unsigned int end; /**< end of the validity (exclusive) */ + unsigned int loc; /**< index/location of the this request */ + size_t size; /**< size of the request */ + size_t offset; /**< offset for this request */ + + /** + * @brief Constructor for the Memory Request + * + */ + MemoryRequest(size_t s, const std::pair &valid, + unsigned int idx) : + start(valid.first), + end(valid.second), + loc(idx), + size(s), + offset(0) {} +}; + +static size_t computeSpace(unsigned int exec_order, + std::vector &sorted_req, + std::vector> &vacant) { + size_t bottom = 0; + size_t max_offset = 0; + + std::sort(sorted_req.begin(), sorted_req.end(), + [](auto const &v1, auto const &v2) -> int { + return v1->offset < v2->offset; + /** TODO: try this */ + // if (v1.end == v2.end) + // return v1.start < v2.start; + // return v1.end > v2.end; + }); + + for (unsigned idx = 0; idx < sorted_req.size(); idx++) { + auto const &sr = sorted_req[idx]; + size_t top = sr->offset + sr->size; + + if (max_offset < top) + max_offset = top; + + if (sr->offset > bottom) { + vacant.push_back(std::make_pair(bottom, sr->offset)); + } + bottom = top; + } + + return max_offset; +} + +/** + * @brief check if validate interval is overlapping in a very naive way. + * + * @param memory_validity validity + * @param memory_size size + * @param memory_offset offset + * @param memory_req request + */ +[[maybe_unused]] static void validateIntervalOverlap( + const std::vector> &memory_validity, + const std::vector &memory_size, + const std::vector &memory_offset, size_t memory_req) { + auto bits = std::make_unique(memory_req); + + for (size_t i = 0; i < memory_req; ++i) { + bits[i] = 0; + } + + auto exec_start = + std::min_element(memory_validity.begin(), memory_validity.end(), + [](auto &a, auto &b) { return a.first < b.first; }); + + auto exec_end = + std::max_element(memory_validity.begin(), memory_validity.end(), + [](auto &a, auto &b) { return a.second < b.second; }); + + auto set = [&](int offset, size_t size, int idx) { + for (unsigned int i = offset; i < size; ++i) { + NNTR_THROW_IF(bits[i], std::invalid_argument) + << " bits taken at i: " << i << " offset: " << offset + << " size: " << size << " idx: " << idx; + bits[i] = 1; + } + }; + + auto unset = [&](int offset, size_t size, int idx) { + for (unsigned int i = offset; i < size; ++i) { + NNTR_THROW_IF(!bits[i], std::invalid_argument) + << "double freeing bits at i: " << i << " offset: " << offset + << " size: " << size << " idx: " << idx; + bits[i] = 0; + } + }; + + for (unsigned int exec = exec_start->first; exec <= exec_end->second; + ++exec) { + + for (unsigned int idx = 0; idx < memory_validity.size(); ++idx) { + auto &validity = memory_validity.at(idx); + auto &sz = memory_size.at(idx); + auto &offset = memory_offset.at(idx); + if (validity.first == exec) { + set(offset, sz, idx); + } + if (validity.second == exec) { + unset(offset, sz, idx); + } + } + } + // check if there is any dangling memory + set(0, memory_req, memory_validity.size()); +} + +/** + * @copydoc MemoryPlanner::planLayout( + * const std::vector &memory_size, + * const std::vector> &memory_validity, + * std::vector &memory_offset, + * std::vector &memory_is_wgrad); + * + * @details The optimized v1 memory planner assigns memory to the requests whose + * validity starts first. + * The requested memories are sorted based on the ascending order of the start + * timestamps, and descending order using the end timestamps. The + * sorted memories are given increasing offset based on the memory size. + * At the end of each timestamp, invalid memories are freed, and offset updated + * for reuse. This planner allocates overlapping memory for all the required + * memories. + * + */ +size_t OptimizedV3Planner::planLayout( + const std::vector &memory_size, + const std::vector> &memory_validity, + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad) const { + + /** create memory requests structure array for easier management */ + std::vector requests; + requests.reserve(memory_size.size()); + for (unsigned int idx = 0; idx < memory_size.size(); idx++) { + requests.emplace_back(memory_size[idx], memory_validity[idx], idx); + } + + /** + * sort the memory requests with ascending order of start time first, and + * then end time + */ + std::sort(requests.begin(), requests.end(), + [](auto const &v1, auto const &v2) -> int { + if (v1.start == v2.start) + return v1.end < v2.end; + return v1.start < v2.start; + /** TODO: try this */ + // if (v1.end == v2.end) + // return v1.start < v2.start; + // return v1.end > v2.end; + }); + + /** all the memories in use sorted by their assigned offset and size */ + std::vector sorted_req; + + /** iterate over the sorted requests and start allocation of the requests */ + memory_offset.resize(memory_size.size()); + size_t memory_req = 0; + for (auto &req : requests) { + sorted_req.erase( + std::remove_if(sorted_req.begin(), sorted_req.end(), + [req](auto elem) { return elem->end <= req.start; }), + sorted_req.end()); + + bool replace_and_fill = false; + std::vector> vacant; + + size_t max_offset = computeSpace(req.start, sorted_req, vacant); + + for (unsigned int idx = 0; idx < vacant.size(); idx++) { + if (vacant[idx].second - vacant[idx].first >= req.size) { + req.offset = vacant[idx].first; + memory_offset[req.loc] = req.offset; + sorted_req.push_back(&req); + replace_and_fill = true; + break; + } + } + vacant.clear(); + + if (replace_and_fill) { + continue; + } + + req.offset = max_offset; + memory_offset[req.loc] = max_offset; + memory_req = std::max(memory_req, req.offset + req.size); + sorted_req.push_back(&req); + } + + return memory_req; +} + +} // namespace nntrainer diff --git a/nntrainer/tensor/optimized_v3_planner.h b/nntrainer/tensor/optimized_v3_planner.h new file mode 100644 index 0000000..bbbc843 --- /dev/null +++ b/nntrainer/tensor/optimized_v3_planner.h @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2023 Jijoong Moon + * + * @file optimzied_v3_planner.h + * @date 2 January 2023 + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @bug No known bugs except for NYI items + * @brief This is Optimized V3 Memory Planner + * + * + */ + +#ifndef __OPTIMIZED_V3_PLANNER_H_ +#define __OPTIMIZED_V3_PLANNER_H_ + +#include + +#include + +namespace nntrainer { + +/** + * @class OptimizedV3Planner + * @brief Optimized V3 Memory Planner provides the optimized plan for memory + * layout + * @details optimized planner performs sharing of overlapping memory sharing + * upto certain extent + */ +class OptimizedV3Planner : public MemoryPlanner { +public: + /** + * @brief OptimizedV3Planner destructor + * + */ + OptimizedV3Planner() = default; + + /** + * @copydoc MemoryPlanner::planLayout( + * const std::vector &memory_size, + * const std::vector> &memory_validity, + * std::vector &memory_offset, + * std::vector &memory_is_wgrad); + * + */ + size_t planLayout( + const std::vector &memory_size, + const std::vector> &memory_validity, + std::vector &memory_offset, std::vector &memory_is_wgrad, + size_t n_wgrad = 0) const; + + /** + * @copydoc MemoryPlanner::getType() const + * + */ + const std::string &getType() const { return type; } + + inline static const std::string type = "optimized_v3_planner"; +}; + +} // namespace nntrainer + +#endif /** __OPTIMIZED_V3_PLANNER_H_ */