[Tensor] Remove calcGrad step for trainable layer
[platform/core/ml/nntrainer.git] / nntrainer / models / neuralnet.cpp
1 /**
2  * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *   http://www.apache.org/licenses/LICENSE-2.0
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  *
15  * @file        neuralnet.cpp
16  * @date        04 December 2019
17  * @brief       This is Neural Network Class
18  * @see         https://github.com/nnstreamer/nntrainer
19  * @author      Jijoong Moon <jijoong.moon@samsung.com>
20  * @bug         No known bugs except for NYI items
21  *
22  */
23
24 #include "layer_context.h"
25 #include "model_common_properties.h"
26 #include <cmath>
27 #include <cstring>
28 #include <fstream>
29 #include <iomanip>
30 #include <iostream>
31 #include <sstream>
32
33 #include <activation_realizer.h>
34 #include <common_properties.h>
35 #include <databuffer.h>
36 #include <flatten_realizer.h>
37 #include <ini_interpreter.h>
38 #include <ini_wrapper.h>
39 #include <input_realizer.h>
40 #include <model_loader.h>
41 #include <multiout_realizer.h>
42 #include <neuralnet.h>
43 #include <nntrainer_error.h>
44 #include <nntrainer_log.h>
45 #include <node_exporter.h>
46 #include <optimizer_context.h>
47 #include <previous_input_realizer.h>
48 #include <profiler.h>
49 #include <recurrent_realizer.h>
50 #include <remap_realizer.h>
51 #include <slice_realizer.h>
52 #include <util_func.h>
53
54 #ifdef ENABLE_TFLITE_INTERPRETER
55 #include <tflite_interpreter.h>
56 #endif
57
58 /**
59  * @brief Internal enum values for nntrainer to summarize model accuracy & loss
60  */
61 #define ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS 101
62 #define ML_TRAIN_SUMMARY_MODEL_VALID_LOSS 102
63 #define ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY 103
64
65 namespace nntrainer {
66
67 NeuralNetwork::NeuralNetwork() :
68   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
69   model_flex_props(props::Epochs(), props::TrainingBatchSize(),
70                    props::SavePath(), props::ContinueTrain(),
71                    props::SaveBestPath(), props::MemoryOptimization(),
72                    props::MemorySwap(), props::MemorySwapPath()),
73   load_path(std::string()),
74   epoch_idx(0),
75   iter(0),
76   loss(0.0f),
77   data_buffers({nullptr, nullptr, nullptr}),
78   initialized(false),
79   compiled(false),
80   loadedFromConfig(false) {
81   app_context = AppContext(AppContext::Global());
82 }
83
84 NeuralNetwork::NeuralNetwork(AppContext app_context_) :
85   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
86   model_flex_props(props::Epochs(), props::TrainingBatchSize(),
87                    props::SavePath(), props::ContinueTrain(),
88                    props::SaveBestPath(), props::MemoryOptimization(),
89                    props::MemorySwap(), props::MemorySwapPath()),
90   load_path(std::string()),
91   epoch_idx(0),
92   iter(0),
93   loss(0.0f),
94   data_buffers({nullptr, nullptr, nullptr}),
95   initialized(false),
96   compiled(false),
97   loadedFromConfig(false),
98   app_context(app_context_) {}
99
100 int NeuralNetwork::loadFromConfig(const std::string &config) {
101   if (loadedFromConfig == true) {
102     ml_loge("cannnot do loadFromConfig twice");
103     return ML_ERROR_INVALID_PARAMETER;
104   }
105
106   ModelLoader loader(app_context);
107   NeuralNetwork tempNet(*this);
108
109   int status = loader.loadFromContext(tempNet);
110   if (status != ML_ERROR_NONE) {
111     return status;
112   }
113
114   status = loader.loadFromConfig(config, tempNet);
115   if (status != ML_ERROR_NONE) {
116     return status;
117   }
118
119   tempNet.loadedFromConfig = true;
120   swap(tempNet, *this);
121
122   return ML_ERROR_NONE;
123 }
124
125 unsigned int NeuralNetwork::getCurrentEpoch() {
126 #ifdef DEBUG
127   ml_logd("[NNTrainer] Current epoch: %d", epoch_idx);
128 #endif
129   return epoch_idx;
130 };
131
132 void NeuralNetwork::setProperty(const std::vector<std::string> &values) {
133   auto left_props = loadProperties(values, model_props);
134   setTrainConfig(left_props);
135 }
136
137 void NeuralNetwork::setTrainConfig(const std::vector<std::string> &values) {
138   auto left_props = loadProperties(values, model_flex_props);
139   NNTR_THROW_IF(left_props.size(), std::invalid_argument)
140     << "Model has unparsed properties, size: " << left_props.size()
141     << " of first element: " << left_props.front();
142 }
143
144 int NeuralNetwork::compile() {
145   std::string loss_type = std::get<props::LossType>(model_props).empty()
146                             ? std::string()
147                             : std::get<props::LossType>(model_props);
148
149   auto &input_conn = std::get<std::vector<props::InputConnection>>(model_props);
150   /// @note label layer might need to be treated in the similar way as well
151
152   /// @todo make NetworkGraph compiled at the construction instead of having
153   /// graph.compile(), neuralnetwork have ownership of list of layer nodes,
154   /// which will be passed at compile time.
155
156   std::vector<std::unique_ptr<GraphRealizer>> realizers;
157
158   realizers.emplace_back(new PreviousInputRealizer(
159     std::vector<Connection>(input_conn.begin(), input_conn.end())));
160   realizers.emplace_back(new MultioutRealizer());
161   realizers.emplace_back(new FlattenRealizer());
162   realizers.emplace_back(new ActivationRealizer());
163
164   for (auto &realizer : realizers) {
165     graph_representation = realizer->realize(graph_representation);
166   }
167
168   bool memory_swap = std::get<props::MemorySwap>(model_flex_props);
169   const std::string memory_swap_path =
170     std::get<props::MemorySwapPath>(model_flex_props);
171   model_graph = NetworkGraph(memory_swap, memory_swap_path);
172
173   model_graph.setMemoryOptimizations(
174     std::get<props::MemoryOptimization>(model_flex_props));
175   for (auto &node : graph_representation) {
176     if (auto &prop = std::get<props::ClipGradByGlobalNorm>(model_props);
177         !prop.empty()) {
178       node->setProperty({"clip_grad_by_norm=" + to_string(prop)});
179     }
180     model_graph.addLayer(node);
181   }
182
183   int status = model_graph.compile(loss_type);
184   NN_RETURN_STATUS();
185
186   compiled = true;
187
188   return status;
189 }
190
191 int NeuralNetwork::initialize() {
192   int status = ML_ERROR_NONE;
193
194   if (initialized) {
195     ml_loge("Error: Initializing the model again");
196     return ML_ERROR_NOT_SUPPORTED;
197   }
198
199   if (!compiled) {
200     ml_loge("Error: Need to compile first");
201     return ML_ERROR_NOT_SUPPORTED;
202   }
203
204   unsigned int n_layers = (unsigned int)model_graph.size();
205
206   ml_logd("initializing neural network, layer size: %d", n_layers);
207   PROFILE_MEM_ANNOTATE("Initialize");
208
209   auto &input_conn_prop =
210     std::get<std::vector<props::InputConnection>>(model_props);
211   auto &label_layer_prop =
212     std::get<std::vector<props::LabelLayer>>(model_props);
213
214   std::vector<Connection> input_conn(input_conn_prop.begin(),
215                                      input_conn_prop.end());
216   std::vector<std::string> label_layers;
217
218   if (!label_layer_prop.empty()) {
219     label_layers = std::vector<std::string>(label_layer_prop.begin(),
220                                             label_layer_prop.end());
221   }
222
223   status = model_graph.initialize(
224     input_conn,
225     std::vector<Connection>(label_layers.begin(), label_layers.end()));
226   NN_RETURN_STATUS();
227
228   model_graph.setBatchSize(
229     std::get<props::TrainingBatchSize>(model_flex_props));
230
231   // initialize optimizer and related variables
232   /// @todo: initialize should take a mode and check if mode is train but
233   /// optimizer is not given, make it as a hard error
234   if (opt) {
235     /** TODO: update request of optimizer to be of same format as
236      * Layer::requestTensor */
237     opt->finalize();
238     std::function<std::vector<TensorDim>(const TensorDim &)> cb =
239       [this](const TensorDim &dim) {
240         return opt->getOptimizerVariableDim(dim);
241       };
242     model_graph.requestOptimizerVariable(cb, true);
243   }
244
245   // Allocate weights
246   model_graph.allocateWeights();
247
248   initialized = true;
249
250   if (!load_path.empty()) {
251     load(load_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
252   }
253
254   return status;
255 }
256
257 /**
258  * @brief     free layers
259  */
260 NeuralNetwork::~NeuralNetwork() = default;
261
262 /**
263  * @brief     forward propagation using layers object which has layer
264  */
265 sharedConstTensors
266 NeuralNetwork::forwarding(bool training,
267                           std::function<bool(void *userdata)> stop_cb) {
268   return model_graph.forwarding(training, stop_cb);
269 }
270
271 /**
272  * @brief     forward propagation using layers object which has layer
273  */
274 sharedConstTensors NeuralNetwork::forwarding(sharedConstTensors input,
275                                              sharedConstTensors label,
276                                              bool training) {
277   auto current_batch = model_graph.getBatchSize();
278   NNTR_THROW_IF(input[0]->batch() != current_batch ||
279                   (!label.empty() && label[0]->batch() != current_batch),
280                 std::logic_error)
281     << "Error: mismatch in batchsize for data and model."
282     << " input_batch: " << input[0]->batch()
283     << " label_batch: " << label[0]->batch()
284     << " target_batch: " << current_batch;
285
286   model_graph.setInputsLabels(input, label);
287
288   return forwarding(training);
289 }
290
291 /**
292  * @brief     back propagation
293  *            Call backwarding function of layer in reverse order
294  *            No need to call at first Input Layer (No data to be updated)
295  */
296 void NeuralNetwork::backwarding(int iteration,
297                                 std::function<bool(void *userdata)> stop_cb) {
298
299 #ifdef DEBUG
300   NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!";
301 #endif
302
303   std::function<void(std::shared_ptr<LayerNode>, int)> backwarding_op =
304     [this, stop_cb](std::shared_ptr<LayerNode> node, int iteration) -> void {
305     /**
306      * Do not change this order:
307      * 1. calcGradient
308      * 2. calcDerivative
309      * 3. applyGradientsOnLastAccess
310      */
311
312     model_graph.flushCacheExcept(std::get<1>(node->getExecutionOrder()));
313     PROFILE_MEM_ANNOTATE("CalcGradient: " + node->getName());
314
315     bool apply_gradient = true;
316
317     /** If gradient optimization mode, then calculate gradient first */
318     if (dynamic_training_opt.isGradientMode())
319       node->calcGradient();
320
321     /**
322      * If optimization off, or gradient must be applied, then this will be
323      * true
324      * @todo This apply gradient should be passed to the each weight and later
325      * be queried when updating gradient at once. (after moving apply_gradient
326      * out of this function)
327      *
328      */
329     // auto &layer = node->getObject();
330     // apply_gradient = dynamic_training_opt.checkIfApply(
331     //   layer->getWeightsRef(), layer->net_input[0], layer->net_hidden[0],
332     //   opt, iteration);
333
334     /** If gradient must be applied and its not gradient mode, calculate
335      * gradient
336      */
337     if (!dynamic_training_opt.isGradientMode() && apply_gradient)
338       node->calcGradient();
339
340     model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder()));
341     PROFILE_MEM_ANNOTATE("CalcDerivative: " + node->getName());
342
343     if (stop_cb(nullptr)) {
344       return;
345     }
346
347     if (node->needsCalcDerivative())
348       node->calcDerivative();
349
350     if (apply_gradient) {
351       /// Apply gradient only at the end of the last shared weight access
352       model_graph.applyGradients(
353         node.get(), [iteration, opt_ = opt.get()](Weight &w) {
354           w.calcRegularizationGradient();
355           w.calcWeightDecayGradient();
356           RunOptimizerContext opt_context(&w, iteration,
357                                           opt_->getLearningRate(iteration));
358           opt_->applyGradient(opt_context);
359         });
360     }
361   };
362
363   std::function<void(Weight &, int)> apply_grad_clip_op =
364     [opt_ = opt.get()](Weight &w, int iteration) -> void {
365     w.calcRegularizationGradient();
366     w.calcWeightDecayGradient();
367     RunOptimizerContext opt_context(&w, iteration,
368                                     opt_->getLearningRate(iteration));
369     opt_->applyGradient(opt_context);
370   };
371
372   model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op,
373                           stop_cb);
374 }
375
376 void NeuralNetwork::save(const std::string &file_path,
377                          ml::train::ModelFormat format) {
378   NNTR_THROW_IF(!initialized, std::runtime_error)
379     << "Cannot save model if not initialized yet, path: " << file_path
380     << " format: " << static_cast<unsigned>(format);
381
382   /// @todo this switch case should be delegating the function call only. It's
383   /// not delegating for now as required logics are managable for now.
384   switch (format) {
385   case ml::train::ModelFormat::MODEL_FORMAT_BIN: {
386     auto model_file = checkedOpenStream<std::ofstream>(
387       file_path, std::ios::out | std::ios::binary | std::ios::trunc);
388     for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
389       (*iter)->save(model_file);
390     }
391     if (opt && istrequal(opt->getType(), "adam")) {
392       std::string adam = "adam";
393       model_file.write(adam.c_str(), adam.size());
394       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
395            iter++) {
396         (*iter)->save(model_file, true);
397       }
398     }
399
400     model_file.write((char *)&epoch_idx, sizeof(epoch_idx));
401     model_file.write((char *)&iter, sizeof(iter));
402
403     model_file.close();
404     break;
405   }
406   case ml::train::ModelFormat::MODEL_FORMAT_INI:
407     saveModelIni(file_path);
408     break;
409
410   case ml::train::ModelFormat::MODEL_FORMAT_INI_WITH_BIN: {
411     auto old_save_path = std::get<props::SavePath>(model_flex_props);
412     auto bin_file_name =
413       file_path.substr(0, file_path.find_last_of('.')) + ".bin";
414
415     std::get<props::SavePath>(model_flex_props).set(bin_file_name);
416     save(file_path, ml::train::ModelFormat::MODEL_FORMAT_INI);
417     save(bin_file_name, ml::train::ModelFormat::MODEL_FORMAT_BIN);
418     std::get<props::SavePath>(model_flex_props) = old_save_path;
419     break;
420   }
421   default:
422     throw nntrainer::exception::not_supported(
423       "saving with given format is not supported yet");
424   }
425 }
426
427 void NeuralNetwork::load(const std::string &file_path,
428                          ml::train::ModelFormat format) {
429   /// @todo this switch case should be delegating the function call only. It's
430   /// not delegating for now as required logics are managable for now.
431   switch (format) {
432   case ml::train::ModelFormat::MODEL_FORMAT_BIN: {
433     NNTR_THROW_IF(!initialized, std::runtime_error)
434       << "Cannot load if not initialized yet, path: " << file_path
435       << " format: " << static_cast<unsigned>(format);
436
437     auto model_file = checkedOpenStream<std::ifstream>(
438       file_path, std::ios::in | std::ios::binary);
439     for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
440       (*iter)->read(model_file);
441     }
442     try {
443       /// this is assuming that the failure is allowed at the end of the file
444       /// read. so, after this line, additional read shouldn't be called
445       if (opt && istrequal(opt->getType(), "adam")) {
446         std::string opt_type;
447         opt_type.resize(4);
448         model_file.read((char *)&opt_type[0], 4);
449         if (istrequal(opt_type, "adam")) {
450           for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
451                iter++) {
452             (*iter)->read(model_file, true);
453           }
454         }
455       }
456
457       checkedRead(model_file, (char *)&epoch_idx, sizeof(epoch_idx),
458                   "[NeuralNetwork::readModel] failed to read epoch_idx");
459       checkedRead(model_file, (char *)&iter, sizeof(iter),
460                   "[NeuralNetwork::readModel] failed to read iteration");
461     } catch (...) {
462       std::cerr << "failed to read additional data like optimizer variable, "
463                    "iteration, proceeding with default\n";
464     }
465
466     ml_logi("read modelfile: %s", file_path.c_str());
467     break;
468   }
469   case ml::train::ModelFormat::MODEL_FORMAT_INI_WITH_BIN: {
470     int ret = loadFromConfig(file_path);
471     throw_status(ret);
472     auto &save_path = std::get<props::SavePath>(model_flex_props);
473     if (!save_path.empty()) {
474       checkedOpenStream<std::ifstream>(save_path,
475                                        std::ios::in | std::ios::binary);
476       load_path = save_path;
477     }
478     break;
479   }
480   case ml::train::ModelFormat::MODEL_FORMAT_INI: {
481     int ret = loadFromConfig(file_path);
482     throw_status(ret);
483     break;
484   }
485   default:
486     throw nntrainer::exception::not_supported(
487       "loading with given format is not supported yet");
488   }
489 }
490
491 float NeuralNetwork::getLoss() {
492   loss = 0.0f;
493
494   for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
495     loss += (*iter)->getLoss();
496   }
497   return loss;
498 }
499
500 void NeuralNetwork::setLoss(float l) { loss = l; }
501
502 NeuralNetwork &NeuralNetwork::copy(NeuralNetwork &from) {
503   if (this != &from) {
504     model_props = from.model_props;
505     model_flex_props = from.model_flex_props;
506     loss = from.loss;
507     opt = from.opt;
508
509     model_graph.copy(from.model_graph);
510   }
511   return *this;
512 }
513
514 void NeuralNetwork::saveModelIni(const std::string &file_path) {
515   NNTR_THROW_IF(isFileExist(file_path), std::invalid_argument)
516     << "There is already a file, overriding to the exisiting file is not "
517        "permitted, path: "
518     << file_path;
519
520   std::vector<IniSection> sections;
521
522   IniSection model_section = IniSection::FromExportable("model", *this);
523   model_section.setEntry("type", "NeuralNetwork");
524   sections.push_back(model_section);
525
526   auto add_section_if_any = [&sections](const std::string &section_name,
527                                         auto obj_ptr, auto pred) {
528     if (pred(obj_ptr)) {
529       IniSection s = IniSection::FromExportable(section_name, *obj_ptr);
530       s.setEntry("type", obj_ptr->getType());
531       sections.push_back(s);
532     }
533   };
534
535   add_section_if_any("optimizer", opt,
536                      [](const auto &obj) { return static_cast<bool>(obj); });
537
538   auto &[train_buffer, valid_buffer, test_buffer] = data_buffers;
539   auto data_buffer_valid = [](const auto &buffer) {
540     return buffer && buffer->isSerializable(
541                        ml::train::ExportMethods::METHOD_STRINGVECTOR);
542   };
543
544   add_section_if_any("train_set", train_buffer, data_buffer_valid);
545   add_section_if_any("valid_set", valid_buffer, data_buffer_valid);
546   add_section_if_any("test_set", test_buffer, data_buffer_valid);
547
548   IniWrapper wrapper("model_saver", sections);
549   wrapper.save_ini(file_path);
550
551   IniGraphInterpreter interpreter;
552   interpreter.serialize(graph_representation, file_path);
553 }
554
555 bool NeuralNetwork::validateInput(sharedConstTensors X) {
556   auto input_dim = getInputDimension();
557   if (X.size() != input_dim.size()) {
558     ml_loge("Error: provided number of inputs %d, required %d", (int)X.size(),
559             (int)input_dim.size());
560     return false;
561   }
562
563   for (unsigned int dim = 0; dim < input_dim.size(); dim++) {
564     if (input_dim[dim] != X[dim]->getDim()) {
565       ml_loge("Error: provided input shape does not match required shape");
566       std::stringstream ss;
567       ss << X[dim]->getDim();
568       ml_loge("Provided tensor summary : %s", ss.str().c_str());
569
570       ss.str(std::string());
571       ss << input_dim[dim];
572       ml_loge("Required tensor summary : %s", ss.str().c_str());
573       return false;
574     }
575   }
576
577   return true;
578 }
579
580 sharedConstTensors NeuralNetwork::inference(sharedConstTensors X,
581                                             bool free_mem) {
582   return inference(X, {}, free_mem);
583 }
584
585 sharedConstTensors NeuralNetwork::inference(sharedConstTensors X,
586                                             sharedConstTensors label,
587                                             bool free_mem) {
588   if (model_graph.getBatchSize() != X[0]->batch()) {
589     model_graph.setBatchSize(X[0]->batch());
590   }
591
592   sharedConstTensors out;
593   if (!validateInput(X))
594     throw std::invalid_argument("Input validation failed.");
595
596   allocate(ExecutionMode::INFERENCE);
597
598   int nn_foward;
599   PROFILE_TIME_REGISTER_EVENT(nn_foward, "nn_forward");
600   PROFILE_TIME_START(nn_foward);
601   out = forwarding(X, label, false);
602   PROFILE_TIME_END(nn_foward);
603
604   if (free_mem)
605     /**
606      * Free the memory needed for training before exiting.
607      * Note that this does not free the weights for the model.
608      * Weights of the model will be freed when the model is destroyed.
609      */
610     model_graph.deallocateTensors(false);
611
612   /** Clear the set inputs and labels */
613   model_graph.setInputsLabels({}, {});
614
615   return out;
616 }
617
618 std::vector<float *>
619 NeuralNetwork::inference(unsigned int batch_size,
620                          const std::vector<float *> &input,
621                          const std::vector<float *> &label) {
622   sharedConstTensors input_tensors, output_tensors;
623   auto in_dim = getInputDimension();
624
625   input_tensors.reserve(input.size());
626   for (unsigned int idx = 0; idx < in_dim.size(); idx++) {
627     in_dim[idx].batch(batch_size);
628     input_tensors.emplace_back(MAKE_SHARED_TENSOR(Tensor::Map(
629       input[idx], in_dim[idx].getDataLen() * sizeof(float), in_dim[idx], 0)));
630   }
631
632   if (!label.empty()) {
633     sharedConstTensors label_tensors;
634     auto label_dim = getOutputDimension();
635     label_tensors.reserve(label.size());
636     for (unsigned int idx = 0; idx < label_dim.size(); idx++) {
637       label_dim[idx].batch(batch_size);
638       label_tensors.emplace_back(MAKE_SHARED_TENSOR(
639         Tensor::Map(label[idx], label_dim[idx].getDataLen() * sizeof(float),
640                     label_dim[idx], 0)));
641     }
642     output_tensors = inference(input_tensors, label_tensors, false);
643   } else {
644     output_tensors = inference(input_tensors, false);
645   }
646
647   std::vector<float *> output;
648   output.reserve(output_tensors.size());
649
650   for (auto &out : output_tensors) {
651     auto out_t = *out.get();
652     output.push_back(out_t.getData());
653   }
654
655   return output;
656 }
657
658 int NeuralNetwork::setDataset(const DatasetModeType &mode,
659                               std::shared_ptr<ml::train::Dataset> dataset) {
660   return setDataBuffer(mode, std::static_pointer_cast<DataBuffer>(dataset));
661 }
662
663 int NeuralNetwork::allocate(ExecutionMode mode) {
664   model_graph.deallocateTensors();
665   model_graph.allocateTensors(mode);
666
667   return ML_ERROR_NONE;
668 }
669
670 int NeuralNetwork::deallocate() {
671   model_graph.deallocateTensors(true);
672
673   return ML_ERROR_NONE;
674 }
675
676 int NeuralNetwork::train(const std::vector<std::string> &values,
677                          std::function<bool(void *)> stop_cb) {
678   int status = ML_ERROR_NONE;
679
680   if (data_buffers[static_cast<int>(DatasetModeType::MODE_TRAIN)] == nullptr) {
681     ml_loge("Cannot initialize the model without the train data buffer.");
682     return ML_ERROR_INVALID_PARAMETER;
683   }
684
685   if (!opt) {
686     ml_loge("Cannot train network without optimizer.");
687     return ML_ERROR_INVALID_PARAMETER;
688   }
689
690   setTrainConfig(values);
691
692   /** set batch size just before training */
693   model_graph.setBatchSize(
694     std::get<props::TrainingBatchSize>(model_flex_props));
695
696   status = allocate(ExecutionMode::TRAIN);
697   NN_RETURN_STATUS();
698
699   status = train_run(stop_cb);
700   NN_RETURN_STATUS();
701
702   /**
703    * Free the memory needed for training before exiting.
704    * Note that this does not free the weights for the model.
705    * Weights of the model will be freed when the model is destroyed.
706    */
707   model_graph.deallocateTensors(false);
708   return status;
709 }
710
711 /**
712  * @brief     Run NeuralNetwork train with callback function by user
713  */
714 int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
715   int status = ML_ERROR_NONE;
716
717   if (!std::get<props::ContinueTrain>(model_flex_props)) {
718     epoch_idx = 0;
719     iter = 0;
720     for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
721       (*iter)->clearOptVar();
722     }
723   }
724
725   auto batch_size = std::get<props::TrainingBatchSize>(model_flex_props);
726
727   auto const &outputs = model_graph.getOutputTensors();
728   auto in_dims = model_graph.getInputDimension();
729   auto label_dims = model_graph.getOutputDimension();
730
731   auto &[train_buffer, valid_buffer, test_buffer] = data_buffers;
732
733   if (train_buffer == nullptr) {
734     ml_loge("[NeuralNetworks] there is no train dataset!");
735     return ML_ERROR_INVALID_PARAMETER;
736   }
737
738   /**
739    * @brief run a single epoch with given callback, @a auto is used instead of
740    * std::function for performance measure
741    * @param buffer buffer to run
742    * @param shuffle whether to shuffle or not
743    * @param on_iteration_fetch function that will recieve reference to stat,
744    * buffer which will be called every time data is fetched and set
745    * @param on_epoch_end function that will recieve reference to stat,
746    * buffer which will be called on the epoch end
747    */
748   auto run_epoch = [this, &in_dims, &label_dims, &outputs, batch_size](
749                      DataBuffer *buffer, bool shuffle,
750                      auto &&on_iteration_fetch, auto &&on_iteration_update_stat,
751                      auto &&on_epoch_end) {
752     /// @todo managing metrics must be handled here as well!! for now it is
753     /// handled in individual callbacks
754     RunStats stat;
755     std::future<std::shared_ptr<IterationQueue>> future_iq =
756       buffer->startFetchWorker(in_dims, label_dims, shuffle);
757     while (true) {
758       ScopedView<Iteration> iter_view = buffer->fetch();
759       if (iter_view.isEmpty()) {
760         break;
761       }
762       auto &iteration = iter_view.get();
763       if (iteration.batch() != batch_size) {
764         /// @todo support partial batch
765         continue;
766       }
767
768       auto const &labels = iteration.getLabelsRef();
769       auto const &inputs = iteration.getInputsRef();
770       model_graph.setInputsLabels(inputs, labels);
771
772       on_iteration_fetch(stat, *buffer);
773       on_iteration_update_stat(stat, outputs, labels);
774     }
775     future_iq.get();
776     on_epoch_end(stat, *buffer);
777
778     if (stat.num_iterations == 0) {
779       throw std::runtime_error("No data came while buffer ran");
780     }
781
782     return stat;
783   };
784
785   auto train_for_iteration = [this, stop_cb](RunStats &stat,
786                                              DataBuffer &buffer) {
787     model_graph.flushCache();
788
789     forwarding(true, stop_cb);
790     backwarding(iter++, stop_cb);
791
792     if (!stop_cb(nullptr)) {
793       std::cout << "#" << epoch_idx << "/" << getEpochs();
794       ml_logi("# %d / %d", epoch_idx, getEpochs());
795       auto loss = getLoss();
796       buffer.displayProgress(stat.num_iterations, loss);
797     }
798   };
799
800   auto update_train_stat = [this](RunStats &stat,
801                                   const std::vector<Tensor> &outputs,
802                                   const std::vector<Tensor> &labels) {
803     stat.loss += getLoss();
804     stat.num_iterations++;
805   };
806
807   auto train_epoch_end = [this, stop_cb](RunStats &stat, DataBuffer &buffer) {
808     stat.loss /= static_cast<float>(stat.num_iterations);
809     auto &save_path = std::get<props::SavePath>(model_flex_props);
810     if (!stop_cb(nullptr)) {
811       if (!save_path.empty()) {
812         save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
813       }
814
815       std::cout << "#" << epoch_idx << "/" << getEpochs()
816                 << " - Training Loss: " << stat.loss;
817       ml_logi("# %d / %d - Training Loss: %f", epoch_idx, getEpochs(),
818               stat.loss);
819       ml_logd("[NNTrainer] Training epoch %d / %d finished successfully.",
820               epoch_idx, getEpochs());
821     } else {
822       ml_logd("[NNTrainer] Training stopped by stop callback function during "
823               "epoch %d.",
824               epoch_idx);
825     }
826   };
827
828   auto eval_for_iteration = [this, batch_size](RunStats &stat,
829                                                DataBuffer &buffer) {
830     forwarding(false);
831   };
832
833   auto update_eval_stat = [batch_size, &update_train_stat](
834                             RunStats &stat, const std::vector<Tensor> &outputs,
835                             const std::vector<Tensor> &labels) {
836     auto model_out = outputs[0].argmax();
837     auto label_out = labels[0].argmax();
838
839     for (unsigned int b = 0; b < batch_size; b++) {
840       if (model_out[b] == label_out[b])
841         stat.num_correct_predictions++;
842     }
843
844     update_train_stat(stat, outputs, labels);
845   };
846
847   auto eval_epoch_end = [this, batch_size, max_acc = 0.0f,
848                          min_loss = std::numeric_limits<float>::max()](
849                           RunStats &stat, DataBuffer &buffer) mutable {
850     stat.loss /= static_cast<float>(stat.num_iterations);
851     stat.accuracy = stat.num_correct_predictions /
852                     static_cast<float>(stat.num_iterations * batch_size) *
853                     100.0f;
854
855     if (stat.accuracy > max_acc ||
856         (stat.accuracy == max_acc && stat.loss < min_loss)) {
857       max_acc = stat.accuracy;
858       /// @note this is not actually 'the' min loss for whole time but records
859       /// when data change
860       min_loss = stat.loss;
861       auto &save_best_path = std::get<props::SaveBestPath>(model_flex_props);
862       if (!save_best_path.empty()) {
863         save(save_best_path);
864       }
865     }
866     std::cout << " >> [ Accuracy: " << stat.accuracy
867               << "% - Validation Loss : " << stat.loss << " ]";
868     ml_logi("[ Accuracy: %.2f %% - Validataion Loss: %.5f", stat.accuracy,
869             stat.loss);
870   };
871
872   PROFILE_MEM_ANNOTATE("TRAIN START");
873   auto epochs = getEpochs();
874   ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
875           epoch_idx + 1, getEpochs());
876   for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
877     if (stop_cb(nullptr)) {
878       --epoch_idx;
879       break;
880     }
881     training = run_epoch(train_buffer.get(), true, train_for_iteration,
882                          update_train_stat, train_epoch_end);
883     if (valid_buffer) {
884       validation = run_epoch(valid_buffer.get(), false, eval_for_iteration,
885                              update_eval_stat, eval_epoch_end);
886     }
887     std::cout << '\n';
888   }
889   PROFILE_MEM_ANNOTATE("TRAIN END");
890
891   if (test_buffer) {
892     std::cout << "Evaluation with test data...\n";
893     testing = run_epoch(test_buffer.get(), false, eval_for_iteration,
894                         update_eval_stat, eval_epoch_end);
895   }
896
897   /** Clear the set inputs and labels */
898   model_graph.setInputsLabels({}, {});
899
900   return status;
901 }
902
903 void swap(NeuralNetwork &lhs, NeuralNetwork &rhs) {
904   {
905     using std::swap;
906
907     swap(lhs.model_props, rhs.model_props);
908     swap(lhs.model_flex_props, rhs.model_flex_props);
909     swap(lhs.load_path, rhs.load_path);
910     swap(lhs.epoch_idx, rhs.epoch_idx);
911     swap(lhs.iter, rhs.iter);
912     swap(lhs.loss, rhs.loss);
913     swap(lhs.opt, rhs.opt);
914     swap(lhs.data_buffers, rhs.data_buffers);
915     swap(lhs.initialized, rhs.initialized);
916     swap(lhs.model_graph, rhs.model_graph);
917     swap(lhs.graph_representation, rhs.graph_representation);
918     swap(lhs.compiled, rhs.compiled);
919     swap(lhs.loadedFromConfig, rhs.loadedFromConfig);
920   }
921 }
922
923 int NeuralNetwork::addLayer(NodeType layer) {
924   int status = ML_ERROR_NONE;
925
926   if (initialized) {
927     return ML_ERROR_NOT_SUPPORTED;
928   }
929
930   /** Insert the layer to the graph */
931   model_graph.addLayer(layer);
932   graph_representation.push_back(layer);
933
934   return status;
935 }
936
937 NeuralNetwork &NeuralNetwork::copyConfiguration(NeuralNetwork &from) {
938   if (this != &from) {
939     model_props = from.model_props;
940     model_flex_props = from.model_flex_props;
941     loss = from.loss;
942     opt = from.opt;
943
944     NetworkGraph f_graph = from.getNetworkGraph();
945     for (auto &l_node : f_graph.getLayerNodes()) {
946       addLayer(static_cast<std::shared_ptr<ml::train::Layer>>(
947         l_node->cloneConfiguration()));
948     }
949   }
950   return *this;
951 }
952
953 NeuralNetwork::GraphType
954 NeuralNetwork::getUnsortedLayers(const std::string &input_layer,
955                                  const std::string &output_layer) {
956   return model_graph.getUnsortedLayers(input_layer, output_layer);
957 }
958
959 int NeuralNetwork::setOptimizer(
960   std::shared_ptr<ml::train::Optimizer> optimizer) {
961   if (initialized) {
962     return ML_ERROR_NOT_SUPPORTED;
963   }
964
965   opt = std::static_pointer_cast<OptimizerWrapped>(optimizer);
966
967   return ML_ERROR_NONE;
968 }
969
970 int NeuralNetwork::setDataBuffer(const DatasetModeType &mode,
971                                  std::shared_ptr<DataBuffer> data_buffer) {
972   if (data_buffer == nullptr) {
973     return ML_ERROR_INVALID_PARAMETER;
974   }
975
976   this->data_buffers[static_cast<int>(mode)] = data_buffer;
977
978   return ML_ERROR_NONE;
979 }
980
981 int NeuralNetwork::getLayer(const char *name,
982                             std::shared_ptr<ml::train::Layer> *layer) {
983   // We provide the layer change through the api with user's responsibility.
984   //
985   // if (compiled) {
986   //   ml_loge("Cannot get compiled layer.");
987   //   return ML_ERROR_NOT_SUPPORTED;
988   // }
989
990   *layer = std::static_pointer_cast<ml::train::Layer>(
991     model_graph.getLayerNode(std::string(name)));
992   return ML_ERROR_NONE;
993 }
994
995 void NeuralNetwork::printMetrics(std::ostream &out, unsigned int flags) {
996   switch (flags) {
997   case ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS:
998     out << training.loss << std::endl;
999     break;
1000
1001   case ML_TRAIN_SUMMARY_MODEL_VALID_LOSS:
1002     out << validation.loss << std::endl;
1003     break;
1004
1005   case ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY:
1006     out << validation.accuracy << std::endl;
1007     break;
1008
1009   default:
1010     break;
1011   }
1012 }
1013
1014 void NeuralNetwork::printPreset(std::ostream &out, unsigned int preset) {
1015   /** print neuralnet metrics */
1016   printMetrics(out, preset);
1017   if (preset > ML_TRAIN_SUMMARY_TENSOR)
1018     return;
1019
1020   LayerNode::PrintPreset layer_preset = LayerNode::PrintPreset::PRINT_NONE;
1021
1022   ///@todo match flags with preset
1023   unsigned int flags = PRINT_INST_INFO | PRINT_GRAPH_INFO | PRINT_PROP |
1024                        PRINT_OPTIMIZER | PRINT_METRIC;
1025
1026   switch (preset) {
1027   case ML_TRAIN_SUMMARY_TENSOR:
1028     layer_preset = LayerNode::PrintPreset::PRINT_ALL;
1029     break;
1030   case ML_TRAIN_SUMMARY_LAYER:
1031     layer_preset = initialized ? LayerNode::PrintPreset::PRINT_SUMMARY
1032                                : LayerNode::PrintPreset::PRINT_SUMMARY_META;
1033     break;
1034   case ML_TRAIN_SUMMARY_MODEL:
1035     break;
1036   default:
1037     throw std::invalid_argument("given verbosity is invalid");
1038   }
1039
1040   print(out, flags, layer_preset);
1041 }
1042
1043 void NeuralNetwork::addWithReferenceLayers(
1044   const std::vector<std::shared_ptr<ml::train::Layer>> &reference,
1045   const std::string &scope, const std::vector<std::string> &input_layers,
1046   const std::vector<std::string> &start_layers,
1047   const std::vector<std::string> &end_layers,
1048   ml::train::ReferenceLayersType type,
1049   const std::vector<std::string> &type_properties) {
1050   std::vector<NodeType> casted_reference;
1051   casted_reference.reserve(reference.size());
1052   for (auto &node : reference) {
1053     casted_reference.emplace_back(std::static_pointer_cast<LayerNode>(node));
1054   }
1055
1056   addWithReferenceLayers(casted_reference, scope, input_layers, start_layers,
1057                          end_layers, type, type_properties);
1058 }
1059 void NeuralNetwork::addWithReferenceLayers(
1060   const std::vector<std::shared_ptr<LayerNode>> &reference,
1061   const std::string &scope, const std::vector<std::string> &input_layers,
1062   const std::vector<std::string> &start_layers,
1063   const std::vector<std::string> &end_layers,
1064   ml::train::ReferenceLayersType type,
1065   const std::vector<std::string> &type_properties) {
1066   /// @todo below configuration should be extracted as a free function to make
1067   /// it more testable, and reused inside graph interpreter
1068
1069   /// @note we can exploit connection to connection more fine grained, for now
1070   /// it is not supported but we can easily make this supported
1071   std::vector<std::shared_ptr<LayerNode>> nodes;
1072   nodes.reserve(reference.size());
1073   for (auto &node : reference) {
1074     nodes.push_back(node->cloneConfiguration());
1075   }
1076
1077   auto start_conns =
1078     std::vector<Connection>(start_layers.begin(), start_layers.end());
1079   auto input_conns =
1080     std::vector<Connection>(input_layers.begin(), input_layers.end());
1081   auto end_conns =
1082     std::vector<Connection>(end_layers.begin(), end_layers.end());
1083
1084   std::vector<std::unique_ptr<GraphRealizer>> realizers;
1085
1086   realizers.emplace_back(new PreviousInputRealizer(start_conns));
1087   realizers.emplace_back(new SliceRealizer(start_conns, end_conns));
1088
1089   if (!input_conns.empty()) {
1090     realizers.emplace_back(new InputRealizer(start_conns, input_conns));
1091   }
1092
1093   if (type == ml::train::ReferenceLayersType::RECURRENT) {
1094     realizers.emplace_back(
1095       new RecurrentRealizer(type_properties, input_conns, end_conns));
1096   }
1097
1098   if (!scope.empty()) {
1099     realizers.emplace_back(
1100       new RemapRealizer([&scope, &input_conns](std::string &name) {
1101         for (auto &i : input_conns) {
1102           if (i.getName() == name) {
1103             return;
1104           }
1105         }
1106         name = scope + "/" + name;
1107       }));
1108   }
1109
1110   for (auto &realizer : realizers) {
1111     nodes = realizer->realize(nodes);
1112   }
1113
1114   for (auto &node : nodes) {
1115     addLayer(node);
1116   }
1117 }
1118
1119 void NeuralNetwork::exportTo(Exporter &exporter,
1120                              const ml::train::ExportMethods &method) const {
1121   exporter.saveResult(model_props, method, this);
1122   exporter.saveResult(model_flex_props, method, this);
1123 }
1124
1125 void NeuralNetwork::print(std::ostream &out, unsigned int flags,
1126                           LayerNode::PrintPreset layerPrintPreset) {
1127   if (flags & PRINT_INST_INFO) {
1128     /// @todo uncomment this after implement getProperty (#1875)
1129     // out << "===================";
1130     // printInstance(out, this);
1131   }
1132
1133   if (flags & PRINT_GRAPH_INFO) {
1134     unsigned int total_col_size = 80;
1135     std::vector<unsigned int> column_size = {20, 20, 20, 20};
1136     auto print_graph_layer_info =
1137       [column_size](std::ostream &out, std::vector<std::string> layer_info) {
1138         auto trim_string = [](std::string str, unsigned int column_width) {
1139           return str.size() < column_width ? str
1140                                            : str.substr(0, column_width - 1);
1141         };
1142
1143         for (unsigned int i = 0; i < column_size.size(); ++i) {
1144           out << std::setw(column_size[i])
1145               << trim_string(layer_info[i], column_size[i]);
1146         }
1147         out << "\n";
1148       };
1149
1150     out << std::string(total_col_size, '=') << '\n';
1151     print_graph_layer_info(
1152       out, {"Layer name", "Layer type", "Input dimension", "Input layer"});
1153     out << std::string(total_col_size, '=') << '\n';
1154     if (compiled) {
1155       props::GenericShape dim_property;
1156
1157       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
1158            iter++) {
1159         std::string first_dim;
1160         if (iter->getInputDimensions().empty()) {
1161           first_dim = "";
1162         } else {
1163           dim_property.set(iter->getInputDimensions()[0]);
1164           first_dim = to_string(dim_property);
1165         }
1166         const std::vector<std::string> &input_layer_names =
1167           iter->getInputConnections();
1168         std::string first_input_name =
1169           input_layer_names.empty() ? "" : input_layer_names[0];
1170         print_graph_layer_info(
1171           out, {iter->getName(), iter->getType(), first_dim, first_input_name});
1172         for (unsigned int i = 1; i < input_layer_names.size(); ++i) {
1173           dim_property.set(iter->getInputDimensions()[i]);
1174           print_graph_layer_info(
1175             out, {"", "", to_string(dim_property), input_layer_names[i]});
1176         }
1177         out << std::string(total_col_size,
1178                            iter == model_graph.cend() - 1 ? '=' : '-')
1179             << '\n';
1180       }
1181     } else {
1182       auto &input_connection =
1183         std::get<std::vector<props::InputConnection>>(model_props);
1184       auto model_input = std::vector<Connection>(input_connection.begin(),
1185                                                  input_connection.end());
1186       auto is_actually_an_input_node =
1187         [model_input](graph_const_iterator<LayerNode> node) {
1188           return node->hasInputShapeProperty() or
1189                  std::any_of(model_input.begin(), model_input.end(),
1190                              [node](auto &conn) {
1191                                return node->getName() == conn.getName();
1192                              });
1193         };
1194
1195       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
1196            iter++) {
1197         const std::vector<std::string> &input_layer_names =
1198           iter->getInputConnections();
1199
1200         /// @brief connection information.
1201         // Intended comment.
1202         // std::string first_input_name =
1203         //   input_layer_names.empty()
1204         //     ? (is_actually_an_input_node(iter) || iter ==
1205         //     model_graph.cbegin()
1206         //          ? ""
1207         //          : (iter - 1)->getName())
1208         //     : input_layer_names[0];
1209         print_graph_layer_info(out, {iter->getName(), iter->getType(), "", ""});
1210         for (unsigned int i = 1; i < input_layer_names.size(); ++i) {
1211           print_graph_layer_info(out, {"", "", "", ""});
1212         }
1213         out << std::string(total_col_size,
1214                            iter == model_graph.cend() - 1 ? '=' : '-')
1215             << '\n';
1216       }
1217     }
1218   }
1219
1220   if (flags & PRINT_PROP) {
1221     /// @todo print neuralnet property
1222     /// @todo print mode (if it is eval or training)
1223   }
1224
1225   if (flags & PRINT_OPTIMIZER) {
1226     /// @todo print optimizer (with print optimizer prop)
1227   }
1228
1229   if (flags & PRINT_METRIC) {
1230     /// @todo print metric (currently it is done at printPreset as a
1231     /// workaround)
1232     /// @todo print loss function when it is not initialized. (if it is
1233     /// initialized, loss layer will be printed)
1234   }
1235
1236   if (model_graph.empty()) {
1237     out << "model is empty!" << std::endl;
1238     return;
1239   }
1240
1241   /** print layer properties */
1242   for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++)
1243     (*iter)->printPreset(out, layerPrintPreset);
1244
1245   /// @todo Add status to check neuralnet has been run. #290
1246 }
1247
1248 void NeuralNetwork::forEachLayer(
1249   std::function<void(ml::train::Layer &, RunLayerContext &, void *)> fn,
1250   void *user_data) {
1251   for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
1252     auto ln = std::static_pointer_cast<LayerNode>(*iter).get();
1253     fn(*ln, std::forward<RunLayerContext &>(ln->getRunContext()), user_data);
1254   };
1255 }
1256
1257 void NeuralNetwork::exports(const ml::train::ExportMethods &method,
1258                             const std::string file_path) {
1259   switch (method) {
1260   case ml::train::ExportMethods::METHOD_TFLITE: {
1261 #ifdef ENABLE_TFLITE_INTERPRETER
1262     nntrainer::TfliteInterpreter interpreter;
1263
1264     /// We will call "serialize" method for the model which is already trained
1265     /// or allocated. So, we need to call deallocateTensors first to make sure
1266     /// `dealloc_weights == false`
1267     model_graph.deallocateTensors();
1268     model_graph.allocateTensors(ExecutionMode::INFERENCE);
1269     interpreter.serialize(graph_representation, file_path);
1270     model_graph.deallocateTensors();
1271 #else
1272     throw std::runtime_error{
1273       "Export methods METHOD_TFLITE is not supported. Please enable tflite "
1274       "interpreter by set ENABLE_TFLITE_INTERPRETER=1"};
1275 #endif
1276     break;
1277   }
1278   default:
1279     throw std::runtime_error{"Unsupported export method"};
1280   }
1281 }
1282 } /* namespace nntrainer */