1bb14eaf6f2f8ccf7090f46f245a0e73112f8230
[platform/core/ml/nntrainer.git] / nntrainer / tensor / manager.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4  *
5  * @file   manager.cpp
6  * @date   2 Dec 2020
7  * @brief  This is NNtrainer manager for all weights, i/o and intermediate
8  * tensors
9  * @see    https://github.com/nnstreamer/nntrainer
10  * @author Parichay Kapoor <pk.kapoor@samsung.com>
11  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
12  * @bug    No known bugs except for NYI items
13  *
14  */
15
16 #ifdef __ANDROID__
17 #include <android/sharedmem.h>
18 #endif
19
20 #ifdef DEBUG
21 #include <cassert>
22 #endif
23 #include <fcntl.h>
24 #include <functional>
25 #include <limits>
26 #include <stdexcept>
27 #include <sys/mman.h>
28 #include <sys/stat.h>
29 #include <unistd.h>
30 #include <vector>
31
32 #include <activation_layer.h>
33 #include <basic_planner.h>
34 #include <bn_layer.h>
35 #include <graph_node.h>
36 #include <layer_node.h>
37 #include <layer_normalization_layer.h>
38 #include <manager.h>
39 #include <multiout_layer.h>
40 #include <nntrainer_log.h>
41 #include <optimized_v1_planner.h>
42 #include <tensor_pool.h>
43 #include <tensor_wrap_specs.h>
44 #include <util_func.h>
45 #include <var_grad.h>
46
47 namespace nntrainer {
48 MMapedMemory::MMapedMemory(size_t size, bool allocate_fd_) :
49   fd(-1),
50   buf(nullptr),
51   buf_size(0),
52   allocate_fd(allocate_fd_) {
53
54 #ifndef __ANDROID__
55   if (allocate_fd) {
56     /// @todo create a file in tmpfs and bind to memfs
57     /// memfd_create is not available for number of platforms so this is
58     /// commented
59     // auto fd_ = memfd_create("", 0);
60     // if (fd_ < 0) {
61     //   throw std::runtime_error("[Manager] creating mem fd failed");
62     // }
63     // if (ftruncate(fd_, size) < 0) {
64     //   throw std::runtime_error("[Manager] truncating fd failed");
65     // }
66     ml_logi("[MMapedMemory] fd creation is not supported in this platform");
67     allocate_fd = false;
68   }
69 #endif
70   int fd_ = -1;
71   void *buf_ = nullptr;
72
73   if (allocate_fd) {
74 #ifdef __ANDROID__
75     /// unfortunately, memfd_create is not supported before android level 30
76     fd_ = ASharedMemory_create("", size);
77     if (fd_ < 0) {
78       throw std::runtime_error("[MMapedMemory] creating mem fd failed");
79     }
80
81     if (ASharedMemory_setProt(fd_, PROT_READ | PROT_WRITE) < 0) {
82       // unlink / close the given fd here
83       close(fd_);
84       throw std::runtime_error("[MMapedMemory] Setting prot failed");
85     }
86
87     buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
88 #endif
89   } else {
90     buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS,
91                 fd_, 0);
92   }
93
94   if (buf_ == MAP_FAILED) {
95 #ifdef __ANDROID__
96     if (fd_ != -1) {
97       // unlink / close the given fd here
98       close(fd_);
99     }
100 #endif
101
102     throw std::runtime_error("[MMapedMemory] mmap failed");
103   }
104
105   fd = fd_;
106   buf = buf_;
107   buf_size = size;
108
109   ml_logd("[MMapedMemory] memory acquired size: %zu, fd: %d, addr: %p",
110           buf_size, fd, buf);
111 }
112
113 MMapedMemory::~MMapedMemory() noexcept {
114 #ifdef DEBUG
115   assert(buf_size > 0 && fd > 0);
116 #endif
117
118   if (fd != -1) {
119     if (close(fd) < 0) {
120       ml_logw("[MMapedMemory] closing fd failed on destruction please check");
121     }
122   }
123
124   if (buf != nullptr) {
125     if (munmap(buf, buf_size) < 0) {
126       ml_logw("[MMapedMemory] munmap failed on destruction please check");
127     }
128   }
129
130   /// keeping the invariant although this is not necessary as of now
131   fd = -1;
132   buf = nullptr;
133   buf_size = 0;
134   ml_logd("[MMapedMemory] buf released");
135 }
136
137 void Manager::allocateWeights(unsigned int max_exec_order_) {
138   if (!weight_pool.isAllocated()) {
139     finalizeTensorPool(weight_pool, 0, max_exec_order_);
140     weight_pool.allocate();
141   }
142 }
143
144 void Manager::deallocateWeights() { weight_pool.deallocate(); }
145
146 static Tensor *requestTensor_(const TensorSpecV2 &spec,
147                               const GraphNode::ExecutionOrder &exec_order,
148                               const std::string &scope, TensorPool &tp,
149                               bool expose) {
150   using RT = TensorSpecV2::RequestType;
151   using LS = TensorLifespan;
152   NNTR_THROW_IF(spec.request_type == RT::MAYBE_MODIFYING_VIEW,
153                 std::invalid_argument)
154     << "Modifying view cannot be requested, the request type has to be "
155        "delegated to either view or unique";
156
157   auto [forward, calc_grad, calc_deriv, apply_grad] = exec_order;
158
159   std::vector<unsigned> order = spec.additional_exec_order;
160   if (expose) {
161     order.push_back(TensorPool::PERSIST_END_ORDER);
162   }
163
164   const auto name = scope + ":" + spec.name;
165
166   if (enum_class_or(spec.ls, LS::FORWARD_FUNC_LIFESPAN) == spec.ls) {
167     order.push_back(forward);
168   }
169   if (enum_class_or(spec.ls, LS::CALC_GRAD_LIFESPAN) == spec.ls) {
170     order.push_back(calc_grad);
171   }
172   if (enum_class_or(spec.ls, LS::CALC_DERIV_LIFESPAN) == spec.ls) {
173     order.push_back(calc_deriv);
174   }
175   if (enum_class_or(spec.ls, LS::CALC_AGRAD_LIFESPAN) == spec.ls) {
176     order.push_back(apply_grad);
177   }
178
179   switch (spec.request_type) {
180   case RT::PLACEHOLDER:
181     return tp.placeholder(name, spec.dim);
182   case RT::UNIQUE:
183     return tp.request(name, spec.dim, order, spec.ls, spec.initializer);
184   case RT::SHARED:
185     return tp.requestOrExtend(name, spec.dim, order, spec.ls, spec.initializer);
186   case RT::READ_ONLY_VIEW:
187     return tp.view(name, spec.reference_name, spec.dim, order, spec.ls);
188   case RT::MAYBE_MODIFYING_VIEW:
189   default:
190     throw std::logic_error("requestTensor_ should not reach here");
191   }
192
193   return nullptr;
194 }
195
196 Var_Grad *Manager::requestTensor(const VarGradSpecV2 &spec,
197                                  TensorGroupType identify_as,
198                                  const GraphNode::ExecutionOrder &exec_order,
199                                  const std::string &scope, bool expose_var,
200                                  bool expose_grad) {
201   NNTR_THROW_IF(identify_as == TensorGroupType::WEIGHT, std::invalid_argument)
202     << "requestTensor with var grad spec cannot be identified as weights, use "
203        "requestTensor with weight spec instead";
204
205   NNTR_THROW_IF(identify_as == TensorGroupType::INPUT or
206                   identify_as == TensorGroupType::TENSORS,
207                 nntrainer::exception::not_supported)
208     << "Currently, input and tensors group type is not yet implemented, use "
209        "requestInputs() requestTensors() instead";
210
211   Tensor *var = requestTensor_(spec.variable_spec, exec_order, scope,
212                                tensor_pool, expose_var);
213   Tensor *grad = spec.gradient_spec
214                    ? requestTensor_(*spec.gradient_spec, exec_order, scope,
215                                     tensor_pool, expose_grad)
216                    : nullptr;
217
218   /// @note as only supporting identify_as == TensorGroupType::output, only
219   /// saves to outputs for now
220   outputs_v2.push_back(std::make_unique<Var_Grad>(var, grad));
221
222   return outputs_v2.back().get();
223 }
224
225 std::vector<Var_Grad *> Manager::requestTensors(
226   const std::vector<VarGradSpecV2> &specs, TensorGroupType identify_as,
227   const GraphNode::ExecutionOrder &exec_order, const std::string &scope,
228   bool expose_var, bool expose_grad) {
229   std::vector<Var_Grad *> ret;
230   ret.reserve(specs.size());
231   for (auto &spec : specs) {
232     ret.push_back(requestTensor(spec, identify_as, exec_order, scope,
233                                 expose_var, expose_grad));
234   }
235
236   return ret;
237 }
238
239 /**
240  * @brief Allocate memory for all the managed tensors
241  */
242 void Manager::allocateTensors(unsigned int max_exec_order_) {
243   allocateWeights(max_exec_order_);
244
245   if (!tensor_pool.isAllocated()) {
246     finalizeTensorPool(tensor_pool, 0, max_exec_order_);
247     tensor_pool.allocate();
248   }
249 }
250
251 /**
252  * @brief Deallocate memory for all the managed tensors
253  */
254 void Manager::deallocateTensors(bool dealloc_weights) {
255   if (dealloc_weights)
256     deallocateWeights();
257
258   tensor_pool.deallocate();
259 }
260
261 #ifdef LAYER_V1
262 void Manager::initializeTensorsInference(unsigned int max_exec_order_) {
263   /**
264    * A single buffer (shared_inout) provides memory for inputs and outputs of a
265    * layer. Further, the output of layer i shares memory with input with layer
266    * i+1. So, each alternate layer allocates memory from either the start of the
267    * buffer or the end of the buffer, and use_first_last tracks this
268    *
269    * @note Label for the last layer is not initialized in inference.
270    * @note Input for the first layer is not initialized in inference.
271    */
272   // Initialize shared input/output memory for inference
273   // @note Memory for label is not allocated here as inference doesnt has label
274   if (enable_inference_inout_memory_opt)
275     shared_inout = Tensor(TensorDim({max_shared_inout}), false);
276
277   bool use_first_last = 0;
278   for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
279     auto &l_io = in_outs[idx];
280     unsigned int offset = 0;
281     bool is_first_layer = idx == 0;
282
283     // For flatten layer, do not assign new memory
284     if (idx > 0 && is_flat_type[idx])
285       use_first_last = 1 - use_first_last;
286
287     // In inference mode, do not allocate the memory for the input of the
288     // first layer. These is the first entry in the in_outs. Inference() will
289     // override input tensors of the first layer
290     if (is_first_layer)
291       continue;
292
293     for (auto &io : l_io) {
294       Tensor shared_inout_cur = Tensor();
295       if (enable_inference_inout_memory_opt) {
296         // if optimized
297         if (use_first_last) {
298           // Create tensor with from the front of shared tensor
299           shared_inout_cur =
300             shared_inout.getSharedDataTensor(io->getDim(), offset);
301         } else {
302           // Create tensor with from the back of shared tensor
303           shared_inout_cur = shared_inout.getSharedDataTensor(
304             io->getDim(),
305             max_shared_inout - io->getDim().getDataLen() - offset);
306         }
307         offset += io->getDim().getDataLen();
308       }
309       io->initialize(shared_inout_cur, Tensor(), false);
310     }
311     use_first_last = 1 - use_first_last;
312   }
313 }
314
315 void Manager::initializeTensorsTrain(unsigned int max_exec_order_) {
316   // Initialize gradients
317   initializeGradients();
318
319   // Initialize shared derivative memory
320   if (max_derivative_size > 0 && enable_activation_memory_opt)
321     shared_deriv = Tensor(TensorDim({max_derivative_size}), false);
322   for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
323     auto &l_io = in_outs[idx];
324     unsigned int offset = 0;
325     bool is_last_layer = idx == in_outs.size() - 1;
326
327     for (auto &io : l_io) {
328       // Last layer requires separate memory allocations for output and label
329       // (deriv)
330       if (enable_derivative_memory_opt && !is_last_layer) {
331         // Training Mode with optimizations
332         if (enable_activation_memory_opt &&
333             (is_rnn_type[idx] || is_act_type[idx])) {
334           io->initialize(
335             Tensor(), shared_deriv.getSharedDataTensor(io->getDim(), offset));
336           offset += io->getDim().getDataLen();
337         } else {
338           io->initializeShared();
339         }
340
341       } else {
342         // Training Mode without optimizations
343         io->initialize(Tensor(), Tensor(), true);
344       }
345     }
346   }
347 }
348 #endif
349
350 /**
351  * @brief     Create weights with the given spec
352  *
353  */
354 std::vector<Weight *> Manager::requestWeights(
355   const GraphNode &node, const std::vector<Weight::Spec> &weights_spec,
356   bool trainable, const std::vector<std::string> &shared_names) {
357   const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] =
358     node.getExecutionOrder();
359
360   std::vector<unsigned int> default_var_exec_order(
361     {forwarding_order, calcDerivative_order});
362
363   /**
364    *  TODO: This needs to be fixed. calcDerivative does not needs the gradient.
365    *  However, current implementation of loss needs the gradient computation.
366    *  and therefore, if we remove the calcDerivative order, then tests fails.
367    */
368
369   TensorLifespan var_ls = TensorLifespan::MAX_LIFESPAN;
370   TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN;
371
372   std::vector<Weight *> ret;
373   size_t current_size = weights_v2.size();
374
375   for (unsigned int i = 0; i < weights_spec.size(); ++i) {
376     auto &[dim, t_initializer, w_reg, w_reg_const, decay, clip_by_global_norm,
377            need_gradient, name] = weights_spec.at(i);
378     auto var_exec_order = default_var_exec_order;
379     std::vector<unsigned int> grad_exec_order;
380
381     if (trainable) {
382       var_exec_order.push_back(calcGradient_order);
383       var_exec_order.push_back(applyGradient_order);
384       grad_exec_order.push_back(calcGradient_order);
385       grad_exec_order.push_back(applyGradient_order);
386     }
387
388     /**
389      * If the weight is supposed to be clip by global norm, extend its exec
390      * order with the max exec order where it will be used for clipping and then
391      * applied to the weight.
392      */
393     if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm)) {
394       grad_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
395       var_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
396     }
397
398     Tensor *var = nullptr, *grad = nullptr;
399     bool is_dependent = !shared_names.empty();
400     if (is_dependent) {
401       /// shared_name is used and the orignal name is discarded
402       const auto &shared_name = shared_names.at(i);
403       /** case when shared names are given */
404       var = weight_pool.requestOrExtend(shared_name, dim, var_exec_order,
405                                         var_ls, t_initializer);
406
407       if (trainable && need_gradient) {
408         /** We cannot use the tensor schedulding for weight gradient if the
409          * weight is shared. Weight Sharing means, the gradient is not temporal
410          * for each layer anymore and it is hard to overwritten.
411          */
412         grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
413                                            dim, grad_exec_order, grad_ls,
414                                            Tensor::Initializer::ZEROS);
415       }
416     } else {
417       /** case requesting fresh weights */
418       var =
419         weight_pool.request(name, dim, var_exec_order, var_ls, t_initializer);
420
421       if (trainable && need_gradient) {
422         /** is_wgrad is the index which is true when it is the gradient tensor
423          * of weight. If it is true, memory planner schedule based on it to
424          * reduce the memory.
425          */
426         bool is_wgrad = true;
427         if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm))
428           is_wgrad = false;
429         grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim,
430                                    grad_exec_order, grad_ls,
431                                    Tensor::Initializer::ZEROS, is_wgrad);
432       }
433     }
434
435     weights_v2.emplace_back(std::make_unique<Weight>(
436       var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm));
437   }
438
439   std::transform(weights_v2.begin() + current_size, weights_v2.end(),
440                  std::back_inserter(ret),
441                  [](auto const &elem) { return elem.get(); });
442
443   return ret;
444 }
445
446 /**
447  * @brief     Create tensors with the given spec
448  *
449  */
450 std::vector<Var_Grad *> Manager::requestTensors(
451   const GraphNode &node, const std::vector<Var_Grad::Spec> &tensors_spec,
452   bool trainable, const std::vector<std::string> &shared_names) {
453   const auto [forwarding_order, calcGradient_order, calcDerivative_order, applyGradient_order] =
454     node.getExecutionOrder();
455
456   std::vector<Var_Grad *> ret;
457   size_t current_size = tensors_v2.size();
458
459   for (unsigned int i = 0; i < tensors_spec.size(); ++i) {
460     auto const &[dim, t_init, need_grad, name, tspan] = tensors_spec.at(i);
461
462     std::vector<unsigned int> var_exec_order;
463     std::vector<unsigned int> grad_exec_order;
464
465     /** usage for tensors */
466     if (enum_class_logical_and(tspan, TensorLifespan::FORWARD_FUNC_LIFESPAN))
467       var_exec_order.push_back(forwarding_order);
468
469     /** usage for tensors gradient in backwarding */
470     if (trainable &&
471         enum_class_logical_and(tspan, TensorLifespan::CALC_GRAD_LIFESPAN)) {
472       var_exec_order.push_back(calcGradient_order);
473       grad_exec_order.push_back(calcGradient_order);
474     }
475
476     if (enum_class_logical_and(tspan, TensorLifespan::CALC_DERIV_LIFESPAN)) {
477       var_exec_order.push_back(calcDerivative_order);
478       grad_exec_order.push_back(calcDerivative_order);
479     }
480
481     if (trainable && enum_class_logical_and(tspan, TensorLifespan::CALC_AGRAD_LIFESPAN)) {
482       var_exec_order.push_back(applyGradient_order);
483       grad_exec_order.push_back(applyGradient_order);
484     }
485
486     bool is_dependent = !shared_names.empty();
487     Tensor *var = nullptr, *grad = nullptr;
488
489     if (is_dependent) {
490       const auto &shared_name = shared_names.at(i);
491       var = tensor_pool.requestOrExtend(shared_name, dim, var_exec_order, tspan,
492                                         t_init);
493       if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
494         grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
495                                            dim, grad_exec_order, tspan,
496                                            Tensor::Initializer::ZEROS);
497       }
498     } else {
499       var = tensor_pool.request(name, dim, var_exec_order, tspan, t_init);
500
501       if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
502         grad =
503           tensor_pool.request(name + Var_Grad::grad_suffix, /// name
504                               dim, grad_exec_order, tspan,
505                               Tensor::Initializer::ZEROS /// tensor initializer
506           );
507       }
508     }
509
510     tensors_v2.emplace_back(std::make_unique<Var_Grad>(var, grad));
511   }
512
513   std::transform(tensors_v2.begin() + current_size, tensors_v2.end(),
514                  std::back_inserter(ret),
515                  [](auto const &elem) { return elem.get(); });
516
517   return ret;
518 }
519
520 /**
521  * @brief     Create tensors with the given spec
522  */
523 std::vector<Var_Grad *>
524 Manager::requestInputs(const GraphNode &node,
525                        const std::vector<TensorDim> &inputs_dim,
526                        const std::vector<std::string> &outputs_name) {
527   using RT = TensorSpecV2::RequestType;
528
529   TensorSpecV2 var_common_spec, grad_common_spec;
530   var_common_spec.ls = TensorLifespan::FORWARD_GRAD_LIFESPAN;
531   grad_common_spec.ls = TensorLifespan::CALC_DERIV_LIFESPAN;
532
533   /// @todo handle this inside layer
534   if (node.getType() == ActivationLayer::type or
535       node.getType() == MultiOutLayer::type or
536       node.getType() == BatchNormalizationLayer::type or
537       node.getType() == LayerNormalizationLayer::type)
538     var_common_spec.ls = TensorLifespan::FORWARD_FUNC_LIFESPAN;
539
540   std::vector<Var_Grad *> ret;
541   size_t current_size = inputs_v2.size();
542
543   for (unsigned int idx = 0; idx < inputs_dim.size(); idx++) {
544     TensorSpecV2 var_spec = var_common_spec, grad_spec = grad_common_spec;
545
546     var_spec.name = std::string("input") + std::to_string(idx);
547     var_spec.dim = inputs_dim[idx];
548
549     grad_spec.name = var_spec.name + Var_Grad::grad_suffix;
550     grad_spec.dim = inputs_dim[idx];
551
552     if (!outputs_name.empty()) {
553       grad_spec.request_type = var_spec.request_type = RT::READ_ONLY_VIEW;
554       var_spec.reference_name = outputs_name[idx];
555       grad_spec.reference_name = outputs_name[idx] + Var_Grad::grad_suffix;
556     } else if (!node.getInputConnections().empty()) {
557       grad_spec.request_type = var_spec.request_type = RT::UNIQUE;
558     } else {
559       var_spec.request_type = RT::PLACEHOLDER;
560
561 #ifdef ENABLE_TEST
562       grad_spec.request_type = RT::UNIQUE;
563 #else
564       grad_spec.request_type = RT::PLACEHOLDER;
565 #endif
566     }
567
568     inputs_v2.emplace_back(std::make_unique<Var_Grad>(
569       requestTensor_(var_spec, node.getExecutionOrder(), node.getName(),
570                      tensor_pool, false),
571       requestTensor_(grad_spec, node.getExecutionOrder(), node.getName(),
572                      tensor_pool, false)));
573   }
574
575   ret.reserve(inputs_dim.size());
576   std::transform(inputs_v2.begin() + current_size, inputs_v2.end(),
577                  std::back_inserter(ret),
578                  [](auto const &elem) { return elem.get(); });
579
580   return ret;
581 }
582
583 std::pair<unsigned int, unsigned int>
584 Manager::getMinMaxTensorExecutionOrder(const std::string &name,
585                                        bool is_weight) {
586
587   auto orders = is_weight ? weight_pool.getExecutionOrder(name)
588                           : tensor_pool.getExecutionOrder(name);
589   auto [min_, max_] = std::minmax_element(orders.begin(), orders.end());
590   return {*min_, *max_};
591 }
592
593 unsigned int Manager::getSecondMaxTensorExecutionOrder(const std::string &name,
594                                                        bool is_weight) {
595
596   auto orders = is_weight ? weight_pool.getExecutionOrder(name)
597                           : tensor_pool.getExecutionOrder(name);
598   if (orders.size() < 2)
599     throw std::runtime_error(
600       "Requesting second last access with less than 2 exec orders");
601   /** tensor pool exec order can have same exec order multiple times */
602   std::sort(orders.begin(), orders.end());
603   orders.erase(std::unique(orders.begin(), orders.end()), orders.end());
604   return orders[orders.size() - 2];
605 }
606
607 bool Manager::isFirstAccess(const std::string &name, unsigned current_execution,
608                             bool is_weight) {
609   /// @todo add cache machanism, eg) sort at finalizing requesting
610   return getMinMaxTensorExecutionOrder(name, is_weight).first ==
611          current_execution;
612 }
613
614 bool Manager::isLastAccess(const std::string &name, unsigned current_execution,
615                            bool is_weight) {
616   /// @todo add cache machanism, eg) sort at finalizing requesting
617   return getMinMaxTensorExecutionOrder(name, is_weight).second ==
618          current_execution;
619 }
620
621 bool Manager::isSecondLastAccess(const std::string &name,
622                                  unsigned current_execution, bool is_weight) {
623   /// @todo add cache machanism, eg) sort at finalizing requesting
624   return getSecondMaxTensorExecutionOrder(name, is_weight) == current_execution;
625 }
626
627 /**
628  * @brief     Create tensors with the given spec
629  *
630  */
631 std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
632   const std::vector<TensorDim> &dims, const std::string &name,
633   const TensorLifespan &lifespan, bool is_grad_clip,
634   Tensor::Initializer initializer) {
635   auto const exec_order = weight_pool.getExecutionOrder(name);
636
637   std::vector<Tensor *> ret;
638   ret.reserve(dims.size());
639
640   std::vector<unsigned int> exec;
641   exec.reserve(1);
642   if (is_grad_clip) {
643     exec.emplace_back(TensorPool::PERSIST_END_ORDER);
644   } else {
645     exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second);
646   }
647
648   /// @note this is assuming weight optimizer variables is treated as weight, if
649   /// not, there is room to optimize below behavior
650   for (unsigned int idx = 0; idx < dims.size(); idx++)
651     ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx),
652                                       dims[idx], exec, lifespan, initializer));
653
654   return ret;
655 }
656
657 std::vector<Weight *>
658 Manager::getWeights(const std::function<bool(const Weight *)> &condition) {
659   std::vector<Weight *> conditional_weights;
660
661   for (auto &w : weights_v2) {
662     if (!condition || condition(w.get()))
663       conditional_weights.push_back(w.get());
664   }
665
666   return conditional_weights;
667 }
668
669 void Manager::flushCache() {
670   if (!swap_lookahead) {
671     weight_pool.flushCache();
672     tensor_pool.flushCache();
673   }
674 }
675
676 void Manager::flushCacheExcept(unsigned int order) {
677   auto loadAsync = [&](TensorPool &pool, unsigned int order) {
678     return pool.loadCacheExecAsync(
679       order, [&](int id, TaskExecutor::CompleteStatus status) {
680         std::scoped_lock<std::mutex> lock(completed_mutex);
681         completed[id].set_value(true);
682       });
683   };
684
685   auto waitComplete = [&](unsigned int o) {
686     auto &tasks = async_task_eos[o];
687
688     std::unique_lock<std::mutex> lock(completed_mutex);
689     auto w_fut = completed[std::get<0>(tasks)].get_future();
690     auto t_fut = completed[std::get<1>(tasks)].get_future();
691     lock.unlock();
692
693     w_fut.wait();
694     t_fut.wait();
695
696     async_task_eos.erase(o);
697   };
698
699   // TODO: lookahead > 1 is required.
700   if (swap_lookahead == 1) {
701     if (async_task_eos.count(order) == 1)
702       waitComplete(order);
703
704     auto load_weight = loadAsync(weight_pool, order + 1);
705     auto load_tensor = loadAsync(tensor_pool, order + 1);
706
707     NNTR_THROW_IF(load_weight < 0 || load_tensor < 0, std::runtime_error)
708       << "Failed to launch preloading task";
709     async_task_eos[order + 1] = std::make_tuple(load_weight, load_tensor);
710   } else {
711     weight_pool.flushCacheExcept(order);
712     tensor_pool.flushCacheExcept(order);
713   }
714 }
715
716 void Manager::finalizeTensorPool(TensorPool &pool, unsigned int start,
717                                  unsigned int end) {
718   if (enable_optimizations)
719     pool.finalize(OptimizedV1Planner(), start, end);
720   else
721     pool.finalize(BasicPlanner(), start, end);
722 }
723
724 } // namespace nntrainer