[ahub] fix ahub issue
[platform/core/ml/nntrainer.git] / nntrainer / models / neuralnet.cpp
1 /**
2  * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *   http://www.apache.org/licenses/LICENSE-2.0
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  *
15  * @file        neuralnet.cpp
16  * @date        04 December 2019
17  * @brief       This is Neural Network Class
18  * @see         https://github.com/nnstreamer/nntrainer
19  * @author      Jijoong Moon <jijoong.moon@samsung.com>
20  * @bug         No known bugs except for NYI items
21  *
22  */
23
24 #include "layer_context.h"
25 #include "model_common_properties.h"
26 #include <cmath>
27 #include <cstring>
28 #include <fstream>
29 #include <iomanip>
30 #include <iostream>
31 #include <sstream>
32
33 #include <activation_realizer.h>
34 #include <common_properties.h>
35 #include <databuffer.h>
36 #include <flatten_realizer.h>
37 #include <ini_interpreter.h>
38 #include <ini_wrapper.h>
39 #include <input_realizer.h>
40 #include <model_loader.h>
41 #include <multiout_realizer.h>
42 #include <neuralnet.h>
43 #include <nntrainer_error.h>
44 #include <nntrainer_log.h>
45 #include <node_exporter.h>
46 #include <optimizer_context.h>
47 #include <previous_input_realizer.h>
48 #include <profiler.h>
49 #include <recurrent_realizer.h>
50 #include <remap_realizer.h>
51 #include <slice_realizer.h>
52 #include <util_func.h>
53
54 #ifdef ENABLE_TFLITE_INTERPRETER
55 #include <tflite_interpreter.h>
56 #endif
57
58 /**
59  * @brief Internal enum values for nntrainer to summarize model accuracy & loss
60  */
61 #define ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS 101
62 #define ML_TRAIN_SUMMARY_MODEL_VALID_LOSS 102
63 #define ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY 103
64
65 namespace nntrainer {
66
67 NeuralNetwork::NeuralNetwork() :
68   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
69   model_flex_props(
70     props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
71     props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
72     props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
73   load_path(std::string()),
74   epoch_idx(0),
75   iter(0),
76   loss(0.0f),
77   data_buffers({nullptr, nullptr, nullptr}),
78   initialized(false),
79   compiled(false),
80   loadedFromConfig(false) {
81   app_context = AppContext(AppContext::Global());
82 }
83
84 NeuralNetwork::NeuralNetwork(AppContext app_context_) :
85   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
86   model_flex_props(
87     props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
88     props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
89     props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
90   load_path(std::string()),
91   epoch_idx(0),
92   iter(0),
93   loss(0.0f),
94   data_buffers({nullptr, nullptr, nullptr}),
95   initialized(false),
96   compiled(false),
97   loadedFromConfig(false),
98   app_context(app_context_) {}
99
100 int NeuralNetwork::loadFromConfig(const std::string &config) {
101   if (loadedFromConfig == true) {
102     ml_loge("cannnot do loadFromConfig twice");
103     return ML_ERROR_INVALID_PARAMETER;
104   }
105
106   ModelLoader loader(app_context);
107   NeuralNetwork tempNet(*this);
108
109   int status = loader.loadFromContext(tempNet);
110   if (status != ML_ERROR_NONE) {
111     return status;
112   }
113
114   status = loader.loadFromConfig(config, tempNet);
115   if (status != ML_ERROR_NONE) {
116     return status;
117   }
118
119   tempNet.loadedFromConfig = true;
120   swap(tempNet, *this);
121
122   return ML_ERROR_NONE;
123 }
124
125 unsigned int NeuralNetwork::getCurrentEpoch() {
126 #ifdef DEBUG
127   ml_logd("[NNTrainer] Current epoch: %d", epoch_idx);
128 #endif
129   return epoch_idx;
130 };
131
132 void NeuralNetwork::setProperty(const std::vector<std::string> &values) {
133   auto left_props = loadProperties(values, model_props);
134   setTrainConfig(left_props);
135 }
136
137 void NeuralNetwork::setTrainConfig(const std::vector<std::string> &values) {
138   auto left_props = loadProperties(values, model_flex_props);
139   NNTR_THROW_IF(left_props.size(), std::invalid_argument)
140     << "Model has unparsed properties, size: " << left_props.size()
141     << " of first element: " << left_props.front();
142 }
143
144 int NeuralNetwork::compile() {
145   std::string loss_type = std::get<props::LossType>(model_props).empty()
146                             ? std::string()
147                             : std::get<props::LossType>(model_props);
148
149   auto &input_conn = std::get<std::vector<props::InputConnection>>(model_props);
150   /// @note label layer might need to be treated in the similar way as well
151
152   /// @todo make NetworkGraph compiled at the construction instead of having
153   /// graph.compile(), neuralnetwork have ownership of list of layer nodes,
154   /// which will be passed at compile time.
155
156   std::vector<std::unique_ptr<GraphRealizer>> realizers;
157
158   realizers.emplace_back(new PreviousInputRealizer(
159     std::vector<Connection>(input_conn.begin(), input_conn.end())));
160   realizers.emplace_back(new MultioutRealizer());
161   realizers.emplace_back(new FlattenRealizer());
162   realizers.emplace_back(new ActivationRealizer());
163
164   for (auto &realizer : realizers) {
165     graph_representation = realizer->realize(graph_representation);
166   }
167
168   bool memory_swap = std::get<props::MemorySwap>(model_flex_props);
169   const std::string memory_swap_path =
170     std::get<props::MemorySwapPath>(model_flex_props);
171   unsigned int lookahead =
172     std::get<props::MemorySwapLookahead>(model_flex_props);
173   model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead);
174
175   model_graph.setMemoryOptimizations(
176     std::get<props::MemoryOptimization>(model_flex_props));
177   for (auto &node : graph_representation) {
178     if (auto &prop = std::get<props::ClipGradByGlobalNorm>(model_props);
179         !prop.empty()) {
180       node->setProperty({"clip_grad_by_norm=" + to_string(prop)});
181     }
182     model_graph.addLayer(node);
183   }
184
185   int status = model_graph.compile(loss_type);
186   NN_RETURN_STATUS();
187
188   compiled = true;
189
190   return status;
191 }
192
193 int NeuralNetwork::initialize() {
194   int status = ML_ERROR_NONE;
195
196   if (initialized) {
197     ml_loge("Error: Initializing the model again");
198     return ML_ERROR_NOT_SUPPORTED;
199   }
200
201   if (!compiled) {
202     ml_loge("Error: Need to compile first");
203     return ML_ERROR_NOT_SUPPORTED;
204   }
205
206   unsigned int n_layers = (unsigned int)model_graph.size();
207
208   ml_logd("initializing neural network, layer size: %d", n_layers);
209   PROFILE_MEM_ANNOTATE("Initialize");
210
211   auto &input_conn_prop =
212     std::get<std::vector<props::InputConnection>>(model_props);
213   auto &label_layer_prop =
214     std::get<std::vector<props::LabelLayer>>(model_props);
215
216   std::vector<Connection> input_conn(input_conn_prop.begin(),
217                                      input_conn_prop.end());
218   std::vector<std::string> label_layers;
219
220   if (!label_layer_prop.empty()) {
221     label_layers = std::vector<std::string>(label_layer_prop.begin(),
222                                             label_layer_prop.end());
223   }
224
225   status = model_graph.initialize(
226     input_conn,
227     std::vector<Connection>(label_layers.begin(), label_layers.end()));
228   NN_RETURN_STATUS();
229
230   model_graph.setBatchSize(
231     std::get<props::TrainingBatchSize>(model_flex_props));
232
233   // initialize optimizer and related variables
234   /// @todo: initialize should take a mode and check if mode is train but
235   /// optimizer is not given, make it as a hard error
236   if (opt) {
237     /** TODO: update request of optimizer to be of same format as
238      * Layer::requestTensor */
239     opt->finalize();
240     std::function<std::vector<TensorDim>(const TensorDim &)> cb =
241       [this](const TensorDim &dim) {
242         return opt->getOptimizerVariableDim(dim);
243       };
244     model_graph.requestOptimizerVariable(cb, true);
245   }
246
247   // Allocate weights
248   model_graph.allocateWeights();
249
250   initialized = true;
251
252   if (!load_path.empty()) {
253     load(load_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
254   }
255
256   return status;
257 }
258
259 /**
260  * @brief     free layers
261  */
262 NeuralNetwork::~NeuralNetwork() { deallocate(); }
263
264 /**
265  * @brief     forward propagation using layers object which has layer
266  */
267 sharedConstTensors
268 NeuralNetwork::forwarding(bool training,
269                           std::function<bool(void *userdata)> stop_cb) {
270   std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op =
271     [this, stop_cb](std::shared_ptr<LayerNode> node, bool training) -> void {
272     (void)this;
273     PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName());
274
275     auto f = std::get<0>(node->getExecutionOrder());
276     model_graph.flushCacheExcept(f);
277
278     node->forwarding(training);
279   };
280
281   return model_graph.forwarding(training, forwarding_op, stop_cb);
282 }
283
284 /**
285  * @brief     forward propagation using layers object which has layer
286  */
287 sharedConstTensors NeuralNetwork::forwarding(sharedConstTensors input,
288                                              sharedConstTensors label,
289                                              bool training) {
290   auto current_batch = model_graph.getBatchSize();
291   NNTR_THROW_IF(input[0]->batch() != current_batch ||
292                   (!label.empty() && label[0]->batch() != current_batch),
293                 std::logic_error)
294     << "Error: mismatch in batchsize for data and model."
295     << " input_batch: " << input[0]->batch()
296     << " label_batch: " << label[0]->batch()
297     << " target_batch: " << current_batch;
298
299   model_graph.setInputsLabels(input, label);
300
301   return forwarding(training);
302 }
303
304 /**
305  * @brief     back propagation
306  *            Call backwarding function of layer in reverse order
307  *            No need to call at first Input Layer (No data to be updated)
308  */
309 void NeuralNetwork::backwarding(int iteration,
310                                 std::function<bool(void *userdata)> stop_cb) {
311
312 #ifdef DEBUG
313   NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!";
314 #endif
315
316   std::function<void(std::shared_ptr<LayerNode>, int)> backwarding_op =
317     [this, stop_cb](std::shared_ptr<LayerNode> node, int iteration) -> void {
318     /**
319      * Do not change this order:
320      * 1. calcGradient
321      * 2. calcDerivative
322      * 3. applyGradient
323      * 4. gradientClippingOnLastAccess
324      */
325
326     model_graph.flushCacheExcept(std::get<1>(node->getExecutionOrder()));
327     PROFILE_MEM_ANNOTATE("CalcGradient: " + node->getName());
328
329     bool apply_gradient = true;
330
331     /** If gradient optimization mode, then calculate gradient first */
332     if (dynamic_training_opt.isGradientMode())
333       node->calcGradient();
334
335     /**
336      * If optimization off, or gradient must be applied, then this will be
337      * true
338      * @todo This apply gradient should be passed to the each weight and later
339      * be queried when updating gradient at once. (after moving apply_gradient
340      * out of this function)
341      *
342      */
343     // auto &layer = node->getObject();
344     // apply_gradient = dynamic_training_opt.checkIfApply(
345     //   layer->getWeightsRef(), layer->net_input[0], layer->net_hidden[0],
346     //   opt, iteration);
347
348     /** If gradient must be applied and its not gradient mode, calculate
349      * gradient
350      */
351     if (!dynamic_training_opt.isGradientMode() && apply_gradient)
352       node->calcGradient();
353
354     model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder()));
355     PROFILE_MEM_ANNOTATE("CalcDerivative: " + node->getName());
356
357     if (stop_cb(nullptr)) {
358       return;
359     }
360
361     if (node->needsCalcDerivative())
362       node->calcDerivative();
363
364     model_graph.flushCacheExcept(std::get<3>(node->getExecutionOrder()));
365     PROFILE_MEM_ANNOTATE("ApplyGradient: " + node->getName());
366
367     if (apply_gradient) {
368       /// Apply gradient only at the end of the last shared weight access
369       model_graph.applyGradients(
370         node.get(), [iteration, opt_ = opt.get()](Weight &w) {
371           w.calcRegularizationGradient();
372           w.calcWeightDecayGradient();
373           RunOptimizerContext opt_context(&w, iteration,
374                                           opt_->getLearningRate(iteration));
375           opt_->applyGradient(opt_context);
376         });
377     }
378   };
379
380   std::function<void(Weight &, int)> apply_grad_clip_op =
381     [opt_ = opt.get()](Weight &w, int iteration) -> void {
382     w.calcRegularizationGradient();
383     w.calcWeightDecayGradient();
384     RunOptimizerContext opt_context(&w, iteration,
385                                     opt_->getLearningRate(iteration));
386     opt_->applyGradient(opt_context);
387   };
388
389   model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op,
390                           stop_cb);
391 }
392
393 void NeuralNetwork::save(const std::string &file_path,
394                          ml::train::ModelFormat format) {
395   NNTR_THROW_IF(!initialized, std::runtime_error)
396     << "Cannot save model if not initialized yet, path: " << file_path
397     << " format: " << static_cast<unsigned>(format);
398
399   /// @todo this switch case should be delegating the function call only. It's
400   /// not delegating for now as required logics are managable for now.
401   switch (format) {
402   case ml::train::ModelFormat::MODEL_FORMAT_BIN: {
403     auto model_file = checkedOpenStream<std::ofstream>(
404       file_path, std::ios::out | std::ios::binary | std::ios::trunc);
405     for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
406       (*iter)->save(model_file);
407     }
408     if (opt && istrequal(opt->getType(), "adam")) {
409       std::string adam = "adam";
410       model_file.write(adam.c_str(), 4);
411       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
412            iter++) {
413         (*iter)->save(model_file, true);
414       }
415     }
416
417     model_file.write((char *)&epoch_idx, sizeof(epoch_idx));
418     model_file.write((char *)&iter, sizeof(iter));
419
420     model_file.close();
421     break;
422   }
423   case ml::train::ModelFormat::MODEL_FORMAT_INI:
424     saveModelIni(file_path);
425     break;
426
427   case ml::train::ModelFormat::MODEL_FORMAT_INI_WITH_BIN: {
428     auto old_save_path = std::get<props::SavePath>(model_flex_props);
429     auto bin_file_name =
430       file_path.substr(0, file_path.find_last_of('.')) + ".bin";
431
432     std::get<props::SavePath>(model_flex_props).set(bin_file_name);
433     save(file_path, ml::train::ModelFormat::MODEL_FORMAT_INI);
434     save(bin_file_name, ml::train::ModelFormat::MODEL_FORMAT_BIN);
435     std::get<props::SavePath>(model_flex_props) = old_save_path;
436     break;
437   }
438   default:
439     throw nntrainer::exception::not_supported(
440       "saving with given format is not supported yet");
441   }
442 }
443
444 void NeuralNetwork::load(const std::string &file_path,
445                          ml::train::ModelFormat format) {
446   /// @todo this switch case should be delegating the function call only. It's
447   /// not delegating for now as required logics are managable for now.
448   switch (format) {
449   case ml::train::ModelFormat::MODEL_FORMAT_BIN: {
450     NNTR_THROW_IF(!initialized, std::runtime_error)
451       << "Cannot load if not initialized yet, path: " << file_path
452       << " format: " << static_cast<unsigned>(format);
453
454     auto model_file = checkedOpenStream<std::ifstream>(
455       file_path, std::ios::in | std::ios::binary);
456     for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
457       (*iter)->read(model_file);
458     }
459     try {
460       /// this is assuming that the failure is allowed at the end of the file
461       /// read. so, after this line, additional read shouldn't be called
462       if (opt && istrequal(opt->getType(), "adam")) {
463         std::string opt_type;
464         opt_type.resize(4);
465         model_file.read((char *)&opt_type[0], 4);
466         if (istrequal(opt_type, "adam")) {
467           for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
468                iter++) {
469             (*iter)->read(model_file, true);
470           }
471         }
472       }
473
474       checkedRead(model_file, (char *)&epoch_idx, sizeof(epoch_idx),
475                   "[NeuralNetwork::readModel] failed to read epoch_idx");
476       checkedRead(model_file, (char *)&iter, sizeof(iter),
477                   "[NeuralNetwork::readModel] failed to read iteration");
478     } catch (...) {
479       std::cerr << "failed to read additional data like optimizer variable, "
480                    "iteration, proceeding with default\n";
481     }
482
483     ml_logi("read modelfile: %s", file_path.c_str());
484     break;
485   }
486   case ml::train::ModelFormat::MODEL_FORMAT_INI_WITH_BIN: {
487     int ret = loadFromConfig(file_path);
488     throw_status(ret);
489     auto &save_path = std::get<props::SavePath>(model_flex_props);
490     if (!save_path.empty()) {
491       checkedOpenStream<std::ifstream>(save_path,
492                                        std::ios::in | std::ios::binary);
493       load_path = save_path;
494     }
495     break;
496   }
497   case ml::train::ModelFormat::MODEL_FORMAT_INI: {
498     int ret = loadFromConfig(file_path);
499     throw_status(ret);
500     break;
501   }
502   case ml::train::ModelFormat::MODEL_FORMAT_FLATBUFFER: {
503     break;
504   }
505   default:
506     throw nntrainer::exception::not_supported(
507       "loading with given format is not supported yet");
508   }
509 }
510
511 float NeuralNetwork::getLoss() {
512   loss = 0.0f;
513
514   for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
515     loss += (*iter)->getLoss();
516   }
517   return loss;
518 }
519
520 void NeuralNetwork::setLoss(float l) { loss = l; }
521
522 NeuralNetwork &NeuralNetwork::copy(NeuralNetwork &from) {
523   if (this != &from) {
524     model_props = from.model_props;
525     model_flex_props = from.model_flex_props;
526     loss = from.loss;
527     opt = from.opt;
528
529     model_graph.copy(from.model_graph);
530   }
531   return *this;
532 }
533
534 void NeuralNetwork::saveModelIni(const std::string &file_path) {
535   NNTR_THROW_IF(isFileExist(file_path), std::invalid_argument)
536     << "There is already a file, overriding to the exisiting file is not "
537        "permitted, path: "
538     << file_path;
539
540   std::vector<IniSection> sections;
541
542   IniSection model_section = IniSection::FromExportable("model", *this);
543   model_section.setEntry("type", "NeuralNetwork");
544   sections.push_back(model_section);
545
546   auto add_section_if_any = [&sections](const std::string &section_name,
547                                         auto obj_ptr, auto pred) {
548     if (pred(obj_ptr)) {
549       IniSection s = IniSection::FromExportable(section_name, *obj_ptr);
550       s.setEntry("type", obj_ptr->getType());
551       sections.push_back(s);
552     }
553   };
554
555   add_section_if_any("optimizer", opt,
556                      [](const auto &obj) { return static_cast<bool>(obj); });
557
558   auto &[train_buffer, valid_buffer, test_buffer] = data_buffers;
559   auto data_buffer_valid = [](const auto &buffer) {
560     return buffer && buffer->isSerializable(
561                        ml::train::ExportMethods::METHOD_STRINGVECTOR);
562   };
563
564   add_section_if_any("train_set", train_buffer, data_buffer_valid);
565   add_section_if_any("valid_set", valid_buffer, data_buffer_valid);
566   add_section_if_any("test_set", test_buffer, data_buffer_valid);
567
568   IniWrapper wrapper("model_saver", sections);
569   wrapper.save_ini(file_path);
570
571   IniGraphInterpreter interpreter;
572   interpreter.serialize(graph_representation, file_path);
573 }
574
575 bool NeuralNetwork::validateInput(sharedConstTensors X) {
576   auto input_dim = getInputDimension();
577   if (X.size() != input_dim.size()) {
578     ml_loge("Error: provided number of inputs %d, required %d", (int)X.size(),
579             (int)input_dim.size());
580     return false;
581   }
582
583   for (unsigned int dim = 0; dim < input_dim.size(); dim++) {
584     if (input_dim[dim] != X[dim]->getDim()) {
585       ml_loge("Error: provided input shape does not match required shape");
586       std::stringstream ss;
587       ss << X[dim]->getDim();
588       ml_loge("Provided tensor summary : %s", ss.str().c_str());
589
590       ss.str(std::string());
591       ss << input_dim[dim];
592       ml_loge("Required tensor summary : %s", ss.str().c_str());
593       return false;
594     }
595   }
596
597   return true;
598 }
599
600 sharedConstTensors NeuralNetwork::inference(sharedConstTensors X,
601                                             bool free_mem) {
602   return inference(X, {}, free_mem);
603 }
604
605 sharedConstTensors NeuralNetwork::inference(sharedConstTensors X,
606                                             sharedConstTensors label,
607                                             bool free_mem) {
608   if (model_graph.getBatchSize() != X[0]->batch()) {
609     model_graph.setBatchSize(X[0]->batch());
610   }
611
612   sharedConstTensors out;
613   if (!validateInput(X))
614     throw std::invalid_argument("Input validation failed.");
615
616   allocate(ExecutionMode::INFERENCE);
617
618   int nn_foward;
619   PROFILE_TIME_REGISTER_EVENT(nn_foward, "nn_forward");
620   PROFILE_TIME_START(nn_foward);
621   out = forwarding(X, label, false);
622   PROFILE_TIME_END(nn_foward);
623
624   if (free_mem)
625     /**
626      * Free the memory needed for training before exiting.
627      * Note that this does not free the weights for the model.
628      * Weights of the model will be freed when the model is destroyed.
629      */
630     model_graph.deallocateTensors(false);
631
632   /** Clear the set inputs and labels */
633   model_graph.setInputsLabels({}, {});
634
635   return out;
636 }
637
638 std::vector<float *>
639 NeuralNetwork::inference(unsigned int batch_size,
640                          const std::vector<float *> &input,
641                          const std::vector<float *> &label) {
642   sharedConstTensors input_tensors, output_tensors;
643   auto in_dim = getInputDimension();
644
645   input_tensors.reserve(input.size());
646   for (unsigned int idx = 0; idx < in_dim.size(); idx++) {
647     in_dim[idx].batch(batch_size);
648     input_tensors.emplace_back(MAKE_SHARED_TENSOR(Tensor::Map(
649       input[idx], in_dim[idx].getDataLen() * sizeof(float), in_dim[idx], 0)));
650   }
651
652   if (!label.empty()) {
653     sharedConstTensors label_tensors;
654     auto label_dim = getOutputDimension();
655     label_tensors.reserve(label.size());
656     for (unsigned int idx = 0; idx < label_dim.size(); idx++) {
657       label_dim[idx].batch(batch_size);
658       label_tensors.emplace_back(MAKE_SHARED_TENSOR(
659         Tensor::Map(label[idx], label_dim[idx].getDataLen() * sizeof(float),
660                     label_dim[idx], 0)));
661     }
662     output_tensors = inference(input_tensors, label_tensors, false);
663   } else {
664     output_tensors = inference(input_tensors, false);
665   }
666
667   std::vector<float *> output;
668   output.reserve(output_tensors.size());
669
670   for (auto &out : output_tensors) {
671     auto out_t = *out.get();
672     output.push_back(out_t.getData());
673   }
674
675   return output;
676 }
677
678 int NeuralNetwork::setDataset(const DatasetModeType &mode,
679                               std::shared_ptr<ml::train::Dataset> dataset) {
680   return setDataBuffer(mode, std::static_pointer_cast<DataBuffer>(dataset));
681 }
682
683 int NeuralNetwork::allocate(ExecutionMode mode) {
684   model_graph.deallocateTensors();
685   model_graph.allocateTensors(mode);
686
687   return ML_ERROR_NONE;
688 }
689
690 int NeuralNetwork::deallocate() {
691   model_graph.deallocateTensors(true);
692
693   return ML_ERROR_NONE;
694 }
695
696 int NeuralNetwork::train(const std::vector<std::string> &values,
697                          std::function<bool(void *)> stop_cb) {
698   int status = ML_ERROR_NONE;
699
700   if (data_buffers[static_cast<int>(DatasetModeType::MODE_TRAIN)] == nullptr) {
701     ml_loge("Cannot initialize the model without the train data buffer.");
702     return ML_ERROR_INVALID_PARAMETER;
703   }
704
705   if (!opt) {
706     ml_loge("Cannot train network without optimizer.");
707     return ML_ERROR_INVALID_PARAMETER;
708   }
709
710   setTrainConfig(values);
711
712   /** set batch size just before training */
713   model_graph.setBatchSize(
714     std::get<props::TrainingBatchSize>(model_flex_props));
715
716   status = allocate(ExecutionMode::TRAIN);
717   NN_RETURN_STATUS();
718
719   status = train_run(stop_cb);
720   NN_RETURN_STATUS();
721
722   /**
723    * Free the memory needed for training before exiting.
724    * Note that this does not free the weights for the model.
725    * Weights of the model will be freed when the model is destroyed.
726    */
727   model_graph.deallocateTensors(false);
728   return status;
729 }
730
731 /**
732  * @brief     Run NeuralNetwork train with callback function by user
733  */
734 int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
735   int status = ML_ERROR_NONE;
736
737   if (!std::get<props::ContinueTrain>(model_flex_props)) {
738     epoch_idx = 0;
739     iter = 0;
740     for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
741       (*iter)->clearOptVar();
742     }
743   }
744
745   auto batch_size = std::get<props::TrainingBatchSize>(model_flex_props);
746
747   auto const &outputs = model_graph.getOutputTensors();
748   auto in_dims = model_graph.getInputDimension();
749   auto label_dims = model_graph.getOutputDimension();
750
751   auto &[train_buffer, valid_buffer, test_buffer] = data_buffers;
752
753   if (train_buffer == nullptr) {
754     ml_loge("[NeuralNetworks] there is no train dataset!");
755     return ML_ERROR_INVALID_PARAMETER;
756   }
757
758   /**
759    * @brief run a single epoch with given callback, @a auto is used instead of
760    * std::function for performance measure
761    * @param buffer buffer to run
762    * @param shuffle whether to shuffle or not
763    * @param on_iteration_fetch function that will recieve reference to stat,
764    * buffer which will be called every time data is fetched and set
765    * @param on_epoch_end function that will recieve reference to stat,
766    * buffer which will be called on the epoch end
767    */
768   auto run_epoch = [this, &in_dims, &label_dims, &outputs, batch_size](
769                      DataBuffer *buffer, bool shuffle,
770                      auto &&on_iteration_fetch, auto &&on_iteration_update_stat,
771                      auto &&on_epoch_end) {
772     /// @todo managing metrics must be handled here as well!! for now it is
773     /// handled in individual callbacks
774     RunStats stat;
775     std::future<std::shared_ptr<IterationQueue>> future_iq =
776       buffer->startFetchWorker(in_dims, label_dims, shuffle);
777     while (true) {
778       ScopedView<Iteration> iter_view = buffer->fetch();
779       if (iter_view.isEmpty()) {
780         break;
781       }
782       auto &iteration = iter_view.get();
783       if (iteration.batch() != batch_size) {
784         /// @todo support partial batch
785         continue;
786       }
787
788       auto const &labels = iteration.getLabelsRef();
789       auto const &inputs = iteration.getInputsRef();
790       model_graph.setInputsLabels(inputs, labels);
791
792       on_iteration_fetch(stat, *buffer);
793       on_iteration_update_stat(stat, outputs, labels);
794     }
795     future_iq.get();
796     on_epoch_end(stat, *buffer);
797
798     if (stat.num_iterations == 0) {
799       throw std::runtime_error("No data came while buffer ran");
800     }
801
802     return stat;
803   };
804
805   auto train_for_iteration = [this, stop_cb](RunStats &stat,
806                                              DataBuffer &buffer) {
807     forwarding(true, stop_cb);
808     backwarding(iter++, stop_cb);
809
810     // To avoid unconsidered memory leak, we need to clear the cache
811     model_graph.flushCache();
812
813     if (!stop_cb(nullptr)) {
814       std::cout << "#" << epoch_idx << "/" << getEpochs();
815       ml_logi("# %d / %d", epoch_idx, getEpochs());
816       auto loss = getLoss();
817       buffer.displayProgress(stat.num_iterations, loss);
818     }
819   };
820
821   auto update_train_stat = [this](RunStats &stat,
822                                   const std::vector<Tensor> &outputs,
823                                   const std::vector<Tensor> &labels) {
824     stat.loss += getLoss();
825     stat.num_iterations++;
826   };
827
828   auto train_epoch_end = [this, stop_cb](RunStats &stat, DataBuffer &buffer) {
829     stat.loss /= static_cast<float>(stat.num_iterations);
830     auto &save_path = std::get<props::SavePath>(model_flex_props);
831     if (!stop_cb(nullptr)) {
832       if (!save_path.empty()) {
833         save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
834       }
835
836       std::cout << "#" << epoch_idx << "/" << getEpochs()
837                 << " - Training Loss: " << stat.loss;
838       ml_logi("# %d / %d - Training Loss: %f", epoch_idx, getEpochs(),
839               stat.loss);
840       ml_logd("[NNTrainer] Training epoch %d / %d finished successfully.",
841               epoch_idx, getEpochs());
842     } else {
843       ml_logd("[NNTrainer] Training stopped by stop callback function during "
844               "epoch %d.",
845               epoch_idx);
846     }
847   };
848
849   auto eval_for_iteration = [this, batch_size](RunStats &stat,
850                                                DataBuffer &buffer) {
851     forwarding(false);
852   };
853
854   auto update_eval_stat = [batch_size, &update_train_stat](
855                             RunStats &stat, const std::vector<Tensor> &outputs,
856                             const std::vector<Tensor> &labels) {
857     auto model_out = outputs[0].argmax();
858     auto label_out = labels[0].argmax();
859
860     for (unsigned int b = 0; b < batch_size; b++) {
861       if (model_out[b] == label_out[b])
862         stat.num_correct_predictions++;
863     }
864
865     update_train_stat(stat, outputs, labels);
866   };
867
868   auto eval_epoch_end = [this, batch_size, max_acc = 0.0f,
869                          min_loss = std::numeric_limits<float>::max()](
870                           RunStats &stat, DataBuffer &buffer) mutable {
871     stat.loss /= static_cast<float>(stat.num_iterations);
872     stat.accuracy = stat.num_correct_predictions /
873                     static_cast<float>(stat.num_iterations * batch_size) *
874                     100.0f;
875
876     if (stat.accuracy > max_acc ||
877         (stat.accuracy == max_acc && stat.loss < min_loss)) {
878       max_acc = stat.accuracy;
879       /// @note this is not actually 'the' min loss for whole time but records
880       /// when data change
881       min_loss = stat.loss;
882       auto &save_best_path = std::get<props::SaveBestPath>(model_flex_props);
883       if (!save_best_path.empty()) {
884         save(save_best_path);
885       }
886     }
887     std::cout << " >> [ Accuracy: " << stat.accuracy
888               << "% - Validation Loss : " << stat.loss << " ]";
889     ml_logi("[ Accuracy: %.2f %% - Validataion Loss: %.5f", stat.accuracy,
890             stat.loss);
891   };
892
893   PROFILE_MEM_ANNOTATE("TRAIN START");
894   auto epochs = getEpochs();
895   ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
896           epoch_idx + 1, getEpochs());
897   for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
898     if (stop_cb(nullptr)) {
899       --epoch_idx;
900       break;
901     }
902     training = run_epoch(train_buffer.get(), true, train_for_iteration,
903                          update_train_stat, train_epoch_end);
904     if (valid_buffer) {
905       validation = run_epoch(valid_buffer.get(), false, eval_for_iteration,
906                              update_eval_stat, eval_epoch_end);
907     }
908     std::cout << '\n';
909   }
910   PROFILE_MEM_ANNOTATE("TRAIN END");
911
912   if (test_buffer) {
913     std::cout << "Evaluation with test data...\n";
914     testing = run_epoch(test_buffer.get(), false, eval_for_iteration,
915                         update_eval_stat, eval_epoch_end);
916   }
917
918   /** Clear the set inputs and labels */
919   model_graph.setInputsLabels({}, {});
920
921   return status;
922 }
923
924 void swap(NeuralNetwork &lhs, NeuralNetwork &rhs) {
925   {
926     using std::swap;
927
928     swap(lhs.model_props, rhs.model_props);
929     swap(lhs.model_flex_props, rhs.model_flex_props);
930     swap(lhs.load_path, rhs.load_path);
931     swap(lhs.epoch_idx, rhs.epoch_idx);
932     swap(lhs.iter, rhs.iter);
933     swap(lhs.loss, rhs.loss);
934     swap(lhs.opt, rhs.opt);
935     swap(lhs.data_buffers, rhs.data_buffers);
936     swap(lhs.initialized, rhs.initialized);
937     swap(lhs.model_graph, rhs.model_graph);
938     swap(lhs.graph_representation, rhs.graph_representation);
939     swap(lhs.compiled, rhs.compiled);
940     swap(lhs.loadedFromConfig, rhs.loadedFromConfig);
941   }
942 }
943
944 int NeuralNetwork::addLayer(NodeType layer) {
945   int status = ML_ERROR_NONE;
946
947   if (initialized) {
948     return ML_ERROR_NOT_SUPPORTED;
949   }
950
951   /** Insert the layer to the graph */
952   model_graph.addLayer(layer);
953   graph_representation.push_back(layer);
954
955   return status;
956 }
957
958 NeuralNetwork &NeuralNetwork::copyConfiguration(NeuralNetwork &from) {
959   if (this != &from) {
960     model_props = from.model_props;
961     model_flex_props = from.model_flex_props;
962     loss = from.loss;
963     opt = from.opt;
964
965     NetworkGraph f_graph = from.getNetworkGraph();
966     for (auto &l_node : f_graph.getLayerNodes()) {
967       addLayer(static_cast<std::shared_ptr<ml::train::Layer>>(
968         l_node->cloneConfiguration()));
969     }
970   }
971   return *this;
972 }
973
974 NeuralNetwork::GraphType
975 NeuralNetwork::getUnsortedLayers(const std::string &input_layer,
976                                  const std::string &output_layer) {
977   return model_graph.getUnsortedLayers(input_layer, output_layer);
978 }
979
980 int NeuralNetwork::setOptimizer(
981   std::shared_ptr<ml::train::Optimizer> optimizer) {
982   if (initialized) {
983     return ML_ERROR_NOT_SUPPORTED;
984   }
985
986   opt = std::static_pointer_cast<OptimizerWrapped>(optimizer);
987
988   return ML_ERROR_NONE;
989 }
990
991 int NeuralNetwork::setDataBuffer(const DatasetModeType &mode,
992                                  std::shared_ptr<DataBuffer> data_buffer) {
993   if (data_buffer == nullptr) {
994     return ML_ERROR_INVALID_PARAMETER;
995   }
996
997   this->data_buffers[static_cast<int>(mode)] = data_buffer;
998
999   return ML_ERROR_NONE;
1000 }
1001
1002 int NeuralNetwork::getLayer(const char *name,
1003                             std::shared_ptr<ml::train::Layer> *layer) {
1004   // We provide the layer change through the api with user's responsibility.
1005   //
1006   // if (compiled) {
1007   //   ml_loge("Cannot get compiled layer.");
1008   //   return ML_ERROR_NOT_SUPPORTED;
1009   // }
1010
1011   *layer = std::static_pointer_cast<ml::train::Layer>(
1012     model_graph.getLayerNode(std::string(name)));
1013   return ML_ERROR_NONE;
1014 }
1015
1016 void NeuralNetwork::printMetrics(std::ostream &out, unsigned int flags) {
1017   switch (flags) {
1018   case ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS:
1019     out << training.loss << std::endl;
1020     break;
1021
1022   case ML_TRAIN_SUMMARY_MODEL_VALID_LOSS:
1023     out << validation.loss << std::endl;
1024     break;
1025
1026   case ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY:
1027     out << validation.accuracy << std::endl;
1028     break;
1029
1030   default:
1031     break;
1032   }
1033 }
1034
1035 void NeuralNetwork::printPreset(std::ostream &out, unsigned int preset) {
1036   /** print neuralnet metrics */
1037   printMetrics(out, preset);
1038   if (preset > ML_TRAIN_SUMMARY_TENSOR)
1039     return;
1040
1041   LayerNode::PrintPreset layer_preset = LayerNode::PrintPreset::PRINT_NONE;
1042
1043   ///@todo match flags with preset
1044   unsigned int flags = PRINT_INST_INFO | PRINT_GRAPH_INFO | PRINT_PROP |
1045                        PRINT_OPTIMIZER | PRINT_METRIC;
1046
1047   switch (preset) {
1048   case ML_TRAIN_SUMMARY_TENSOR:
1049     layer_preset = LayerNode::PrintPreset::PRINT_ALL;
1050     break;
1051   case ML_TRAIN_SUMMARY_LAYER:
1052     layer_preset = initialized ? LayerNode::PrintPreset::PRINT_SUMMARY
1053                                : LayerNode::PrintPreset::PRINT_SUMMARY_META;
1054     break;
1055   case ML_TRAIN_SUMMARY_MODEL:
1056     break;
1057   default:
1058     throw std::invalid_argument("given verbosity is invalid");
1059   }
1060
1061   print(out, flags, layer_preset);
1062 }
1063
1064 void NeuralNetwork::addWithReferenceLayers(
1065   const std::vector<std::shared_ptr<ml::train::Layer>> &reference,
1066   const std::string &scope, const std::vector<std::string> &input_layers,
1067   const std::vector<std::string> &start_layers,
1068   const std::vector<std::string> &end_layers,
1069   ml::train::ReferenceLayersType type,
1070   const std::vector<std::string> &type_properties) {
1071   std::vector<NodeType> casted_reference;
1072   casted_reference.reserve(reference.size());
1073   for (auto &node : reference) {
1074     casted_reference.emplace_back(std::static_pointer_cast<LayerNode>(node));
1075   }
1076
1077   addWithReferenceLayers(casted_reference, scope, input_layers, start_layers,
1078                          end_layers, type, type_properties);
1079 }
1080 void NeuralNetwork::addWithReferenceLayers(
1081   const std::vector<std::shared_ptr<LayerNode>> &reference,
1082   const std::string &scope, const std::vector<std::string> &input_layers,
1083   const std::vector<std::string> &start_layers,
1084   const std::vector<std::string> &end_layers,
1085   ml::train::ReferenceLayersType type,
1086   const std::vector<std::string> &type_properties) {
1087   /// @todo below configuration should be extracted as a free function to make
1088   /// it more testable, and reused inside graph interpreter
1089
1090   /// @note we can exploit connection to connection more fine grained, for now
1091   /// it is not supported but we can easily make this supported
1092   std::vector<std::shared_ptr<LayerNode>> nodes;
1093   nodes.reserve(reference.size());
1094   for (auto &node : reference) {
1095     nodes.push_back(node->cloneConfiguration());
1096   }
1097
1098   auto start_conns =
1099     std::vector<Connection>(start_layers.begin(), start_layers.end());
1100   auto input_conns =
1101     std::vector<Connection>(input_layers.begin(), input_layers.end());
1102   auto end_conns =
1103     std::vector<Connection>(end_layers.begin(), end_layers.end());
1104
1105   std::vector<std::unique_ptr<GraphRealizer>> realizers;
1106
1107   realizers.emplace_back(new PreviousInputRealizer(start_conns));
1108   realizers.emplace_back(new SliceRealizer(start_conns, end_conns));
1109
1110   if (!input_conns.empty()) {
1111     realizers.emplace_back(new InputRealizer(start_conns, input_conns));
1112   }
1113
1114   if (type == ml::train::ReferenceLayersType::RECURRENT) {
1115     realizers.emplace_back(
1116       new RecurrentRealizer(type_properties, input_conns, end_conns));
1117   }
1118
1119   if (!scope.empty()) {
1120     realizers.emplace_back(
1121       new RemapRealizer([&scope, &input_conns](std::string &name) {
1122         for (auto &i : input_conns) {
1123           if (i.getName() == name) {
1124             return;
1125           }
1126         }
1127         name = scope + "/" + name;
1128       }));
1129   }
1130
1131   for (auto &realizer : realizers) {
1132     nodes = realizer->realize(nodes);
1133   }
1134
1135   for (auto &node : nodes) {
1136     addLayer(node);
1137   }
1138 }
1139
1140 void NeuralNetwork::exportTo(Exporter &exporter,
1141                              const ml::train::ExportMethods &method) const {
1142   exporter.saveResult(model_props, method, this);
1143   exporter.saveResult(model_flex_props, method, this);
1144 }
1145
1146 void NeuralNetwork::print(std::ostream &out, unsigned int flags,
1147                           LayerNode::PrintPreset layerPrintPreset) {
1148   if (flags & PRINT_INST_INFO) {
1149     /// @todo uncomment this after implement getProperty (#1875)
1150     // out << "===================";
1151     // printInstance(out, this);
1152   }
1153
1154   if (flags & PRINT_GRAPH_INFO) {
1155     unsigned int total_col_size = 80;
1156     std::vector<unsigned int> column_size = {20, 20, 20, 20};
1157     auto print_graph_layer_info =
1158       [column_size](std::ostream &out, std::vector<std::string> layer_info) {
1159         auto trim_string = [](std::string str, unsigned int column_width) {
1160           return str.size() < column_width ? str
1161                                            : str.substr(0, column_width - 1);
1162         };
1163
1164         for (unsigned int i = 0; i < column_size.size(); ++i) {
1165           out << std::setw(column_size[i])
1166               << trim_string(layer_info[i], column_size[i]);
1167         }
1168         out << "\n";
1169       };
1170
1171     out << std::string(total_col_size, '=') << '\n';
1172     print_graph_layer_info(
1173       out, {"Layer name", "Layer type", "Input dimension", "Input layer"});
1174     out << std::string(total_col_size, '=') << '\n';
1175     if (compiled) {
1176       props::GenericShape dim_property;
1177
1178       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
1179            iter++) {
1180         std::string first_dim;
1181         if (iter->getInputDimensions().empty()) {
1182           first_dim = "";
1183         } else {
1184           dim_property.set(iter->getInputDimensions()[0]);
1185           first_dim = to_string(dim_property);
1186         }
1187         const std::vector<std::string> &input_layer_names =
1188           iter->getInputConnections();
1189         std::string first_input_name =
1190           input_layer_names.empty() ? "" : input_layer_names[0];
1191         print_graph_layer_info(
1192           out, {iter->getName(), iter->getType(), first_dim, first_input_name});
1193         for (unsigned int i = 1; i < input_layer_names.size(); ++i) {
1194           dim_property.set(iter->getInputDimensions()[i]);
1195           print_graph_layer_info(
1196             out, {"", "", to_string(dim_property), input_layer_names[i]});
1197         }
1198         out << std::string(total_col_size,
1199                            iter == model_graph.cend() - 1 ? '=' : '-')
1200             << '\n';
1201       }
1202     } else {
1203       auto &input_connection =
1204         std::get<std::vector<props::InputConnection>>(model_props);
1205       auto model_input = std::vector<Connection>(input_connection.begin(),
1206                                                  input_connection.end());
1207       auto is_actually_an_input_node =
1208         [model_input](graph_const_iterator<LayerNode> node) {
1209           return node->hasInputShapeProperty() or
1210                  std::any_of(model_input.begin(), model_input.end(),
1211                              [node](auto &conn) {
1212                                return node->getName() == conn.getName();
1213                              });
1214         };
1215
1216       for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
1217            iter++) {
1218         const std::vector<std::string> &input_layer_names =
1219           iter->getInputConnections();
1220
1221         /// @brief connection information.
1222         // Intended comment.
1223         // std::string first_input_name =
1224         //   input_layer_names.empty()
1225         //     ? (is_actually_an_input_node(iter) || iter ==
1226         //     model_graph.cbegin()
1227         //          ? ""
1228         //          : (iter - 1)->getName())
1229         //     : input_layer_names[0];
1230         print_graph_layer_info(out, {iter->getName(), iter->getType(), "", ""});
1231         for (unsigned int i = 1; i < input_layer_names.size(); ++i) {
1232           print_graph_layer_info(out, {"", "", "", ""});
1233         }
1234         out << std::string(total_col_size,
1235                            iter == model_graph.cend() - 1 ? '=' : '-')
1236             << '\n';
1237       }
1238     }
1239   }
1240
1241   if (flags & PRINT_PROP) {
1242     /// @todo print neuralnet property
1243     /// @todo print mode (if it is eval or training)
1244   }
1245
1246   if (flags & PRINT_OPTIMIZER) {
1247     /// @todo print optimizer (with print optimizer prop)
1248   }
1249
1250   if (flags & PRINT_METRIC) {
1251     /// @todo print metric (currently it is done at printPreset as a
1252     /// workaround)
1253     /// @todo print loss function when it is not initialized. (if it is
1254     /// initialized, loss layer will be printed)
1255   }
1256
1257   if (model_graph.empty()) {
1258     out << "model is empty!" << std::endl;
1259     return;
1260   }
1261
1262   /** print layer properties */
1263   for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++)
1264     (*iter)->printPreset(out, layerPrintPreset);
1265
1266   /// @todo Add status to check neuralnet has been run. #290
1267 }
1268
1269 void NeuralNetwork::forEachLayer(
1270   std::function<void(ml::train::Layer &, RunLayerContext &, void *)> fn,
1271   void *user_data) {
1272   for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
1273     auto ln = std::static_pointer_cast<LayerNode>(*iter).get();
1274     fn(*ln, std::forward<RunLayerContext &>(ln->getRunContext()), user_data);
1275   };
1276 }
1277
1278 void NeuralNetwork::exports(const ml::train::ExportMethods &method,
1279                             const std::string file_path) {
1280   switch (method) {
1281   case ml::train::ExportMethods::METHOD_TFLITE: {
1282 #ifdef ENABLE_TFLITE_INTERPRETER
1283     nntrainer::TfliteInterpreter interpreter;
1284
1285     /// We will call "serialize" method for the model which is already trained
1286     /// or allocated. So, we need to call deallocateTensors first to make sure
1287     /// `dealloc_weights == false`
1288     model_graph.deallocateTensors();
1289     model_graph.allocateTensors(ExecutionMode::INFERENCE);
1290     interpreter.serialize(graph_representation, file_path);
1291     model_graph.deallocateTensors();
1292 #else
1293     throw std::runtime_error{
1294       "Export methods METHOD_TFLITE is not supported. Please enable tflite "
1295       "interpreter by set ENABLE_TFLITE_INTERPRETER=1"};
1296 #endif
1297     break;
1298   }
1299   case ml::train::ExportMethods::METHOD_FLATBUFFER: {
1300
1301     model_graph.deallocateTensors();
1302     model_graph.allocateTensors(ExecutionMode::TRAIN);
1303     break;
1304   }
1305   default:
1306     throw std::runtime_error{"Unsupported export method"};
1307   }
1308 }
1309 } /* namespace nntrainer */