[Tensor] Remove calcGrad step for trainable layer
[platform/core/ml/nntrainer.git] / nntrainer / graph / network_graph.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
4  *
5  * @file    network_graph.h
6  * @date    19 Oct 2020
7  * @see     https://github.com/nnstreamer/nntrainer
8  * @author  Jijoong Moon <jijoong.moon@samsung.com>
9  * @bug     No known bugs except for NYI items
10  * @brief   This is Network Graph Class for Neural Network
11  *
12  * @todo    Support multi-input graph.
13  */
14
15 #include "tensor.h"
16 #include <cmath>
17 #include <stdexcept>
18 #include <string>
19
20 #include <activation_layer.h>
21 #include <addition_layer.h>
22 #include <bn_layer.h>
23 #include <concat_layer.h>
24 #include <connection.h>
25 #include <cross_entropy_loss_layer.h>
26 #include <cross_entropy_sigmoid_loss_layer.h>
27 #include <cross_entropy_softmax_loss_layer.h>
28 #include <flatten_layer.h>
29 #include <identity_layer.h>
30 #include <input_layer.h>
31 #include <layer_node.h>
32 #include <layer_normalization_layer.h>
33 #include <multiout_layer.h>
34 #include <network_graph.h>
35 #include <nntrainer_error.h>
36 #include <nntrainer_log.h>
37 #include <profiler.h>
38 #include <rnn.h>
39 #include <split_layer.h>
40 #include <time_dist.h>
41 #include <util_func.h>
42
43 #define LNODE(x) std::static_pointer_cast<LayerNode>(x)
44
45 namespace nntrainer {
46
47 int NetworkGraph::compile(const std::string &loss_type) {
48   int status = ML_ERROR_NONE;
49
50   status = isCompilable();
51   NN_RETURN_STATUS();
52
53   try {
54     setOutputConnections();
55   } catch (std::exception &e) {
56     ml_loge("setting output layer failed, reason: %s", e.what());
57     return ML_ERROR_INVALID_PARAMETER;
58   }
59
60   graph.realizeInputOutputNode();
61
62   try {
63     /// @todo realize loss beforehand
64     status = addLossLayer(loss_type);
65     NN_RETURN_STATUS();
66   } catch (const std::exception &e) {
67     ml_loge("%s", e.what());
68     status = ML_ERROR_INVALID_PARAMETER;
69     NN_RETURN_STATUS();
70   }
71
72   graph.topologicalSort();
73
74   setExecutionOrder();
75   forward_iter_end = (*(cend() - 1)).get();
76
77   inPlaceOptimize();
78
79   status = checkCompiledGraph();
80   NN_RETURN_STATUS();
81
82   compiled = true;
83
84   return status;
85 }
86
87 void NetworkGraph::setExecutionOrder() {
88   auto max_count = graph.size() * 3;
89   /** @todo: remove backwarding count for non-trainble layers */
90   for (auto iter = cbegin(); iter != cend(); iter++) {
91     auto &node = *iter;
92     auto order_idx = iter - cbegin();
93     auto forward_order = order_idx;
94     auto calc_gradient_order = max_count - ((order_idx + 1) * 2);
95     /** calc derivative is called right after calc_gradient */
96     auto calc_derivative_order = calc_gradient_order + 1;
97     node->setExecutionOrder(
98       {forward_order, calc_gradient_order, calc_derivative_order});
99   }
100
101   /**
102    * This sets max execution order temporarily till model is initialized.
103    * This set max execution order is used to extend gradient exec orders for
104    * clipping.
105    */
106   graph_exec_end = std::get<2>((*(cbegin()))->getExecutionOrder());
107 }
108
109 void NetworkGraph::addLayerNode(std::unique_ptr<Layer> layer) {
110   graph.addNode(std::make_unique<LayerNode>(std::move(layer)));
111 }
112
113 int NetworkGraph::addLossLayer(const std::string &loss_type_) {
114   for (unsigned int i = 0; i < graph.getNumOutputNodes(); ++i) {
115     auto output_layer_node = LNODE(graph.getOutputNode(i));
116     std::string loss_type = loss_type_;
117
118     if (output_layer_node->requireLabel())
119       continue;
120
121     if (loss_type.empty())
122       continue;
123
124     auto second_to_last_layer_node = output_layer_node;
125     bool is_cross_entropy_loss =
126       istrequal(loss_type, CrossEntropyLossLayer::type);
127     if (is_cross_entropy_loss) {
128       auto type = output_layer_node->getType();
129
130       if (type != ActivationLayer::type) {
131         throw exception::not_supported(
132           "Error: Cross Entropy need last layer to have softmax or sigmoid"
133           "activation.");
134       }
135
136       switch (output_layer_node->getActivationType()) {
137       case ActivationType::ACT_SIGMOID:
138         loss_type = CrossEntropySigmoidLossLayer::type;
139         break;
140       case ActivationType::ACT_SOFTMAX:
141         loss_type = CrossEntropySoftmaxLossLayer::type;
142         break;
143       default:
144         throw exception::not_supported(
145           "Error: Cross Entropy not supported without softmax or sigmoid.");
146       }
147
148       second_to_last_layer_node =
149         LNODE(graph.getNode(output_layer_node->getInputConnectionName(0)));
150     }
151
152     std::shared_ptr<LayerNode> lnode = createLayerNode(loss_type);
153     graph.ensureName(*lnode);
154
155     if (second_to_last_layer_node->getDistribute()) {
156       lnode->setProperty({"distribute=true"});
157     }
158
159     /// @todo remove this by add loss at realization
160     second_to_last_layer_node->setOutputLayers({lnode->getName()});
161     lnode->setProperty(
162       {"input_layers=" + second_to_last_layer_node->getName()});
163
164     if (is_cross_entropy_loss) {
165       graph.replaceNode(output_layer_node, lnode);
166     } else {
167       graph.addNode(lnode, false);
168     }
169     graph.replaceOutputNode(i, lnode);
170   }
171
172   return ML_ERROR_NONE;
173 }
174
175 void NetworkGraph::setOutputConnections() {
176   for (auto layer_iter = cbegin(); layer_iter != cend(); layer_iter++) {
177     const auto &node = *layer_iter;
178     for (auto i = 0u, num_inode = node->getNumInputConnections(); i < num_inode;
179          ++i) {
180       const auto &name = node->getInputConnectionName(i);
181       const auto &idx = node->getInputConnectionIndex(i);
182
183       auto node_setting_output = getLayerNode(name);
184       node_setting_output->setOutputConnection(idx, node->getName(), i);
185     }
186   }
187 }
188
189 int NetworkGraph::isCompilable() {
190   if (compiled) {
191     ml_loge("Graph is already compiled");
192     return ML_ERROR_NOT_SUPPORTED;
193   }
194
195   if (graph.empty()) {
196     ml_loge("Graph is empty");
197     return ML_ERROR_INVALID_PARAMETER;
198   }
199
200   return ML_ERROR_NONE;
201 }
202
203 int NetworkGraph::checkCompiledGraph() {
204   /** Dimension of input layers must be known */
205   for (auto iter = cbegin(); iter != cend(); iter++) {
206     auto lnode = (*iter);
207     if (lnode->getNumInputConnections() == 0) {
208       if (!lnode->hasInputShapeProperty()) {
209         ml_loge("Layer with no inbound connection need input_shape property");
210         return ML_ERROR_INVALID_PARAMETER;
211       }
212     }
213   }
214
215   return ML_ERROR_NONE;
216 }
217
218 void NetworkGraph::markNodesForBackwarding() {
219   /** accumulate all the nodes which must support backwarding */
220   std::unordered_set<std::string> must_support_backwarding;
221
222   /**
223    * if a node is trainable, then all the nodes ahead of it must support
224    * backwarding operation
225    */
226   for (auto iter = cbegin(); iter != cend(); iter++) {
227     auto lnode = (*iter);
228     if (lnode->getTrainable() ||
229         must_support_backwarding.find(lnode->getName()) !=
230           must_support_backwarding.end()) {
231       if (lnode->getTrainable()) {
232         lnode->needsCalcGradient(true);
233       }
234 #ifdef ENABLE_TEST
235       if (lnode->supportBackwarding() && !optimize_memory) {
236         lnode->needsCalcDerivative(true);
237       }
238 #endif
239
240       for (auto i = 0u, num_node = lnode->getNumOutputConnections();
241            i < num_node; ++i) {
242         auto conn = lnode->getOutputConnection(i);
243         if (!conn) {
244           continue;
245         }
246
247         must_support_backwarding.insert(conn->getName());
248       }
249     }
250   }
251
252   /** mark all the required nodes support backwarding */
253   for (auto const &node_name : must_support_backwarding) {
254     auto ln = LNODE(graph.getNode(node_name)).get();
255     ln->needsCalcDerivative(true);
256   }
257 }
258
259 void NetworkGraph::setBatchSize(unsigned int batch_size) {
260   if (batch_size == this->batch_size)
261     return;
262
263   this->batch_size = batch_size;
264   if (!input_list.empty() && getInputDimension()[0].batch() == batch_size)
265     return;
266
267   auto allocated = tensor_manager->isAllocated();
268
269   if (allocated)
270     deallocateTensors();
271
272   for (auto iter = cbegin(); iter != cend(); iter++) {
273     if ((*iter)->isFinalized()) {
274       /// resize tensors spec
275       /// @todo remove below, if cutsom tensor needs to change dimension
276       /// according to the tensor, it must be done explicitly, or at least have
277       /// a property to control the behavior
278       const RunLayerContext &context = (*iter)->getRunContext();
279       for (unsigned int idx = 0; idx < context.getNumTensors(); idx++) {
280         auto const &ts = context.getTensor(idx);
281         tensor_manager->setBatchSize(ts.getName(), ts.getDim().batch());
282         if (context.tensorHasGradient(idx)) {
283           auto const &ts_grad = context.getTensorGrad(idx);
284           tensor_manager->setBatchSize(ts_grad.getName(),
285                                        ts_grad.getDim().batch());
286         }
287       }
288       /// override setting batch as per request
289       (*iter)->setBatch(batch_size);
290     }
291   }
292   /// resize input and output spec
293   tensor_manager->setBatchSize(batch_size);
294
295   if (allocated)
296     allocateTensors(exec_mode);
297
298   /** update input and label dimensions */
299   for (unsigned int idx = 0; idx < input_list.size(); idx++)
300     input_dims[idx] = tensor_manager->getTensor(input_list[idx])->getDim();
301   for (unsigned int idx = 0; idx < label_list.size(); idx++)
302     label_dims[idx] = tensor_manager->getTensor(label_list[idx])->getDim();
303 }
304
305 void NetworkGraph::applyGradients(
306   LayerNode *node, const std::function<void(Weight &)> &apply_func) {
307   auto &rc = node->getRunContext();
308   auto num_weight = rc.getNumWeights();
309   for (unsigned i = 0; i < num_weight; ++i) {
310     if (!rc.weightHasGradient(i)) {
311       continue;
312     }
313
314     if (!rc.isGradientLastAccess(i)) {
315       /// @note instead of checking the last access of the weight, checking
316       /// if weights are dependent to others to minimize overhead.
317       /// this logic assums that the source of the dependent weight must be
318       /// prior to the dependent.
319       continue;
320     }
321
322     if (rc.isGradientClipByGlobalNorm(i)) {
323       /**
324        * @note the weights whose gradient are to be clipped by global norm will
325        * be clipped at once at the end of iteration and applied then.
326        */
327       continue;
328     }
329
330     apply_func(rc.getWeightObject(i));
331   }
332 }
333
334 sharedConstTensors
335 NetworkGraph::forwarding(bool training,
336                          std::function<bool(void *userdata)> stop_cb) {
337   for (auto iter = cbegin(); iter != cend() && !stop_cb(nullptr); iter++) {
338     auto const &ln = *iter;
339     PROFILE_TIME_START(profile_keys.at(ln->getType()));
340     PROFILE_MEM_ANNOTATE("Forwarding for layer: " + ln->getName());
341
342     auto f = std::get<0>(ln->getExecutionOrder());
343     flushCacheExcept(f);
344
345     ln->forwarding(training);
346
347     PROFILE_TIME_END(profile_keys.at(ln->getType()));
348   }
349
350   sharedConstTensors out;
351   for (unsigned int i = 0; i < graph.getNumOutputNodes(); ++i) {
352     auto const &output_layer_node = LNODE(graph.getOutputNode(i));
353     for (unsigned int j = 0; j < output_layer_node->getNumOutputs(); ++j) {
354       out.push_back(MAKE_SHARED_TENSOR(output_layer_node->getOutput(j)));
355     }
356   }
357
358   return out;
359 }
360
361 void NetworkGraph::backwarding(
362   int iteration,
363   std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
364   std::function<void(Weight &, int)> &apply_grad_clip_op,
365   std::function<bool(void *userdata)> stop_cb) const {
366   /**
367    * last layer backwarding is run out of this loop
368    */
369   auto iter_begin = getBackwardingBeginIter();
370   auto iter_end = getBackwardingEndIter();
371
372   /// there is no layer to train, so backwarding is essentially noop
373   if (iter_begin == iter_end) {
374     return;
375   }
376
377   auto const &lptr_begin = (*iter_begin);
378
379   if (lptr_begin->requireLabel() == false)
380     throw std::runtime_error(
381       "Error: last layer does not accept label, we can't train");
382
383   for (auto iter = iter_begin; iter != iter_end && !stop_cb(nullptr); iter++) {
384     auto &ln = *iter;
385     PROFILE_TIME_START(profile_keys.at(ln->getType()));
386     backwarding_op(ln, iteration);
387     PROFILE_TIME_END(profile_keys.at(ln->getType()));
388   }
389
390   /** perform clipping of the gradients by global norm if any */
391   if (clip_weights.empty())
392     return;
393
394   /** calculate the global norm */
395   Tensor global_norm_t(
396     TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()}));
397   float *global_norm_data = global_norm_t.getData();
398   for (unsigned int idx = 0; idx < clip_weights.size(); idx++) {
399     auto const &w = clip_weights[idx];
400     global_norm_data[idx] = w->getGradientNorm();
401   }
402   float global_norm = global_norm_t.l2norm();
403   /** apply the gradient with the above global norm */
404   for (auto w : clip_weights) {
405     w->clipGradientByGlobalNorm(global_norm);
406   }
407   /** apply the gradient with the above global norm */
408   for (auto w : clip_weights) {
409     apply_grad_clip_op(*w, iteration);
410   }
411 }
412
413 LayerNode *NetworkGraph::computeBackwardEnd() {
414   int max_exec_order = -1;
415   LayerNode *node = nullptr;
416
417   if (!optimize_memory) {
418     return (*cbegin()).get();
419   }
420
421   for (auto iter = getBackwardingBeginIter(); iter != getBackwardingEndIter();
422        iter++) {
423     auto &ln = *iter;
424     const auto &exec_order = ln->getExecutionOrder();
425     int cur_order = std::get<0>(exec_order);
426     if (ln->needsCalcDerivative() || ln->needsCalcGradient()) {
427 #ifdef ENABLE_TEST
428       cur_order = std::get<2>(exec_order);
429 #else
430       cur_order = std::get<1>(exec_order);
431 #endif
432     }
433
434     NNTR_THROW_IF(max_exec_order == cur_order, std::invalid_argument)
435       << "layer node: " << ln->getName()
436       << " has duplicated max_exec_order, this should not happen, current "
437          "execution order: "
438       << max_exec_order;
439
440     if (max_exec_order < cur_order) {
441       max_exec_order = cur_order;
442       node = ln.get();
443     }
444   }
445
446   return node;
447 }
448
449 /**
450  * @brief Allocate memory for all the managed tensors
451  */
452 void NetworkGraph::allocateTensors(ExecutionMode exec_mode_) {
453   exec_mode = exec_mode_;
454   if (exec_mode == ExecutionMode::INFERENCE)
455     /**
456      * get the order of execution/usage order for the forwarding of the last
457      * layer and pass that as the max_exec_order ensuring that all tensors
458      * with usage less than the max_exec_order are allocated.
459      */
460     tensor_manager->allocateTensors(
461       std::get<0>((*(cend() - 1))->getExecutionOrder()));
462   else {
463     /**
464      * get the order of execution/usage order for the backwarding of the first
465      * layer (as that will be the last layer to executed in the backwarding)
466      * and pass that as the max_exec_order ensuring that all tensors with
467      * usage less than the max_exec_order are allocated.
468      */
469     tensor_manager->allocateTensors(
470       std::get<2>(backward_iter_end->getExecutionOrder()));
471   }
472 }
473
474 std::vector<TensorDim> NetworkGraph::getInputDimension() const {
475   NNTR_THROW_IF(input_dims.empty(), std::invalid_argument)
476     << "[NetworkGraph] the graph has no node identified as input!";
477   return input_dims;
478 }
479
480 unsigned int NetworkGraph::getBatchSize() const { return batch_size; }
481
482 std::vector<TensorDim> NetworkGraph::getOutputDimension() const {
483   NNTR_THROW_IF(label_dims.empty(), std::invalid_argument)
484     << "[NetworkGraph] the graph has no node identified as output!";
485   /// for now, outputting label_dims works, later label dim will be different
486   /// from output dimension
487   return label_dims;
488 }
489
490 std::vector<std::shared_ptr<LayerNode>>
491 NetworkGraph::getUnsortedLayers(const std::string &input_layer,
492                                 const std::string &output_layer) const {
493   /// @fixme: this won't work if input, output layers are not in order
494   /// Further, this function must be removed. There should be rather
495   /// getAllNames and getLayerByName instead of getUnsortedLayers.
496
497   /** count layers after output layer */
498   unsigned int num_layers_remove_end = 0;
499   if (!output_layer.empty()) {
500     for (auto iter = graph.crbegin(); iter != graph.crend(); iter++) {
501       if ((*iter)->getName() != output_layer)
502         num_layers_remove_end++;
503       else
504         break;
505     }
506   }
507
508   if (num_layers_remove_end == graph.size())
509     return {};
510
511   /** count layers before input layer */
512   unsigned int num_layers_remove_start = 0;
513   if (!input_layer.empty()) {
514     for (auto iter = graph.cbegin();
515          iter != graph.cend() - num_layers_remove_end; iter++) {
516       if ((*iter)->getName() != input_layer)
517         num_layers_remove_start++;
518       else
519         break;
520     }
521   }
522
523   /** copy the graph and return */
524   std::vector<std::shared_ptr<LayerNode>> ret;
525   std::transform(graph.cbegin() + num_layers_remove_start,
526                  graph.cend() - num_layers_remove_end, std::back_inserter(ret),
527                  [](auto const &elem) { return LNODE(elem); });
528
529   return ret;
530 }
531
532 std::vector<std::shared_ptr<LayerNode>> NetworkGraph::getLayerNodes() const {
533   return std::vector<std::shared_ptr<LayerNode>>(cbegin(), cend());
534 }
535
536 void NetworkGraph::addLayer(std::shared_ptr<LayerNode> layer) {
537   if (compiled)
538     throw std::runtime_error("Cannot modify graph after compile");
539
540   /** Insert the layer to the graph */
541   graph.addNode(layer);
542 }
543
544 InPlace
545 NetworkGraph::canExecuteInPlace(const std::shared_ptr<LayerNode> &lnode) {
546   if (!lnode->supportInPlace())
547     return InPlace::NONE;
548
549   /** layers which behave as a no-op - flatten */
550   auto no_op = [](const std::shared_ptr<LayerNode> &lnode) {
551     return lnode->getType() == FlattenLayer::type ||
552            lnode->getType() == IdentityLayer::type;
553   };
554
555   /** layers which behave as a no-op but shares memory among parallel nodes -
556    * multiout */
557   auto no_op_shared = [](const std::shared_ptr<LayerNode> &lnode) {
558     return lnode->getType() == MultiOutLayer::type;
559   };
560
561   /**
562    * layers whose backwarding is not dependent on input/output but only its
563    * derivatives and weights, if any - batch normalization
564    */
565   auto io_independent_backwarding =
566     [](const std::shared_ptr<LayerNode> &lnode) {
567       return (lnode->getType() == BatchNormalizationLayer::type) ||
568              (lnode->getType() == LayerNormalizationLayer::type);
569     };
570
571   /**
572    * @note Conditions to decide if this layer node can be in-place:
573    * 1. if the layer is a no-op, then it can operate in-place as it is not
574    * modifying its input/output tensors and does not need to check its
575    * neighboring nodes for dependency.
576    * 2. if the layer is not supporting backwarding, there is no dependency
577    * requirement with other nodes for backwarding.
578    *
579    * @note Conditions to decide the type of inplace for this layer:
580    * 1. if the previous layers were restricting, then this layer will also be
581    * restricting.
582    * 2. if the previous layer were non_restricting or not inplace, then this
583    * layer will be non-restricting.
584    */
585   if (no_op(lnode) || !lnode->supportBackwarding()) {
586     for (auto i = 0u, num_node = lnode->getNumInputConnections(); i < num_node;
587          ++i) {
588       const auto &input_name = lnode->getInputConnectionName(i);
589       if (getLayerNode(input_name)->executeInPlace() == InPlace::RESTRICTING)
590         return InPlace::RESTRICTING;
591     }
592     return InPlace::NON_RESTRICTING;
593   }
594
595   /**
596    * @note Conditions to decide if this layer node can be in-place:
597    * if the layer is a no-op-shared, then it can operate in-place as it is not
598    * modifying its input/output tensors and does not need to check its
599    * neighboring nodes for dependency.
600    *
601    * @note Conditions to decide the type of inplace for this layer:
602    * As all the output nodes are sharing memory, the output nodes cant execute
603    * inplace, and then its restricting mode.
604    */
605   if (no_op_shared(lnode))
606     return InPlace::RESTRICTING;
607
608   /**
609    * @note Conditions to decide if this layer node can be in-place:
610    * This is a generic case where the layer can support in-place but will
611    * modify its input in-place. This includes layers like activation, etc.
612    * Apply checks below to ensure that the layers can work in-place:
613    * - if any of the input layer are restriction, then this layer cannot work
614    *   as layers behind this layer have added restrictions.
615    * - if all of the input layers are either not inplace or have no
616    * restrictions, then this layer can operate in-place.
617    *
618    * @note Conditions to decide the type of inplace for this layer:
619    * This is a generic case, and always restrictions on the next nodes to be
620    * not inplace.
621    *
622    * @note This logic is prone to change as more layers are allowed to
623    * work in-place such as concat layer, split layer, addition layer, dropout
624    * layer, etc.
625    *
626    * @todo This logic sets layers to in-place one-by-one as they arrive. However
627    * setting some layers to in-place can save more memory than others (like
628    * multiout layer vs activation layer). The layers need to sorted based on the
629    * memory save they provide and then make them in-place in that order.
630    */
631   if (lnode->getType() == ActivationLayer::type ||
632       lnode->getType() == BatchNormalizationLayer::type ||
633       lnode->getType() == LayerNormalizationLayer::type) {
634     for (auto i = 0u, num_node = lnode->getNumInputConnections(); i < num_node;
635          ++i) {
636       if (getLayerNode(lnode->getInputConnectionName(i))->executeInPlace() ==
637           InPlace::RESTRICTING)
638         return InPlace::NONE;
639     }
640
641     /**
642      * if the layer does io_independent_backwarding where the input and output
643      * is not required during backwarding, then it is a non-restricting in-place
644      * layer.
645      */
646     if (io_independent_backwarding(lnode))
647       return InPlace::NON_RESTRICTING;
648
649     return InPlace::RESTRICTING;
650   }
651
652   return InPlace::NONE;
653 }
654
655 void NetworkGraph::inPlaceOptimize() {
656   if (optimize_memory) {
657     for (unsigned int idx = 0; idx < graph.size(); ++idx) {
658       auto const &lnode = getSortedLayerNode(idx);
659       lnode->executeInPlace(canExecuteInPlace(lnode));
660     }
661   }
662 }
663
664 /**
665  * @brief Set the Inplace Shared Memory Config By Layer object
666  *
667  * @param lnode layer node object
668  * @param shared_var if the variable should be shared
669  * @param shared_grad if the gradient should be shared
670  */
671 static void
672 setInplaceSharedMemoryConfigByLayer(const std::shared_ptr<LayerNode> &lnode,
673                                     bool &shared_var, bool &shared_grad) {
674   /** for multiout layer, variables are shared but gradients are not */
675   if (lnode->getType() == MultiOutLayer::type) {
676     shared_var = true;
677     shared_grad = false;
678   } else {
679     shared_var = true;
680     shared_grad = true;
681   }
682   /** @todo for addition layer, variables are not shared but gradients are */
683   /**
684    * @todo for layers which support in-place, both variables and gradients
685    * will be shared.
686    *
687    * @todo add a check here is the layer being checked here can support
688    * in-place or not
689    */
690 }
691
692 std::vector<Var_Grad *>
693 NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
694                               const std::vector<Var_Grad *> &prev_inputs) {
695   const GraphNode &gnode = *lnode.get();
696   std::vector<TensorDim> input_dims;
697   input_dims.reserve(prev_inputs.size());
698   std::transform(prev_inputs.begin(), prev_inputs.end(),
699                  std::back_inserter(input_dims),
700                  [](const Var_Grad *vg) { return vg->getDim(); });
701
702   /** finalize the layer and get the final context */
703   auto init_context = lnode->finalize(input_dims);
704
705   /**
706    * Request manager for either a pre-allocated output as input or a newly
707    * allocated output. This is necessary for manager to know when this output
708    * node is going to be used.
709    */
710   std::vector<std::string> input_names;
711   input_names.reserve(prev_inputs.size());
712   std::transform(prev_inputs.begin(), prev_inputs.end(),
713                  std::back_inserter(input_names),
714                  [](auto const &vg) { return vg->getName(); });
715   const std::vector<Var_Grad *> &inputs = tensor_manager->requestInputs(
716     gnode, init_context.getInputDimensions(), input_names);
717
718   /** In-Place optimizations */
719   /**
720    * Request manager for either a pre-allocated input as output or a newly
721    * allocated output. This is necessary for manager to know when this output
722    * node is going to be used with in-place optimizations.
723    */
724   auto out_specs = init_context.getOutSpecs();
725   /// @note try move inplace control to finalize
726   bool shared_var = false, shared_grad = false;
727   if (lnode->executeInPlace() != InPlace::NONE) {
728     setInplaceSharedMemoryConfigByLayer(lnode, shared_var, shared_grad);
729     for (unsigned int i = 0; i < out_specs.size(); ++i) {
730       auto &s = out_specs.at(i);
731       if (shared_var) {
732         s.variable_spec.request_type =
733           TensorSpecV2::RequestType::READ_ONLY_VIEW;
734         if (lnode->getType() == IdentityLayer::type) {
735           s.variable_spec.reference_name = inputs[i]->getName();
736         } else {
737           s.variable_spec.reference_name = inputs[0]->getName();
738         }
739       }
740       if (shared_grad && s.gradient_spec) {
741         s.gradient_spec->request_type =
742           TensorSpecV2::RequestType::READ_ONLY_VIEW;
743         if (lnode->getType() == IdentityLayer::type) {
744           s.gradient_spec->reference_name = inputs[i]->getGradientName();
745         } else {
746           s.gradient_spec->reference_name = inputs[0]->getGradientName();
747         }
748       }
749     }
750   }
751   if (lnode->requireLabel()) {
752     NNTR_THROW_IF(out_specs.size() != 1, std::invalid_argument)
753       << "out specification size must be 1 for label layer for now, "
754       << lnode->getName() << " out spec size: " << out_specs.size();
755     NNTR_THROW_IF(out_specs[0].gradient_spec == nullptr, std::invalid_argument)
756       << "label space does not exist for " << lnode->getName();
757     out_specs[0].gradient_spec->request_type =
758       TensorSpecV2::RequestType::PLACEHOLDER;
759   }
760
761   /// @note below needs to be enabled only for inference mode, but need decision
762   /// if we are going to separate inference initialization from train
763   /// initialization this might not worth optimize because in general output of
764   /// a neuralnet is very small
765   if (lnode->getOutputConnections().size() == 0u) {
766     std::for_each(out_specs.begin(), out_specs.end(),
767                   [this](VarGradSpecV2 &spec) {
768                     spec.variable_spec.additional_exec_order.push_back(
769                       std::get<0>(forward_iter_end->getExecutionOrder()));
770                   });
771   }
772
773   const std::vector<Var_Grad *> &outputs = tensor_manager->requestTensors(
774     out_specs, Manager::TensorGroupType::OUTPUT, lnode->getExecutionOrder(),
775     lnode->getName());
776
777   /** create shared weight names if requested */
778   std::vector<std::string> shared_weight_names;
779   std::vector<std::string> shared_tensor_names;
780   if (auto shared_node_str = lnode->getSharedFrom(); !shared_node_str.empty()) {
781     /// @note below is commented but kept from quick fix to be referenced for
782     /// later(#1707)
783     // auto shared_node = getLayerNode(shared_node_str).get();
784     // NNTR_THROW_IF(shared_node == nullptr, std::invalid_argument)
785     //   << "shared_node requested but it is not registered in the graph,
786     //   name:
787     //   "
788     //   << shared_node_str << " requested from " << lnode->getName();
789     // NNTR_THROW_IF(shared_node->getType() != lnode->getType(),
790     //               std::invalid_argument)
791     //   << " shared_node and lnode type mismatch, source node type: "
792     //   << shared_node->getType() << " depedent node type: " <<
793     //   lnode->getType()
794     //   << " depedent node name: " << lnode->getName();
795     // NNTR_THROW_IF(!shared_node->isFinalized(), std::invalid_argument)
796     //   << "shared node must be prior to the dependent node and it should be
797     //   "
798     //      "finalized beforehand, shared node name: "
799     //   << shared_node_str << " dependent node name: " << lnode->getName();
800     // auto num_weight = shared_node->getNumWeights();
801     // shared_weight_names.reserve(num_weight);
802     // for (auto i = 0u; i < num_weight; ++i) {
803     //   shared_weight_names.emplace_back(shared_node->getWeightName(i));
804     // }
805     // auto &rc = node->getRunContext();
806
807     /// @fixme tensor should be only shared if context explicitly requested to
808     /// do so. This has to be added to the part of tensor spec, other wise it
809     /// will break many things
810     const auto &t_specs = init_context.getTensorsSpec();
811     for (auto i = 0u; i < t_specs.size(); ++i) {
812       shared_tensor_names.emplace_back(std::get<3>(t_specs.at(i)));
813     }
814
815     const auto &w_specs = init_context.getWeightsSpec();
816     for (auto i = 0u; i < w_specs.size(); ++i) {
817       shared_weight_names.emplace_back(std::get<7>(w_specs.at(i)));
818     }
819   }
820
821   lnode->configureRunContext(
822     // TODO: update weights spec for trainable based on layer trainable prop
823     tensor_manager->requestWeights(gnode, init_context.getWeightsSpec(),
824                                    lnode->getTrainable(), shared_weight_names),
825     inputs, outputs,
826     tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(),
827                                    lnode->getTrainable(), shared_tensor_names));
828
829   return outputs;
830 }
831
832 int NetworkGraph::initialize(const std::vector<Connection> &model_input_names,
833                              const std::vector<Connection> &model_label_names) {
834
835   /**
836    * this contains the map from node name to its input tensor names
837    * @note: these input tensors have already been allocated
838    */
839   std::unordered_map<std::string, std::vector<Var_Grad *>> input_map;
840
841   /** check if the given config of node is of input node */
842   auto is_input_node = [](const LayerNode *node) -> bool {
843     return node->getInputConnections().empty();
844   };
845
846   for (unsigned int idx = 0; idx < graph.size(); ++idx) {
847     std::vector<Var_Grad *> inputs = {};
848     auto const &lnode = getSortedLayerNode(idx);
849     ml_logd("layer name : %s", lnode->getName().c_str());
850
851     if (profile_keys.find(lnode->getType()) == profile_keys.end()) {
852       int event_key = 0;
853       PROFILE_TIME_REGISTER_EVENT(event_key, lnode->getType());
854       profile_keys[lnode->getType()] = event_key;
855     }
856
857     /**
858      * Set input dimension for all the layers.
859      * For input layer, as input dimension is known, set input tensor.
860      */
861     if (!is_input_node(lnode.get())) {
862       if (input_map.find(lnode->getName()) == input_map.end())
863         throw std::runtime_error("Cannot find input buffers for the node");
864       inputs = input_map.at(lnode->getName());
865     }
866
867     /**
868      * Initialize all the layers, allocate output tensors for each layer
869      * init2and add optimizer related weights for the layer
870      */
871     const std::vector<Var_Grad *> &outputs = finalizeContext(lnode, inputs);
872
873     /** no need to update input_map for the last layer */
874     if (idx == graph.size() - 1)
875       break;
876
877     for (auto i = 0u, num_node = lnode->getNumOutputConnections(); i < num_node;
878          ++i) {
879       auto conn = lnode->getOutputConnection(i);
880       if (!conn) {
881         ml_logi("out connection not defined for  %s, %u",
882                 lnode->getName().c_str(), i);
883         continue;
884       }
885
886       auto sink_node = getLayerNode(conn->getName());
887       [[maybe_unused]] auto [it, b] =
888         input_map.try_emplace({sink_node->getName(), {}});
889
890       NNTR_THROW_IF(sink_node->getInputConnectionName(conn->getIndex()) !=
891                       lnode->getName(),
892                     std::invalid_argument)
893         << "node pair does not match between " << lnode->getName() << ' '
894         << sink_node->getName();
895
896       auto &sink_tensors = it->second;
897       sink_tensors.resize(sink_node->getNumInputConnections());
898       sink_tensors[conn->getIndex()] = outputs[i];
899     }
900   }
901
902   for (unsigned int idx = 0; idx < graph.size(); ++idx) {
903     auto const &lnode = getSortedLayerNode(idx);
904     auto &rc = lnode->getRunContext();
905     auto first_grad_access = std::get<1>(lnode->getExecutionOrder());
906     auto last_grad_access = std::get<2>(lnode->getExecutionOrder());
907     for (unsigned i = 0; i < rc.getNumWeights(); ++i) {
908       if (!rc.weightHasGradient(i)) {
909         /// @todo this is duck taping that MUST BE REMOVED. We will need to
910         /// have, is weight first access kind of concept.
911         if (tensor_manager->isFirstAccess(
912               rc.getWeight(i).getName(),
913               std::get<0>(lnode->getExecutionOrder()), true)) {
914           rc.getWeightObject(i).setAsGradientFirstAccess();
915         }
916         if (tensor_manager->isLastAccess(rc.getWeight(i).getName(),
917                                          last_grad_access, true)) {
918           rc.getWeightObject(i).setAsGradientLastAccess();
919         }
920       } else {
921         if (tensor_manager->isFirstAccess(rc.getWeightGrad(i).getName(),
922                                           first_grad_access)) {
923           rc.getWeightObject(i).setAsGradientFirstAccess();
924         }
925         /**
926          * if the gradient is to be clipped by global norm, then the last access
927          * is by clipping itself. However, as clipping is not a layer and does
928          * not contain any weights, such weights never get assigned
929          * gradient_last_access. This is a quick hotfix.
930          * TODO: make an independent clipping layer which will execute at the
931          * end, and will share ownership of weights which it will clip. This
932          * will remove this hot fix, and also remove the checks of if weights
933          * require clipping.
934          */
935         if (tensor_manager->isLastAccess(rc.getWeightGrad(i).getName(),
936                                          last_grad_access) ||
937             (rc.isGradientClipByGlobalNorm(i) &&
938              tensor_manager->isSecondLastAccess(rc.getWeightGrad(i).getName(),
939                                                 last_grad_access))) {
940           rc.getWeightObject(i).setAsGradientLastAccess();
941         }
942       }
943     }
944   }
945   /**** identify model input / output to be set externally later ****/
946   auto identify_as_model_input = [this](LayerNode *node) {
947     auto num_input = node->getNumInputs();
948     NNTR_THROW_IF(num_input != 1, std::invalid_argument)
949       << "Input layer is supposed to have exactly one input, but more then "
950          "one input detected, num inputs: "
951       << num_input;
952
953     input_list.push_back(node->getInput(0).getName());
954     input_dims.push_back(node->getInputDimensions()[0]);
955   };
956
957   auto is_label_node = [](LayerNode *node) { return node->requireLabel(); };
958
959   auto identify_as_model_label = [this](LayerNode *node) {
960     /// @todo change this as lnode->getNumLabels of sorts
961     auto num_label = node->getNumOutputs();
962     NNTR_THROW_IF(!node->getOutputConnections().empty(), std::invalid_argument)
963       << "label layer is supposed to be a leaf for now";
964     NNTR_THROW_IF(num_label != 1, std::invalid_argument)
965       << "label layer is supposed to have exactly one label, but more then "
966          "one label detected, num labels: "
967       << num_label;
968
969     /// @todo implement and use getLabel(0) instead.
970     output_list.push_back(node->getOutput(0).getName());
971     label_list.push_back(node->getOutputGrad(0).getName());
972     label_dims.push_back(node->getOutputDimensions()[0]);
973   };
974
975   auto identify_external_tensors = [this](const std::vector<Connection> &conns,
976                                           auto &&pred, auto &&identify) {
977     if (conns.empty()) {
978       for (unsigned int i = 0; i < graph.size(); ++i) {
979         auto lnode = getSortedLayerNode(i).get();
980         if (!pred(lnode)) {
981           continue;
982         }
983         /// when name is empty, we identify everything as the node, all of
984         /// them must be having identical dimensions
985         identify(lnode);
986       }
987     } else {
988       for (auto &conn : conns) {
989         auto lnode = getLayerNode(conn.getName()).get();
990         NNTR_THROW_IF(!pred(lnode), std::invalid_argument)
991           << "given node is not of that kind, name: " << conn.getName();
992         identify(lnode);
993       }
994       unsigned int num_node_of_kind = 0;
995       for (unsigned int i = 0; i < graph.size(); ++i) {
996         auto lnode = getSortedLayerNode(i).get();
997         if (!pred(lnode)) {
998           continue;
999         }
1000         num_node_of_kind++;
1001       }
1002       NNTR_THROW_IF(num_node_of_kind != conns.size(), std::invalid_argument)
1003         << "conns given but there are not identified node of the kind, num "
1004            "node of kind: "
1005         << num_node_of_kind << " identifier size: " << conns.size();
1006     }
1007   };
1008
1009   identify_external_tensors(model_input_names, is_input_node,
1010                             identify_as_model_input);
1011   identify_external_tensors(model_label_names, is_label_node,
1012                             identify_as_model_label);
1013
1014   /** mark the nodes which will be backwarded during the graph operation */
1015   try {
1016     markNodesForBackwarding();
1017     backward_iter_end = computeBackwardEnd();
1018   } catch (std::exception &e) {
1019     ml_loge(
1020       "Backwarding required from layer which doesn't support backwarding: %s",
1021       e.what());
1022     return ML_ERROR_INVALID_PARAMETER;
1023   }
1024
1025   /** select weights which would require clipping of the gradients by global
1026    * norm if any */
1027   clip_weights = tensor_manager->getWeights([](const Weight *w) {
1028     return w->hasGradient() && w->isGradientLastAccess() &&
1029            w->isGradientClipByGlobalNorm();
1030   });
1031
1032   return ML_ERROR_NONE;
1033 }
1034
1035 void NetworkGraph::setExternalTensors(const std::vector<Tensor> &data,
1036                                       const std::vector<std::string> names) {
1037
1038   /// feed or clear label
1039   for (unsigned int idx = 0; idx < names.size(); idx++) {
1040     if (data.empty())
1041       tensor_manager->fillPlaceholder(names[idx], Tensor());
1042     else if (data.size() == 1)
1043       tensor_manager->fillPlaceholder(names[idx], data[0]);
1044     else
1045       tensor_manager->fillPlaceholder(names[idx], data[idx]);
1046   }
1047 }
1048
1049 void NetworkGraph::setInputsLabels(const std::vector<Tensor> &inputs,
1050                                    const std::vector<Tensor> &labels) {
1051
1052   NNTR_THROW_IF(labels.size() > 1 && labels.size() != label_list.size(),
1053                 std::invalid_argument)
1054     << "label size does not match with the network requirements"
1055     << " label size: " << labels.size()
1056     << " requirements size: " << label_list.size();
1057
1058   NNTR_THROW_IF(inputs.size() > 1 && inputs.size() != input_list.size(),
1059                 std::invalid_argument)
1060     << "input size does not match with the network requirements"
1061     << " input size: " << inputs.size()
1062     << " requirements size: " << input_list.size();
1063
1064   setExternalTensors(inputs, input_list);
1065   setExternalTensors(labels, label_list);
1066 }
1067
1068 void NetworkGraph::setInputsLabels(sharedConstTensors &inputs,
1069                                    sharedConstTensors &labels) {
1070
1071   std::vector<Tensor> ins;
1072   std::transform(inputs.begin(), inputs.end(), std::back_inserter(ins),
1073                  [](auto const &val) { return *val.get(); });
1074
1075   std::vector<Tensor> labs;
1076   std::transform(labels.begin(), labels.end(), std::back_inserter(labs),
1077                  [](auto const &val) { return *val.get(); });
1078
1079   setInputsLabels(ins, labs);
1080 }
1081
1082 std::vector<Tensor> NetworkGraph::getOutputTensors() const {
1083   std::vector<Tensor> output_tensors;
1084   output_tensors.reserve(output_list.size());
1085
1086   for (auto const &name : output_list)
1087     output_tensors.push_back(*tensor_manager->getTensor(name));
1088
1089   return output_tensors;
1090 }
1091
1092 void NetworkGraph::flushCache() { tensor_manager->flushCache(); }
1093
1094 void NetworkGraph::flushCacheExcept(unsigned int order) {
1095   tensor_manager->flushCacheExcept(order);
1096 }
1097
1098 void NetworkGraph::requestOptimizerVariable(
1099   std::function<std::vector<TensorDim>(const TensorDim &)> cb,
1100   bool request_only_trainable) {
1101   for (auto const &w : tensor_manager->getWeights()) {
1102     if (w->isGradientLastAccess() && w->hasGradient()) {
1103       const TensorDim &dim = w->getDim();
1104       std::vector<TensorDim> dims = cb(dim);
1105       w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables(
1106         dims, w->getName(), TensorLifespan::MAX_LIFESPAN,
1107         Tensor::Initializer::ZEROS));
1108     }
1109   }
1110 }
1111
1112 } /* namespace nntrainer */