[Tensor] Remove calcGrad step for trainable layer
[platform/core/ml/nntrainer.git] / nntrainer / tensor / manager.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4  *
5  * @file   manager.cpp
6  * @date   2 Dec 2020
7  * @brief  This is NNtrainer manager for all weights, i/o and intermediate
8  * tensors
9  * @see    https://github.com/nnstreamer/nntrainer
10  * @author Parichay Kapoor <pk.kapoor@samsung.com>
11  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
12  * @bug    No known bugs except for NYI items
13  *
14  */
15
16 #ifdef __ANDROID__
17 #include <android/sharedmem.h>
18 #endif
19
20 #ifdef DEBUG
21 #include <cassert>
22 #endif
23 #include <fcntl.h>
24 #include <functional>
25 #include <limits>
26 #include <stdexcept>
27 #include <sys/mman.h>
28 #include <sys/stat.h>
29 #include <unistd.h>
30 #include <vector>
31
32 #include <activation_layer.h>
33 #include <basic_planner.h>
34 #include <bn_layer.h>
35 #include <graph_node.h>
36 #include <layer_node.h>
37 #include <layer_normalization_layer.h>
38 #include <manager.h>
39 #include <multiout_layer.h>
40 #include <nntrainer_log.h>
41 #include <optimized_v1_planner.h>
42 #include <tensor_pool.h>
43 #include <tensor_wrap_specs.h>
44 #include <util_func.h>
45 #include <var_grad.h>
46
47 namespace nntrainer {
48 MMapedMemory::MMapedMemory(size_t size, bool allocate_fd_) :
49   fd(-1),
50   buf(nullptr),
51   buf_size(0),
52   allocate_fd(allocate_fd_) {
53
54 #ifndef __ANDROID__
55   if (allocate_fd) {
56     /// @todo create a file in tmpfs and bind to memfs
57     /// memfd_create is not available for number of platforms so this is
58     /// commented
59     // auto fd_ = memfd_create("", 0);
60     // if (fd_ < 0) {
61     //   throw std::runtime_error("[Manager] creating mem fd failed");
62     // }
63     // if (ftruncate(fd_, size) < 0) {
64     //   throw std::runtime_error("[Manager] truncating fd failed");
65     // }
66     ml_logi("[MMapedMemory] fd creation is not supported in this platform");
67     allocate_fd = false;
68   }
69 #endif
70   int fd_ = -1;
71   void *buf_ = nullptr;
72
73   if (allocate_fd) {
74 #ifdef __ANDROID__
75     /// unfortunately, memfd_create is not supported before android level 30
76     fd_ = ASharedMemory_create("", size);
77     if (fd_ < 0) {
78       throw std::runtime_error("[MMapedMemory] creating mem fd failed");
79     }
80
81     if (ASharedMemory_setProt(fd_, PROT_READ | PROT_WRITE) < 0) {
82       // unlink / close the given fd here
83       close(fd_);
84       throw std::runtime_error("[MMapedMemory] Setting prot failed");
85     }
86
87     buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0);
88 #endif
89   } else {
90     buf_ = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS,
91                 fd_, 0);
92   }
93
94   if (buf_ == MAP_FAILED) {
95 #ifdef __ANDROID__
96     if (fd_ != -1) {
97       // unlink / close the given fd here
98       close(fd_);
99     }
100 #endif
101
102     throw std::runtime_error("[MMapedMemory] mmap failed");
103   }
104
105   fd = fd_;
106   buf = buf_;
107   buf_size = size;
108
109   ml_logd("[MMapedMemory] memory acquired size: %zu, fd: %d, addr: %p",
110           buf_size, fd, buf);
111 }
112
113 MMapedMemory::~MMapedMemory() noexcept {
114 #ifdef DEBUG
115   assert(buf_size > 0 && fd > 0);
116 #endif
117
118   if (fd != -1) {
119     if (close(fd) < 0) {
120       ml_logw("[MMapedMemory] closing fd failed on destruction please check");
121     }
122   }
123
124   if (buf != nullptr) {
125     if (munmap(buf, buf_size) < 0) {
126       ml_logw("[MMapedMemory] munmap failed on destruction please check");
127     }
128   }
129
130   /// keeping the invariant although this is not necessary as of now
131   fd = -1;
132   buf = nullptr;
133   buf_size = 0;
134   ml_logd("[MMapedMemory] buf released");
135 }
136
137 void Manager::allocateWeights(unsigned int max_exec_order_) {
138   if (!weight_pool.isAllocated()) {
139     finalizeTensorPool(weight_pool, 0, max_exec_order_);
140     weight_pool.allocate();
141   }
142 }
143
144 void Manager::deallocateWeights() { weight_pool.deallocate(); }
145
146 static Tensor *requestTensor_(const TensorSpecV2 &spec,
147                               const GraphNode::ExecutionOrder &exec_order,
148                               const std::string &scope, TensorPool &tp,
149                               bool expose) {
150   using RT = TensorSpecV2::RequestType;
151   using LS = TensorLifespan;
152   NNTR_THROW_IF(spec.request_type == RT::MAYBE_MODIFYING_VIEW,
153                 std::invalid_argument)
154     << "Modifying view cannot be requested, the request type has to be "
155        "delegated to either view or unique";
156
157   auto [forward, calc_grad, calc_deriv] = exec_order;
158
159   std::vector<unsigned> order = spec.additional_exec_order;
160   if (expose) {
161     order.push_back(TensorPool::PERSIST_END_ORDER);
162   }
163
164   const auto name = scope + ":" + spec.name;
165
166   if (enum_class_or(spec.ls, LS::FORWARD_FUNC_LIFESPAN) == spec.ls) {
167     order.push_back(forward);
168   }
169   if (enum_class_or(spec.ls, LS::CALC_GRAD_LIFESPAN) == spec.ls) {
170     order.push_back(calc_grad);
171   }
172   if (enum_class_or(spec.ls, LS::CALC_DERIV_LIFESPAN) == spec.ls) {
173     order.push_back(calc_deriv);
174   }
175
176   switch (spec.request_type) {
177   case RT::PLACEHOLDER:
178     return tp.placeholder(name, spec.dim);
179   case RT::UNIQUE:
180     return tp.request(name, spec.dim, order, spec.ls, spec.initializer);
181   case RT::SHARED:
182     return tp.requestOrExtend(name, spec.dim, order, spec.ls, spec.initializer);
183   case RT::READ_ONLY_VIEW:
184     return tp.view(name, spec.reference_name, spec.dim, order, spec.ls);
185   case RT::MAYBE_MODIFYING_VIEW:
186   default:
187     throw std::logic_error("requestTensor_ should not reach here");
188   }
189
190   return nullptr;
191 }
192
193 Var_Grad *Manager::requestTensor(const VarGradSpecV2 &spec,
194                                  TensorGroupType identify_as,
195                                  const GraphNode::ExecutionOrder &exec_order,
196                                  const std::string &scope, bool expose_var,
197                                  bool expose_grad) {
198   NNTR_THROW_IF(identify_as == TensorGroupType::WEIGHT, std::invalid_argument)
199     << "requestTensor with var grad spec cannot be identified as weights, use "
200        "requestTensor with weight spec instead";
201
202   NNTR_THROW_IF(identify_as == TensorGroupType::INPUT or
203                   identify_as == TensorGroupType::TENSORS,
204                 nntrainer::exception::not_supported)
205     << "Currently, input and tensors group type is not yet implemented, use "
206        "requestInputs() requestTensors() instead";
207
208   Tensor *var = requestTensor_(spec.variable_spec, exec_order, scope,
209                                tensor_pool, expose_var);
210   Tensor *grad = spec.gradient_spec
211                    ? requestTensor_(*spec.gradient_spec, exec_order, scope,
212                                     tensor_pool, expose_grad)
213                    : nullptr;
214
215   /// @note as only supporting identify_as == TensorGroupType::output, only
216   /// saves to outputs for now
217   outputs_v2.push_back(std::make_unique<Var_Grad>(var, grad));
218
219   return outputs_v2.back().get();
220 }
221
222 std::vector<Var_Grad *> Manager::requestTensors(
223   const std::vector<VarGradSpecV2> &specs, TensorGroupType identify_as,
224   const GraphNode::ExecutionOrder &exec_order, const std::string &scope,
225   bool expose_var, bool expose_grad) {
226   std::vector<Var_Grad *> ret;
227   ret.reserve(specs.size());
228   for (auto &spec : specs) {
229     ret.push_back(requestTensor(spec, identify_as, exec_order, scope,
230                                 expose_var, expose_grad));
231   }
232
233   return ret;
234 }
235
236 /**
237  * @brief Allocate memory for all the managed tensors
238  */
239 void Manager::allocateTensors(unsigned int max_exec_order_) {
240   allocateWeights(max_exec_order_);
241
242   if (!tensor_pool.isAllocated()) {
243     finalizeTensorPool(tensor_pool, 0, max_exec_order_);
244     tensor_pool.allocate();
245   }
246 }
247
248 /**
249  * @brief Deallocate memory for all the managed tensors
250  */
251 void Manager::deallocateTensors(bool dealloc_weights) {
252   if (dealloc_weights)
253     deallocateWeights();
254
255   tensor_pool.deallocate();
256 }
257
258 #ifdef LAYER_V1
259 void Manager::initializeTensorsInference(unsigned int max_exec_order_) {
260   /**
261    * A single buffer (shared_inout) provides memory for inputs and outputs of a
262    * layer. Further, the output of layer i shares memory with input with layer
263    * i+1. So, each alternate layer allocates memory from either the start of the
264    * buffer or the end of the buffer, and use_first_last tracks this
265    *
266    * @note Label for the last layer is not initialized in inference.
267    * @note Input for the first layer is not initialized in inference.
268    */
269   // Initialize shared input/output memory for inference
270   // @note Memory for label is not allocated here as inference doesnt has label
271   if (enable_inference_inout_memory_opt)
272     shared_inout = Tensor(TensorDim({max_shared_inout}), false);
273
274   bool use_first_last = 0;
275   for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
276     auto &l_io = in_outs[idx];
277     unsigned int offset = 0;
278     bool is_first_layer = idx == 0;
279
280     // For flatten layer, do not assign new memory
281     if (idx > 0 && is_flat_type[idx])
282       use_first_last = 1 - use_first_last;
283
284     // In inference mode, do not allocate the memory for the input of the
285     // first layer. These is the first entry in the in_outs. Inference() will
286     // override input tensors of the first layer
287     if (is_first_layer)
288       continue;
289
290     for (auto &io : l_io) {
291       Tensor shared_inout_cur = Tensor();
292       if (enable_inference_inout_memory_opt) {
293         // if optimized
294         if (use_first_last) {
295           // Create tensor with from the front of shared tensor
296           shared_inout_cur =
297             shared_inout.getSharedDataTensor(io->getDim(), offset);
298         } else {
299           // Create tensor with from the back of shared tensor
300           shared_inout_cur = shared_inout.getSharedDataTensor(
301             io->getDim(),
302             max_shared_inout - io->getDim().getDataLen() - offset);
303         }
304         offset += io->getDim().getDataLen();
305       }
306       io->initialize(shared_inout_cur, Tensor(), false);
307     }
308     use_first_last = 1 - use_first_last;
309   }
310 }
311
312 void Manager::initializeTensorsTrain(unsigned int max_exec_order_) {
313   // Initialize gradients
314   initializeGradients();
315
316   // Initialize shared derivative memory
317   if (max_derivative_size > 0 && enable_activation_memory_opt)
318     shared_deriv = Tensor(TensorDim({max_derivative_size}), false);
319   for (unsigned int idx = 0; idx < in_outs.size(); idx++) {
320     auto &l_io = in_outs[idx];
321     unsigned int offset = 0;
322     bool is_last_layer = idx == in_outs.size() - 1;
323
324     for (auto &io : l_io) {
325       // Last layer requires separate memory allocations for output and label
326       // (deriv)
327       if (enable_derivative_memory_opt && !is_last_layer) {
328         // Training Mode with optimizations
329         if (enable_activation_memory_opt &&
330             (is_rnn_type[idx] || is_act_type[idx])) {
331           io->initialize(
332             Tensor(), shared_deriv.getSharedDataTensor(io->getDim(), offset));
333           offset += io->getDim().getDataLen();
334         } else {
335           io->initializeShared();
336         }
337
338       } else {
339         // Training Mode without optimizations
340         io->initialize(Tensor(), Tensor(), true);
341       }
342     }
343   }
344 }
345 #endif
346
347 /**
348  * @brief     Create weights with the given spec
349  *
350  */
351 std::vector<Weight *> Manager::requestWeights(
352   const GraphNode &node, const std::vector<Weight::Spec> &weights_spec,
353   bool trainable, const std::vector<std::string> &shared_names) {
354   const auto [forwarding_order, calcGradient_order, calcDerivative_order] =
355     node.getExecutionOrder();
356
357   std::vector<unsigned int> default_var_exec_order(
358     {forwarding_order, calcDerivative_order});
359   std::vector<unsigned int> default_grad_exec_order({calcDerivative_order});
360
361   TensorLifespan var_ls = TensorLifespan::MAX_LIFESPAN;
362   TensorLifespan grad_ls = TensorLifespan::BACKWARD_FUNC_LIFESPAN;
363
364   std::vector<Weight *> ret;
365   size_t current_size = weights_v2.size();
366
367   for (unsigned int i = 0; i < weights_spec.size(); ++i) {
368     auto &[dim, t_initializer, w_reg, w_reg_const, decay, clip_by_global_norm,
369            need_gradient, name] = weights_spec.at(i);
370     auto var_exec_order = default_var_exec_order;
371     auto grad_exec_order = default_grad_exec_order;
372
373     if (trainable) {
374       var_exec_order.insert(var_exec_order.begin(), calcGradient_order);
375       grad_exec_order.insert(grad_exec_order.begin(), calcGradient_order);
376     }
377
378     /**
379      * If the weight is supposed to be clip by global norm, extend its exec
380      * order with the max exec order where it will be used for clipping and then
381      * applied to the weight.
382      */
383     if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm))
384       grad_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
385
386     Tensor *var = nullptr, *grad = nullptr;
387     bool is_dependent = !shared_names.empty();
388     if (is_dependent) {
389       /// shared_name is used and the orignal name is discarded
390       const auto &shared_name = shared_names.at(i);
391       /** case when shared names are given */
392       var = weight_pool.requestOrExtend(shared_name, dim, var_exec_order,
393                                         var_ls, t_initializer);
394
395       if (trainable && need_gradient) {
396         grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
397                                            dim, grad_exec_order, grad_ls,
398                                            Tensor::Initializer::ZEROS);
399       }
400     } else {
401       /** case requesting fresh weights */
402       var =
403         weight_pool.request(name, dim, var_exec_order, var_ls, t_initializer);
404
405       if (trainable && need_gradient)
406         grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim,
407                                    grad_exec_order, grad_ls,
408                                    Tensor::Initializer::ZEROS);
409     }
410
411     weights_v2.emplace_back(std::make_unique<Weight>(
412       var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm));
413   }
414
415   std::transform(weights_v2.begin() + current_size, weights_v2.end(),
416                  std::back_inserter(ret),
417                  [](auto const &elem) { return elem.get(); });
418
419   return ret;
420 }
421
422 /**
423  * @brief     Create weights with the given spec
424  *
425  */
426 std::vector<Var_Grad *> Manager::requestTensors(
427   const GraphNode &node, const std::vector<Var_Grad::Spec> &tensors_spec,
428   bool trainable, const std::vector<std::string> &shared_names) {
429   const auto [forwarding_order, calcGradient_order, calcDerivative_order] =
430     node.getExecutionOrder();
431
432   std::vector<Var_Grad *> ret;
433   size_t current_size = tensors_v2.size();
434
435   for (unsigned int i = 0; i < tensors_spec.size(); ++i) {
436     auto const &[dim, t_init, need_grad, name, tspan] = tensors_spec.at(i);
437
438     std::vector<unsigned int> var_exec_order;
439     std::vector<unsigned int> grad_exec_order;
440
441     /** usage for tensors */
442     if (enum_class_logical_and(tspan, TensorLifespan::FORWARD_FUNC_LIFESPAN))
443       var_exec_order.push_back(forwarding_order);
444
445     /** usage for tensors gradient in backwarding */
446     if (trainable &&
447         enum_class_logical_and(tspan, TensorLifespan::CALC_GRAD_LIFESPAN)) {
448       var_exec_order.push_back(calcGradient_order);
449       grad_exec_order.push_back(calcGradient_order);
450     }
451
452     if (enum_class_logical_and(tspan, TensorLifespan::CALC_DERIV_LIFESPAN)) {
453       var_exec_order.push_back(calcDerivative_order);
454       grad_exec_order.push_back(calcDerivative_order);
455     }
456
457     bool is_dependent = !shared_names.empty();
458     Tensor *var = nullptr, *grad = nullptr;
459
460     if (is_dependent) {
461       const auto &shared_name = shared_names.at(i);
462       var = tensor_pool.requestOrExtend(shared_name, dim, var_exec_order, tspan,
463                                         t_init);
464       if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
465         grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
466                                            dim, grad_exec_order, tspan,
467                                            Tensor::Initializer::ZEROS);
468       }
469     } else {
470       var = tensor_pool.request(name, dim, var_exec_order, tspan, t_init);
471
472       if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
473         grad =
474           tensor_pool.request(name + Var_Grad::grad_suffix, /// name
475                               dim, grad_exec_order, tspan,
476                               Tensor::Initializer::ZEROS /// tensor initializer
477           );
478       }
479     }
480
481     tensors_v2.emplace_back(std::make_unique<Var_Grad>(var, grad));
482   }
483
484   std::transform(tensors_v2.begin() + current_size, tensors_v2.end(),
485                  std::back_inserter(ret),
486                  [](auto const &elem) { return elem.get(); });
487
488   return ret;
489 }
490
491 /**
492  * @brief     Create tensors with the given spec
493  */
494 std::vector<Var_Grad *>
495 Manager::requestInputs(const GraphNode &node,
496                        const std::vector<TensorDim> &inputs_dim,
497                        const std::vector<std::string> &outputs_name) {
498   using RT = TensorSpecV2::RequestType;
499
500   TensorSpecV2 var_common_spec, grad_common_spec;
501   var_common_spec.ls = TensorLifespan::FORWARD_GRAD_LIFESPAN;
502   grad_common_spec.ls = TensorLifespan::CALC_DERIV_LIFESPAN;
503
504   /// @todo handle this inside layer
505   if (node.getType() == ActivationLayer::type or
506       node.getType() == MultiOutLayer::type or
507       node.getType() == BatchNormalizationLayer::type or
508       node.getType() == LayerNormalizationLayer::type)
509     var_common_spec.ls = TensorLifespan::FORWARD_FUNC_LIFESPAN;
510
511   std::vector<Var_Grad *> ret;
512   size_t current_size = inputs_v2.size();
513
514   for (unsigned int idx = 0; idx < inputs_dim.size(); idx++) {
515     TensorSpecV2 var_spec = var_common_spec, grad_spec = grad_common_spec;
516
517     var_spec.name = std::string("input") + std::to_string(idx);
518     var_spec.dim = inputs_dim[idx];
519
520     grad_spec.name = var_spec.name + Var_Grad::grad_suffix;
521     grad_spec.dim = inputs_dim[idx];
522
523     if (!outputs_name.empty()) {
524       grad_spec.request_type = var_spec.request_type = RT::READ_ONLY_VIEW;
525       var_spec.reference_name = outputs_name[idx];
526       grad_spec.reference_name = outputs_name[idx] + Var_Grad::grad_suffix;
527     } else if (!node.getInputConnections().empty()) {
528       grad_spec.request_type = var_spec.request_type = RT::UNIQUE;
529     } else {
530       var_spec.request_type = RT::PLACEHOLDER;
531
532 #ifdef ENABLE_TEST
533       grad_spec.request_type = RT::UNIQUE;
534 #else
535       grad_spec.request_type = RT::PLACEHOLDER;
536 #endif
537     }
538
539     inputs_v2.emplace_back(std::make_unique<Var_Grad>(
540       requestTensor_(var_spec, node.getExecutionOrder(), node.getName(),
541                      tensor_pool, false),
542       requestTensor_(grad_spec, node.getExecutionOrder(), node.getName(),
543                      tensor_pool, false)));
544   }
545
546   ret.reserve(inputs_dim.size());
547   std::transform(inputs_v2.begin() + current_size, inputs_v2.end(),
548                  std::back_inserter(ret),
549                  [](auto const &elem) { return elem.get(); });
550
551   return ret;
552 }
553
554 std::pair<unsigned int, unsigned int>
555 Manager::getMinMaxTensorExecutionOrder(const std::string &name,
556                                        bool is_weight) {
557
558   auto orders = is_weight ? weight_pool.getExecutionOrder(name)
559                           : tensor_pool.getExecutionOrder(name);
560   auto [min_, max_] = std::minmax_element(orders.begin(), orders.end());
561   return {*min_, *max_};
562 }
563
564 unsigned int Manager::getSecondMaxTensorExecutionOrder(const std::string &name,
565                                                        bool is_weight) {
566
567   auto orders = is_weight ? weight_pool.getExecutionOrder(name)
568                           : tensor_pool.getExecutionOrder(name);
569   if (orders.size() < 2)
570     throw std::runtime_error(
571       "Requesting second last access with less than 2 exec orders");
572   /** tensor pool exec order can have same exec order multiple times */
573   std::sort(orders.begin(), orders.end());
574   orders.erase(std::unique(orders.begin(), orders.end()), orders.end());
575   return orders[orders.size() - 2];
576 }
577
578 bool Manager::isFirstAccess(const std::string &name, unsigned current_execution,
579                             bool is_weight) {
580   /// @todo add cache machanism, eg) sort at finalizing requesting
581   return getMinMaxTensorExecutionOrder(name, is_weight).first ==
582          current_execution;
583 }
584
585 bool Manager::isLastAccess(const std::string &name, unsigned current_execution,
586                            bool is_weight) {
587   /// @todo add cache machanism, eg) sort at finalizing requesting
588   return getMinMaxTensorExecutionOrder(name, is_weight).second ==
589          current_execution;
590 }
591
592 bool Manager::isSecondLastAccess(const std::string &name,
593                                  unsigned current_execution, bool is_weight) {
594   /// @todo add cache machanism, eg) sort at finalizing requesting
595   return getSecondMaxTensorExecutionOrder(name, is_weight) == current_execution;
596 }
597
598 /**
599  * @brief     Create tensors with the given spec
600  *
601  */
602 std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
603   const std::vector<TensorDim> &dims, const std::string &name,
604   const TensorLifespan &lifespan, Tensor::Initializer initializer) {
605   auto const exec_order = weight_pool.getExecutionOrder(name);
606
607   std::vector<Tensor *> ret;
608   ret.reserve(dims.size());
609
610   /// @note this is assuming weight optimizer variables is treated as weight, if
611   /// not, there is room to optimize below behavior
612   for (unsigned int idx = 0; idx < dims.size(); idx++)
613     ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx),
614                                       dims[idx], exec_order, lifespan,
615                                       initializer));
616
617   return ret;
618 }
619
620 std::vector<Weight *>
621 Manager::getWeights(const std::function<bool(const Weight *)> &condition) {
622   std::vector<Weight *> conditional_weights;
623
624   for (auto &w : weights_v2) {
625     if (!condition || condition(w.get()))
626       conditional_weights.push_back(w.get());
627   }
628
629   return conditional_weights;
630 }
631
632 void Manager::flushCache() {
633   weight_pool.flushCache();
634   tensor_pool.flushCache();
635 }
636
637 void Manager::flushCacheExcept(unsigned int order) {
638   weight_pool.flushCacheExcept(order);
639   tensor_pool.flushCacheExcept(order);
640 }
641
642 void Manager::finalizeTensorPool(TensorPool &pool, unsigned int start,
643                                  unsigned int end) {
644   if (enable_optimizations)
645     pool.finalize(OptimizedV1Planner(), start, end);
646   else
647     pool.finalize(BasicPlanner(), start, end);
648 }
649
650 } // namespace nntrainer