1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
7 * @brief This is NNtrainer manager for all weights, i/o and intermediate
9 * @see https://github.com/nnstreamer/nntrainer
10 * @author Parichay Kapoor <pk.kapoor@samsung.com>
11 * @author Jihoon Lee <jhoon.it.lee@samsung.com>
12 * @bug No known bugs except for NYI items
17 #include <android/sharedmem.h>
32 #include <activation_layer.h>
33 #include <basic_planner.h>
35 #include <graph_node.h>
36 #include <layer_node.h>
37 #include <layer_normalization_layer.h>
39 #include <multiout_layer.h>
40 #include <nntrainer_log.h>
41 #include <optimized_v1_planner.h>
42 #include <tensor_pool.h>
43 #include <tensor_wrap_specs.h>
44 #include <util_func.h>
48 MMapedMemory::MMapedMemory(size_t size, bool allocate_fd_) :
52 allocate_fd(allocate_fd_) {
56 /// @todo create a file in tmpfs and bind to memfs
57 /// memfd_create is not available for number of platforms so this is
59 // auto fd_ = memfd_create("", 0);
61 // throw std::runtime_error("[Manager] creating mem fd failed");
63 // if (ftruncate(fd_, size) < 0) {
64 // throw std::runtime_error("[Manager] truncating fd failed");
66 ml_logi("[MMapedMemory] fd creation is not supported in this platform");
75 /// unfortunately, memfd_create is not supported before android level 30
76 fd_ = ASharedMemory_create("", size);
78 throw std::runtime_error("[MMapedMemory] creating mem fd failed");
81 if (ASharedMemory_setProt(fd_, PROT_READ | PROT_WRITE) < 0) {
82 // unlink / close the given fd here
84 throw std::runtime_error("[MMapedMemory] Setting prot failed");
87 buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
90 buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS,
94 if (buf_ == MAP_FAILED) {
97 // unlink / close the given fd here
102 throw std::runtime_error("[MMapedMemory] mmap failed");
109 ml_logd("[MMapedMemory] memory acquired size: %zu, fd: %d, addr: %p",
113 MMapedMemory::~MMapedMemory() noexcept {
115 assert(buf_size > 0 && fd > 0);
120 ml_logw("[MMapedMemory] closing fd failed on destruction please check");
124 if (buf != nullptr) {
125 if (munmap(buf, buf_size) < 0) {
126 ml_logw("[MMapedMemory] munmap failed on destruction please check");
130 /// keeping the invariant although this is not necessary as of now
134 ml_logd("[MMapedMemory] buf released");
137 void Manager::allocateWeights(unsigned int max_exec_order_) {
138 if (!weight_pool.isAllocated()) {
139 finalizeTensorPool(weight_pool, 0, max_exec_order_);
140 weight_pool.allocate();
144 void Manager::deallocateWeights() { weight_pool.deallocate(); }
146 static Tensor *requestTensor_(const TensorSpecV2 &spec,
147 const GraphNode::ExecutionOrder &exec_order,
148 const std::string &scope, TensorPool &tp,
150 using RT = TensorSpecV2::RequestType;
151 using LS = TensorLifespan;
152 NNTR_THROW_IF(spec.request_type == RT::MAYBE_MODIFYING_VIEW,
153 std::invalid_argument)
154 << "Modifying view cannot be requested, the request type has to be "
155 "delegated to either view or unique";
157 auto [forward, calc_grad, calc_deriv, apply_grad] = exec_order;
159 std::vector<unsigned> order = spec.additional_exec_order;
161 order.push_back(TensorPool::PERSIST_END_ORDER);
164 const auto name = scope + ":" + spec.name;
166 if (enum_class_or(spec.ls, LS::FORWARD_FUNC_LIFESPAN) == spec.ls) {
167 order.push_back(forward);
169 if (enum_class_or(spec.ls, LS::CALC_GRAD_LIFESPAN) == spec.ls) {
170 order.push_back(calc_grad);
172 if (enum_class_or(spec.ls, LS::CALC_DERIV_LIFESPAN) == spec.ls) {
173 order.push_back(calc_deriv);
175 if (enum_class_or(spec.ls, LS::CALC_AGRAD_LIFESPAN) == spec.ls) {
176 order.push_back(apply_grad);
179 switch (spec.request_type) {
180 case RT::PLACEHOLDER:
181 return tp.placeholder(name, spec.dim);
183 return tp.request(name, spec.dim, order, spec.ls, spec.initializer);
185 return tp.requestOrExtend(name, spec.dim, order, spec.ls, spec.initializer);
186 case RT::READ_ONLY_VIEW:
187 return tp.view(name, spec.reference_name, spec.dim, order, spec.ls);
188 case RT::MAYBE_MODIFYING_VIEW:
190 throw std::logic_error("requestTensor_ should not reach here");
196 Var_Grad *Manager::requestTensor(const VarGradSpecV2 &spec,
197 TensorGroupType identify_as,
198 const GraphNode::ExecutionOrder &exec_order,
199 const std::string &scope, bool expose_var,
201 NNTR_THROW_IF(identify_as == TensorGroupType::WEIGHT, std::invalid_argument)
202 << "requestTensor with var grad spec cannot be identified as weights, use "
203 "requestTensor with weight spec instead";
205 NNTR_THROW_IF(identify_as == TensorGroupType::INPUT or
206 identify_as == TensorGroupType::TENSORS,
207 nntrainer::exception::not_supported)
208 << "Currently, input and tensors group type is not yet implemented, use "
209 "requestInputs() requestTensors() instead";
211 Tensor *var = requestTensor_(spec.variable_spec, exec_order, scope,
212 tensor_pool, expose_var);
213 Tensor *grad = spec.gradient_spec
214 ? requestTensor_(*spec.gradient_spec, exec_order, scope,
215 tensor_pool, expose_grad)
218 /// @note as only supporting identify_as == TensorGroupType::output, only
219 /// saves to outputs for now
220 outputs_v2.push_back(std::make_unique<Var_Grad>(var, grad));
222 return outputs_v2.back().get();
225 std::vector<Var_Grad *> Manager::requestTensors(
226 const std::vector<VarGradSpecV2> &specs, TensorGroupType identify_as,
227 const GraphNode::ExecutionOrder &exec_order, const std::string &scope,
228 bool expose_var, bool expose_grad) {
229 std::vector<Var_Grad *> ret;
230 ret.reserve(specs.size());
231 for (auto &spec : specs) {
232 ret.push_back(requestTensor(spec, identify_as, exec_order, scope,
233 expose_var, expose_grad));
240 * @brief Allocate memory for all the managed tensors
242 void Manager::allocateTensors(unsigned int max_exec_order_) {
243 allocateWeights(max_exec_order_);
245 if (!tensor_pool.isAllocated()) {
246 finalizeTensorPool(tensor_pool, 0, max_exec_order_);
247 tensor_pool.allocate();
252 * @brief Deallocate memory for all the managed tensors
254 void Manager::deallocateTensors(bool dealloc_weights) {
258 tensor_pool.deallocate();
262 void Manager::initializeTensorsInference(unsigned int max_exec_order_) {
264 * A single buffer (shared_inout) provides memory for inputs and outputs of a
265 * layer. Further, the output of layer i shares memory with input with layer
266 * i+1. So, each alternate layer allocates memory from either the start of the
267 * buffer or the end of the buffer, and use_first_last tracks this
269 * @note Label for the last layer is not initialized in inference.
270 * @note Input for the first layer is not initialized in inference.
272 // Initialize shared input/output memory for inference
273 // @note Memory for label is not allocated here as inference doesnt has label
274 if (enable_inference_inout_memory_opt)
275 shared_inout = Tensor(TensorDim({max_shared_inout}), false);
277 bool use_first_last = 0;
278 for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
279 auto &l_io = in_outs[idx];
280 unsigned int offset = 0;
281 bool is_first_layer = idx == 0;
283 // For flatten layer, do not assign new memory
284 if (idx > 0 && is_flat_type[idx])
285 use_first_last = 1 - use_first_last;
287 // In inference mode, do not allocate the memory for the input of the
288 // first layer. These is the first entry in the in_outs. Inference() will
289 // override input tensors of the first layer
293 for (auto &io : l_io) {
294 Tensor shared_inout_cur = Tensor();
295 if (enable_inference_inout_memory_opt) {
297 if (use_first_last) {
298 // Create tensor with from the front of shared tensor
300 shared_inout.getSharedDataTensor(io->getDim(), offset);
302 // Create tensor with from the back of shared tensor
303 shared_inout_cur = shared_inout.getSharedDataTensor(
305 max_shared_inout - io->getDim().getDataLen() - offset);
307 offset += io->getDim().getDataLen();
309 io->initialize(shared_inout_cur, Tensor(), false);
311 use_first_last = 1 - use_first_last;
315 void Manager::initializeTensorsTrain(unsigned int max_exec_order_) {
316 // Initialize gradients
317 initializeGradients();
319 // Initialize shared derivative memory
320 if (max_derivative_size > 0 && enable_activation_memory_opt)
321 shared_deriv = Tensor(TensorDim({max_derivative_size}), false);
322 for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
323 auto &l_io = in_outs[idx];
324 unsigned int offset = 0;
325 bool is_last_layer = idx == in_outs.size() - 1;
327 for (auto &io : l_io) {
328 // Last layer requires separate memory allocations for output and label
330 if (enable_derivative_memory_opt && !is_last_layer) {
331 // Training Mode with optimizations
332 if (enable_activation_memory_opt &&
333 (is_rnn_type[idx] || is_act_type[idx])) {
335 Tensor(), shared_deriv.getSharedDataTensor(io->getDim(), offset));
336 offset += io->getDim().getDataLen();
338 io->initializeShared();
342 // Training Mode without optimizations
343 io->initialize(Tensor(), Tensor(), true);
351 * @brief Create weights with the given spec
354 std::vector<Weight *> Manager::requestWeights(
355 const GraphNode &node, const std::vector<Weight::Spec> &weights_spec,
356 bool trainable, const std::vector<std::string> &shared_names) {
357 const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] =
358 node.getExecutionOrder();
360 std::vector<unsigned int> default_var_exec_order(
361 {forwarding_order, calcDerivative_order});
364 * TODO: This needs to be fixed. calcDerivative does not needs the gradient.
365 * However, current implementation of loss needs the gradient computation.
366 * and therefore, if we remove the calcDerivative order, then tests fails.
369 TensorLifespan var_ls = TensorLifespan::MAX_LIFESPAN;
370 TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN;
372 std::vector<Weight *> ret;
373 size_t current_size = weights_v2.size();
375 for (unsigned int i = 0; i < weights_spec.size(); ++i) {
376 auto &[dim, t_initializer, w_reg, w_reg_const, decay, clip_by_global_norm,
377 need_gradient, name] = weights_spec.at(i);
378 auto var_exec_order = default_var_exec_order;
379 std::vector<unsigned int> grad_exec_order;
382 var_exec_order.push_back(calcGradient_order);
383 var_exec_order.push_back(applyGradient_order);
384 grad_exec_order.push_back(calcGradient_order);
385 grad_exec_order.push_back(applyGradient_order);
389 * If the weight is supposed to be clip by global norm, extend its exec
390 * order with the max exec order where it will be used for clipping and then
391 * applied to the weight.
393 if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) {
394 grad_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
395 var_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
398 Tensor *var = nullptr, *grad = nullptr;
399 bool is_dependent = !shared_names.empty();
401 /// shared_name is used and the orignal name is discarded
402 const auto &shared_name = shared_names.at(i);
403 /** case when shared names are given */
404 var = weight_pool.requestOrExtend(shared_name, dim, var_exec_order,
405 var_ls, t_initializer);
407 if (trainable && need_gradient) {
408 /** We cannot use the tensor schedulding for weight gradient if the
409 * weight is shared. Weight Sharing means, the gradient is not temporal
410 * for each layer anymore and it is hard to overwritten.
412 grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
413 dim, grad_exec_order, grad_ls,
414 Tensor::Initializer::ZEROS);
417 /** case requesting fresh weights */
419 weight_pool.request(name, dim, var_exec_order, var_ls, t_initializer);
421 if (trainable && need_gradient) {
422 /** is_wgrad is the index which is true when it is the gradient tensor
423 * of weight. If it is true, memory planner schedule based on it to
426 bool is_wgrad = true;
427 if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm))
429 grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim,
430 grad_exec_order, grad_ls,
431 Tensor::Initializer::ZEROS, is_wgrad);
435 weights_v2.emplace_back(std::make_unique<Weight>(
436 var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm));
439 std::transform(weights_v2.begin() + current_size, weights_v2.end(),
440 std::back_inserter(ret),
441 [](auto const &elem) { return elem.get(); });
447 * @brief Create tensors with the given spec
450 std::vector<Var_Grad *> Manager::requestTensors(
451 const GraphNode &node, const std::vector<Var_Grad::Spec> &tensors_spec,
452 bool trainable, const std::vector<std::string> &shared_names) {
453 const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] =
454 node.getExecutionOrder();
456 std::vector<Var_Grad *> ret;
457 size_t current_size = tensors_v2.size();
459 for (unsigned int i = 0; i < tensors_spec.size(); ++i) {
460 auto const &[dim, t_init, need_grad, name, tspan] = tensors_spec.at(i);
462 std::vector<unsigned int> var_exec_order;
463 std::vector<unsigned int> grad_exec_order;
465 /** usage for tensors */
466 if (enum_class_logical_and(tspan, TensorLifespan::FORWARD_FUNC_LIFESPAN))
467 var_exec_order.push_back(forwarding_order);
469 /** usage for tensors gradient in backwarding */
471 enum_class_logical_and(tspan, TensorLifespan::CALC_GRAD_LIFESPAN)) {
472 var_exec_order.push_back(calcGradient_order);
473 grad_exec_order.push_back(calcGradient_order);
476 if (enum_class_logical_and(tspan, TensorLifespan::CALC_DERIV_LIFESPAN)) {
477 var_exec_order.push_back(calcDerivative_order);
478 grad_exec_order.push_back(calcDerivative_order);
481 if (trainable && enum_class_logical_and(tspan, TensorLifespan::CALC_AGRAD_LIFESPAN)) {
482 var_exec_order.push_back(applyGradient_order);
483 grad_exec_order.push_back(applyGradient_order);
486 bool is_dependent = !shared_names.empty();
487 Tensor *var = nullptr, *grad = nullptr;
490 const auto &shared_name = shared_names.at(i);
491 var = tensor_pool.requestOrExtend(shared_name, dim, var_exec_order, tspan,
493 if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
494 grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
495 dim, grad_exec_order, tspan,
496 Tensor::Initializer::ZEROS);
499 var = tensor_pool.request(name, dim, var_exec_order, tspan, t_init);
501 if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
503 tensor_pool.request(name + Var_Grad::grad_suffix, /// name
504 dim, grad_exec_order, tspan,
505 Tensor::Initializer::ZEROS /// tensor initializer
510 tensors_v2.emplace_back(std::make_unique<Var_Grad>(var, grad));
513 std::transform(tensors_v2.begin() + current_size, tensors_v2.end(),
514 std::back_inserter(ret),
515 [](auto const &elem) { return elem.get(); });
521 * @brief Create tensors with the given spec
523 std::vector<Var_Grad *>
524 Manager::requestInputs(const GraphNode &node,
525 const std::vector<TensorDim> &inputs_dim,
526 const std::vector<std::string> &outputs_name) {
527 using RT = TensorSpecV2::RequestType;
529 TensorSpecV2 var_common_spec, grad_common_spec;
530 var_common_spec.ls = TensorLifespan::FORWARD_GRAD_LIFESPAN;
531 grad_common_spec.ls = TensorLifespan::CALC_DERIV_LIFESPAN;
533 /// @todo handle this inside layer
534 if (node.getType() == ActivationLayer::type or
535 node.getType() == MultiOutLayer::type or
536 node.getType() == BatchNormalizationLayer::type or
537 node.getType() == LayerNormalizationLayer::type)
538 var_common_spec.ls = TensorLifespan::FORWARD_FUNC_LIFESPAN;
540 std::vector<Var_Grad *> ret;
541 size_t current_size = inputs_v2.size();
543 for (unsigned int idx = 0; idx < inputs_dim.size(); idx++) {
544 TensorSpecV2 var_spec = var_common_spec, grad_spec = grad_common_spec;
546 var_spec.name = std::string("input") + std::to_string(idx);
547 var_spec.dim = inputs_dim[idx];
549 grad_spec.name = var_spec.name + Var_Grad::grad_suffix;
550 grad_spec.dim = inputs_dim[idx];
552 if (!outputs_name.empty()) {
553 grad_spec.request_type = var_spec.request_type = RT::READ_ONLY_VIEW;
554 var_spec.reference_name = outputs_name[idx];
555 grad_spec.reference_name = outputs_name[idx] + Var_Grad::grad_suffix;
556 } else if (!node.getInputConnections().empty()) {
557 grad_spec.request_type = var_spec.request_type = RT::UNIQUE;
559 var_spec.request_type = RT::PLACEHOLDER;
562 grad_spec.request_type = RT::UNIQUE;
564 grad_spec.request_type = RT::PLACEHOLDER;
568 inputs_v2.emplace_back(std::make_unique<Var_Grad>(
569 requestTensor_(var_spec, node.getExecutionOrder(), node.getName(),
571 requestTensor_(grad_spec, node.getExecutionOrder(), node.getName(),
572 tensor_pool, false)));
575 ret.reserve(inputs_dim.size());
576 std::transform(inputs_v2.begin() + current_size, inputs_v2.end(),
577 std::back_inserter(ret),
578 [](auto const &elem) { return elem.get(); });
583 std::pair<unsigned int, unsigned int>
584 Manager::getMinMaxTensorExecutionOrder(const std::string &name,
587 auto orders = is_weight ? weight_pool.getExecutionOrder(name)
588 : tensor_pool.getExecutionOrder(name);
589 auto [min_, max_] = std::minmax_element(orders.begin(), orders.end());
590 return {*min_, *max_};
593 unsigned int Manager::getSecondMaxTensorExecutionOrder(const std::string &name,
596 auto orders = is_weight ? weight_pool.getExecutionOrder(name)
597 : tensor_pool.getExecutionOrder(name);
598 if (orders.size() < 2)
599 throw std::runtime_error(
600 "Requesting second last access with less than 2 exec orders");
601 /** tensor pool exec order can have same exec order multiple times */
602 std::sort(orders.begin(), orders.end());
603 orders.erase(std::unique(orders.begin(), orders.end()), orders.end());
604 return orders[orders.size() - 2];
607 bool Manager::isFirstAccess(const std::string &name, unsigned current_execution,
609 /// @todo add cache machanism, eg) sort at finalizing requesting
610 return getMinMaxTensorExecutionOrder(name, is_weight).first ==
614 bool Manager::isLastAccess(const std::string &name, unsigned current_execution,
616 /// @todo add cache machanism, eg) sort at finalizing requesting
617 return getMinMaxTensorExecutionOrder(name, is_weight).second ==
621 bool Manager::isSecondLastAccess(const std::string &name,
622 unsigned current_execution, bool is_weight) {
623 /// @todo add cache machanism, eg) sort at finalizing requesting
624 return getSecondMaxTensorExecutionOrder(name, is_weight) == current_execution;
628 * @brief Create tensors with the given spec
631 std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
632 const std::vector<TensorDim> &dims, const std::string &name,
633 const TensorLifespan &lifespan, bool is_grad_clip,
634 Tensor::Initializer initializer) {
635 auto const exec_order = weight_pool.getExecutionOrder(name);
637 std::vector<Tensor *> ret;
638 ret.reserve(dims.size());
640 std::vector<unsigned int> exec;
643 exec.emplace_back(TensorPool::PERSIST_END_ORDER);
645 exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second);
648 /// @note this is assuming weight optimizer variables is treated as weight, if
649 /// not, there is room to optimize below behavior
650 for (unsigned int idx = 0; idx < dims.size(); idx++)
651 ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx),
652 dims[idx], exec, lifespan, initializer));
657 std::vector<Weight *>
658 Manager::getWeights(const std::function<bool(const Weight *)> &condition) {
659 std::vector<Weight *> conditional_weights;
661 for (auto &w : weights_v2) {
662 if (!condition || condition(w.get()))
663 conditional_weights.push_back(w.get());
666 return conditional_weights;
669 void Manager::flushCache() {
670 if (!swap_lookahead) {
671 weight_pool.flushCache();
672 tensor_pool.flushCache();
676 void Manager::flushCacheExcept(unsigned int order) {
677 auto loadAsync = [&](TensorPool &pool, unsigned int order) {
678 return pool.loadCacheExecAsync(
679 order, [&](int id, TaskExecutor::CompleteStatus status) {
680 std::scoped_lock<std::mutex> lock(completed_mutex);
681 completed[id].set_value(true);
685 auto waitComplete = [&](unsigned int o) {
686 auto &tasks = async_task_eos[o];
688 std::unique_lock<std::mutex> lock(completed_mutex);
689 auto w_fut = completed[std::get<0>(tasks)].get_future();
690 auto t_fut = completed[std::get<1>(tasks)].get_future();
696 async_task_eos.erase(o);
699 // TODO: lookahead > 1 is required.
700 if (swap_lookahead == 1) {
701 if (async_task_eos.count(order) == 1)
704 auto load_weight = loadAsync(weight_pool, order + 1);
705 auto load_tensor = loadAsync(tensor_pool, order + 1);
707 NNTR_THROW_IF(load_weight < 0 || load_tensor < 0, std::runtime_error)
708 << "Failed to launch preloading task";
709 async_task_eos[order + 1] = std::make_tuple(load_weight, load_tensor);
711 weight_pool.flushCacheExcept(order);
712 tensor_pool.flushCacheExcept(order);
716 void Manager::finalizeTensorPool(TensorPool &pool, unsigned int start,
718 if (enable_optimizations)
719 pool.finalize(OptimizedV1Planner(), start, end);
721 pool.finalize(BasicPlanner(), start, end);
724 } // namespace nntrainer