2 * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
16 * @date 04 December 2019
17 * @brief This is Neural Network Class
18 * @see https://github.com/nnstreamer/nntrainer
19 * @author Jijoong Moon <jijoong.moon@samsung.com>
20 * @bug No known bugs except for NYI items
24 #include "layer_context.h"
25 #include "model_common_properties.h"
33 #include <activation_realizer.h>
34 #include <common_properties.h>
35 #include <databuffer.h>
36 #include <flatten_realizer.h>
37 #include <ini_interpreter.h>
38 #include <ini_wrapper.h>
39 #include <input_realizer.h>
40 #include <model_loader.h>
41 #include <multiout_realizer.h>
42 #include <neuralnet.h>
43 #include <nntrainer_error.h>
44 #include <nntrainer_log.h>
45 #include <node_exporter.h>
46 #include <optimizer_context.h>
47 #include <previous_input_realizer.h>
49 #include <recurrent_realizer.h>
50 #include <remap_realizer.h>
51 #include <slice_realizer.h>
52 #include <util_func.h>
54 #ifdef ENABLE_TFLITE_INTERPRETER
55 #include <tflite_interpreter.h>
59 * @brief Internal enum values for nntrainer to summarize model accuracy & loss
61 #define ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS 101
62 #define ML_TRAIN_SUMMARY_MODEL_VALID_LOSS 102
63 #define ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY 103
67 NeuralNetwork::NeuralNetwork() :
68 model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
70 props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
71 props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
72 props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
73 load_path(std::string()),
77 data_buffers({nullptr, nullptr, nullptr}),
80 loadedFromConfig(false) {
81 app_context = AppContext(AppContext::Global());
84 NeuralNetwork::NeuralNetwork(AppContext app_context_) :
85 model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
87 props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
88 props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
89 props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
90 load_path(std::string()),
94 data_buffers({nullptr, nullptr, nullptr}),
97 loadedFromConfig(false),
98 app_context(app_context_) {}
100 int NeuralNetwork::loadFromConfig(const std::string &config) {
101 if (loadedFromConfig == true) {
102 ml_loge("cannnot do loadFromConfig twice");
103 return ML_ERROR_INVALID_PARAMETER;
106 ModelLoader loader(app_context);
107 NeuralNetwork tempNet(*this);
109 int status = loader.loadFromContext(tempNet);
110 if (status != ML_ERROR_NONE) {
114 status = loader.loadFromConfig(config, tempNet);
115 if (status != ML_ERROR_NONE) {
119 tempNet.loadedFromConfig = true;
120 swap(tempNet, *this);
122 return ML_ERROR_NONE;
125 unsigned int NeuralNetwork::getCurrentEpoch() {
127 ml_logd("[NNTrainer] Current epoch: %d", epoch_idx);
132 void NeuralNetwork::setProperty(const std::vector<std::string> &values) {
133 auto left_props = loadProperties(values, model_props);
134 setTrainConfig(left_props);
137 void NeuralNetwork::setTrainConfig(const std::vector<std::string> &values) {
138 auto left_props = loadProperties(values, model_flex_props);
139 NNTR_THROW_IF(left_props.size(), std::invalid_argument)
140 << "Model has unparsed properties, size: " << left_props.size()
141 << " of first element: " << left_props.front();
144 int NeuralNetwork::compile() {
145 std::string loss_type = std::get<props::LossType>(model_props).empty()
147 : std::get<props::LossType>(model_props);
149 auto &input_conn = std::get<std::vector<props::InputConnection>>(model_props);
150 /// @note label layer might need to be treated in the similar way as well
152 /// @todo make NetworkGraph compiled at the construction instead of having
153 /// graph.compile(), neuralnetwork have ownership of list of layer nodes,
154 /// which will be passed at compile time.
156 std::vector<std::unique_ptr<GraphRealizer>> realizers;
158 realizers.emplace_back(new PreviousInputRealizer(
159 std::vector<Connection>(input_conn.begin(), input_conn.end())));
160 realizers.emplace_back(new MultioutRealizer());
161 realizers.emplace_back(new FlattenRealizer());
162 realizers.emplace_back(new ActivationRealizer());
164 for (auto &realizer : realizers) {
165 graph_representation = realizer->realize(graph_representation);
168 bool memory_swap = std::get<props::MemorySwap>(model_flex_props);
169 const std::string memory_swap_path =
170 std::get<props::MemorySwapPath>(model_flex_props);
171 unsigned int lookahead =
172 std::get<props::MemorySwapLookahead>(model_flex_props);
173 model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead);
175 model_graph.setMemoryOptimizations(
176 std::get<props::MemoryOptimization>(model_flex_props));
177 for (auto &node : graph_representation) {
178 if (auto &prop = std::get<props::ClipGradByGlobalNorm>(model_props);
180 node->setProperty({"clip_grad_by_norm=" + to_string(prop)});
182 model_graph.addLayer(node);
185 int status = model_graph.compile(loss_type);
193 int NeuralNetwork::initialize() {
194 int status = ML_ERROR_NONE;
197 ml_loge("Error: Initializing the model again");
198 return ML_ERROR_NOT_SUPPORTED;
202 ml_loge("Error: Need to compile first");
203 return ML_ERROR_NOT_SUPPORTED;
206 unsigned int n_layers = (unsigned int)model_graph.size();
208 ml_logd("initializing neural network, layer size: %d", n_layers);
209 PROFILE_MEM_ANNOTATE("Initialize");
211 auto &input_conn_prop =
212 std::get<std::vector<props::InputConnection>>(model_props);
213 auto &label_layer_prop =
214 std::get<std::vector<props::LabelLayer>>(model_props);
216 std::vector<Connection> input_conn(input_conn_prop.begin(),
217 input_conn_prop.end());
218 std::vector<std::string> label_layers;
220 if (!label_layer_prop.empty()) {
221 label_layers = std::vector<std::string>(label_layer_prop.begin(),
222 label_layer_prop.end());
225 status = model_graph.initialize(
227 std::vector<Connection>(label_layers.begin(), label_layers.end()));
230 model_graph.setBatchSize(
231 std::get<props::TrainingBatchSize>(model_flex_props));
233 // initialize optimizer and related variables
234 /// @todo: initialize should take a mode and check if mode is train but
235 /// optimizer is not given, make it as a hard error
237 /** TODO: update request of optimizer to be of same format as
238 * Layer::requestTensor */
240 std::function<std::vector<TensorDim>(const TensorDim &)> cb =
241 [this](const TensorDim &dim) {
242 return opt->getOptimizerVariableDim(dim);
244 model_graph.requestOptimizerVariable(cb, true);
248 model_graph.allocateWeights();
252 if (!load_path.empty()) {
253 load(load_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
262 NeuralNetwork::~NeuralNetwork() { deallocate(); }
265 * @brief forward propagation using layers object which has layer
268 NeuralNetwork::forwarding(bool training,
269 std::function<bool(void *userdata)> stop_cb) {
270 std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op =
271 [this, stop_cb](std::shared_ptr<LayerNode> node, bool training) -> void {
273 PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName());
275 auto f = std::get<0>(node->getExecutionOrder());
276 model_graph.flushCacheExcept(f);
278 node->forwarding(training);
281 return model_graph.forwarding(training, forwarding_op, stop_cb);
285 * @brief forward propagation using layers object which has layer
287 sharedConstTensors NeuralNetwork::forwarding(sharedConstTensors input,
288 sharedConstTensors label,
290 auto current_batch = model_graph.getBatchSize();
291 NNTR_THROW_IF(input[0]->batch() != current_batch ||
292 (!label.empty() && label[0]->batch() != current_batch),
294 << "Error: mismatch in batchsize for data and model."
295 << " input_batch: " << input[0]->batch()
296 << " label_batch: " << label[0]->batch()
297 << " target_batch: " << current_batch;
299 model_graph.setInputsLabels(input, label);
301 return forwarding(training);
305 * @brief back propagation
306 * Call backwarding function of layer in reverse order
307 * No need to call at first Input Layer (No data to be updated)
309 void NeuralNetwork::backwarding(int iteration,
310 std::function<bool(void *userdata)> stop_cb) {
313 NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!";
316 std::function<void(std::shared_ptr<LayerNode>, int)> backwarding_op =
317 [this, stop_cb](std::shared_ptr<LayerNode> node, int iteration) -> void {
319 * Do not change this order:
323 * 4. gradientClippingOnLastAccess
326 model_graph.flushCacheExcept(std::get<1>(node->getExecutionOrder()));
327 PROFILE_MEM_ANNOTATE("CalcGradient: " + node->getName());
329 bool apply_gradient = true;
331 /** If gradient optimization mode, then calculate gradient first */
332 if (dynamic_training_opt.isGradientMode())
333 node->calcGradient();
336 * If optimization off, or gradient must be applied, then this will be
338 * @todo This apply gradient should be passed to the each weight and later
339 * be queried when updating gradient at once. (after moving apply_gradient
340 * out of this function)
343 // auto &layer = node->getObject();
344 // apply_gradient = dynamic_training_opt.checkIfApply(
345 // layer->getWeightsRef(), layer->net_input[0], layer->net_hidden[0],
348 /** If gradient must be applied and its not gradient mode, calculate
351 if (!dynamic_training_opt.isGradientMode() && apply_gradient)
352 node->calcGradient();
354 model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder()));
355 PROFILE_MEM_ANNOTATE("CalcDerivative: " + node->getName());
357 if (stop_cb(nullptr)) {
361 if (node->needsCalcDerivative())
362 node->calcDerivative();
364 model_graph.flushCacheExcept(std::get<3>(node->getExecutionOrder()));
365 PROFILE_MEM_ANNOTATE("ApplyGradient: " + node->getName());
367 if (apply_gradient) {
368 /// Apply gradient only at the end of the last shared weight access
369 model_graph.applyGradients(
370 node.get(), [iteration, opt_ = opt.get()](Weight &w) {
371 w.calcRegularizationGradient();
372 w.calcWeightDecayGradient();
373 RunOptimizerContext opt_context(&w, iteration,
374 opt_->getLearningRate(iteration));
375 opt_->applyGradient(opt_context);
380 std::function<void(Weight &, int)> apply_grad_clip_op =
381 [opt_ = opt.get()](Weight &w, int iteration) -> void {
382 w.calcRegularizationGradient();
383 w.calcWeightDecayGradient();
384 RunOptimizerContext opt_context(&w, iteration,
385 opt_->getLearningRate(iteration));
386 opt_->applyGradient(opt_context);
389 model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op,
393 void NeuralNetwork::save(const std::string &file_path,
394 ml::train::ModelFormat format) {
395 NNTR_THROW_IF(!initialized, std::runtime_error)
396 << "Cannot save model if not initialized yet, path: " << file_path
397 << " format: " << static_cast<unsigned>(format);
399 /// @todo this switch case should be delegating the function call only. It's
400 /// not delegating for now as required logics are managable for now.
402 case ml::train::ModelFormat::MODEL_FORMAT_BIN: {
403 auto model_file = checkedOpenStream<std::ofstream>(
404 file_path, std::ios::out | std::ios::binary | std::ios::trunc);
405 for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
406 (*iter)->save(model_file);
408 if (opt && istrequal(opt->getType(), "adam")) {
409 std::string adam = "adam";
410 model_file.write(adam.c_str(), 4);
411 for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
413 (*iter)->save(model_file, true);
417 model_file.write((char *)&epoch_idx, sizeof(epoch_idx));
418 model_file.write((char *)&iter, sizeof(iter));
423 case ml::train::ModelFormat::MODEL_FORMAT_INI:
424 saveModelIni(file_path);
427 case ml::train::ModelFormat::MODEL_FORMAT_INI_WITH_BIN: {
428 auto old_save_path = std::get<props::SavePath>(model_flex_props);
430 file_path.substr(0, file_path.find_last_of('.')) + ".bin";
432 std::get<props::SavePath>(model_flex_props).set(bin_file_name);
433 save(file_path, ml::train::ModelFormat::MODEL_FORMAT_INI);
434 save(bin_file_name, ml::train::ModelFormat::MODEL_FORMAT_BIN);
435 std::get<props::SavePath>(model_flex_props) = old_save_path;
439 throw nntrainer::exception::not_supported(
440 "saving with given format is not supported yet");
444 void NeuralNetwork::load(const std::string &file_path,
445 ml::train::ModelFormat format) {
446 /// @todo this switch case should be delegating the function call only. It's
447 /// not delegating for now as required logics are managable for now.
449 case ml::train::ModelFormat::MODEL_FORMAT_BIN: {
450 NNTR_THROW_IF(!initialized, std::runtime_error)
451 << "Cannot load if not initialized yet, path: " << file_path
452 << " format: " << static_cast<unsigned>(format);
454 auto model_file = checkedOpenStream<std::ifstream>(
455 file_path, std::ios::in | std::ios::binary);
456 for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
457 (*iter)->read(model_file);
460 /// this is assuming that the failure is allowed at the end of the file
461 /// read. so, after this line, additional read shouldn't be called
462 if (opt && istrequal(opt->getType(), "adam")) {
463 std::string opt_type;
465 model_file.read((char *)&opt_type[0], 4);
466 if (istrequal(opt_type, "adam")) {
467 for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
469 (*iter)->read(model_file, true);
474 checkedRead(model_file, (char *)&epoch_idx, sizeof(epoch_idx),
475 "[NeuralNetwork::readModel] failed to read epoch_idx");
476 checkedRead(model_file, (char *)&iter, sizeof(iter),
477 "[NeuralNetwork::readModel] failed to read iteration");
479 std::cerr << "failed to read additional data like optimizer variable, "
480 "iteration, proceeding with default\n";
483 ml_logi("read modelfile: %s", file_path.c_str());
486 case ml::train::ModelFormat::MODEL_FORMAT_INI_WITH_BIN: {
487 int ret = loadFromConfig(file_path);
489 auto &save_path = std::get<props::SavePath>(model_flex_props);
490 if (!save_path.empty()) {
491 checkedOpenStream<std::ifstream>(save_path,
492 std::ios::in | std::ios::binary);
493 load_path = save_path;
497 case ml::train::ModelFormat::MODEL_FORMAT_INI: {
498 int ret = loadFromConfig(file_path);
502 case ml::train::ModelFormat::MODEL_FORMAT_FLATBUFFER: {
506 throw nntrainer::exception::not_supported(
507 "loading with given format is not supported yet");
511 float NeuralNetwork::getLoss() {
514 for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
515 loss += (*iter)->getLoss();
520 void NeuralNetwork::setLoss(float l) { loss = l; }
522 NeuralNetwork &NeuralNetwork::copy(NeuralNetwork &from) {
524 model_props = from.model_props;
525 model_flex_props = from.model_flex_props;
529 model_graph.copy(from.model_graph);
534 void NeuralNetwork::saveModelIni(const std::string &file_path) {
535 NNTR_THROW_IF(isFileExist(file_path), std::invalid_argument)
536 << "There is already a file, overriding to the exisiting file is not "
540 std::vector<IniSection> sections;
542 IniSection model_section = IniSection::FromExportable("model", *this);
543 model_section.setEntry("type", "NeuralNetwork");
544 sections.push_back(model_section);
546 auto add_section_if_any = [§ions](const std::string §ion_name,
547 auto obj_ptr, auto pred) {
549 IniSection s = IniSection::FromExportable(section_name, *obj_ptr);
550 s.setEntry("type", obj_ptr->getType());
551 sections.push_back(s);
555 add_section_if_any("optimizer", opt,
556 [](const auto &obj) { return static_cast<bool>(obj); });
558 auto &[train_buffer, valid_buffer, test_buffer] = data_buffers;
559 auto data_buffer_valid = [](const auto &buffer) {
560 return buffer && buffer->isSerializable(
561 ml::train::ExportMethods::METHOD_STRINGVECTOR);
564 add_section_if_any("train_set", train_buffer, data_buffer_valid);
565 add_section_if_any("valid_set", valid_buffer, data_buffer_valid);
566 add_section_if_any("test_set", test_buffer, data_buffer_valid);
568 IniWrapper wrapper("model_saver", sections);
569 wrapper.save_ini(file_path);
571 IniGraphInterpreter interpreter;
572 interpreter.serialize(graph_representation, file_path);
575 bool NeuralNetwork::validateInput(sharedConstTensors X) {
576 auto input_dim = getInputDimension();
577 if (X.size() != input_dim.size()) {
578 ml_loge("Error: provided number of inputs %d, required %d", (int)X.size(),
579 (int)input_dim.size());
583 for (unsigned int dim = 0; dim < input_dim.size(); dim++) {
584 if (input_dim[dim] != X[dim]->getDim()) {
585 ml_loge("Error: provided input shape does not match required shape");
586 std::stringstream ss;
587 ss << X[dim]->getDim();
588 ml_loge("Provided tensor summary : %s", ss.str().c_str());
590 ss.str(std::string());
591 ss << input_dim[dim];
592 ml_loge("Required tensor summary : %s", ss.str().c_str());
600 sharedConstTensors NeuralNetwork::inference(sharedConstTensors X,
602 return inference(X, {}, free_mem);
605 sharedConstTensors NeuralNetwork::inference(sharedConstTensors X,
606 sharedConstTensors label,
608 if (model_graph.getBatchSize() != X[0]->batch()) {
609 model_graph.setBatchSize(X[0]->batch());
612 sharedConstTensors out;
613 if (!validateInput(X))
614 throw std::invalid_argument("Input validation failed.");
616 allocate(ExecutionMode::INFERENCE);
619 PROFILE_TIME_REGISTER_EVENT(nn_foward, "nn_forward");
620 PROFILE_TIME_START(nn_foward);
621 out = forwarding(X, label, false);
622 PROFILE_TIME_END(nn_foward);
626 * Free the memory needed for training before exiting.
627 * Note that this does not free the weights for the model.
628 * Weights of the model will be freed when the model is destroyed.
630 model_graph.deallocateTensors(false);
632 /** Clear the set inputs and labels */
633 model_graph.setInputsLabels({}, {});
639 NeuralNetwork::inference(unsigned int batch_size,
640 const std::vector<float *> &input,
641 const std::vector<float *> &label) {
642 sharedConstTensors input_tensors, output_tensors;
643 auto in_dim = getInputDimension();
645 input_tensors.reserve(input.size());
646 for (unsigned int idx = 0; idx < in_dim.size(); idx++) {
647 in_dim[idx].batch(batch_size);
648 input_tensors.emplace_back(MAKE_SHARED_TENSOR(Tensor::Map(
649 input[idx], in_dim[idx].getDataLen() * sizeof(float), in_dim[idx], 0)));
652 if (!label.empty()) {
653 sharedConstTensors label_tensors;
654 auto label_dim = getOutputDimension();
655 label_tensors.reserve(label.size());
656 for (unsigned int idx = 0; idx < label_dim.size(); idx++) {
657 label_dim[idx].batch(batch_size);
658 label_tensors.emplace_back(MAKE_SHARED_TENSOR(
659 Tensor::Map(label[idx], label_dim[idx].getDataLen() * sizeof(float),
660 label_dim[idx], 0)));
662 output_tensors = inference(input_tensors, label_tensors, false);
664 output_tensors = inference(input_tensors, false);
667 std::vector<float *> output;
668 output.reserve(output_tensors.size());
670 for (auto &out : output_tensors) {
671 auto out_t = *out.get();
672 output.push_back(out_t.getData());
678 int NeuralNetwork::setDataset(const DatasetModeType &mode,
679 std::shared_ptr<ml::train::Dataset> dataset) {
680 return setDataBuffer(mode, std::static_pointer_cast<DataBuffer>(dataset));
683 int NeuralNetwork::allocate(ExecutionMode mode) {
684 model_graph.deallocateTensors();
685 model_graph.allocateTensors(mode);
687 return ML_ERROR_NONE;
690 int NeuralNetwork::deallocate() {
691 model_graph.deallocateTensors(true);
693 return ML_ERROR_NONE;
696 int NeuralNetwork::train(const std::vector<std::string> &values,
697 std::function<bool(void *)> stop_cb) {
698 int status = ML_ERROR_NONE;
700 if (data_buffers[static_cast<int>(DatasetModeType::MODE_TRAIN)] == nullptr) {
701 ml_loge("Cannot initialize the model without the train data buffer.");
702 return ML_ERROR_INVALID_PARAMETER;
706 ml_loge("Cannot train network without optimizer.");
707 return ML_ERROR_INVALID_PARAMETER;
710 setTrainConfig(values);
712 /** set batch size just before training */
713 model_graph.setBatchSize(
714 std::get<props::TrainingBatchSize>(model_flex_props));
716 status = allocate(ExecutionMode::TRAIN);
719 status = train_run(stop_cb);
723 * Free the memory needed for training before exiting.
724 * Note that this does not free the weights for the model.
725 * Weights of the model will be freed when the model is destroyed.
727 model_graph.deallocateTensors(false);
732 * @brief Run NeuralNetwork train with callback function by user
734 int NeuralNetwork::train_run(std::function<bool(void *userdata)> stop_cb) {
735 int status = ML_ERROR_NONE;
737 if (!std::get<props::ContinueTrain>(model_flex_props)) {
740 for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
741 (*iter)->clearOptVar();
745 auto batch_size = std::get<props::TrainingBatchSize>(model_flex_props);
747 auto const &outputs = model_graph.getOutputTensors();
748 auto in_dims = model_graph.getInputDimension();
749 auto label_dims = model_graph.getOutputDimension();
751 auto &[train_buffer, valid_buffer, test_buffer] = data_buffers;
753 if (train_buffer == nullptr) {
754 ml_loge("[NeuralNetworks] there is no train dataset!");
755 return ML_ERROR_INVALID_PARAMETER;
759 * @brief run a single epoch with given callback, @a auto is used instead of
760 * std::function for performance measure
761 * @param buffer buffer to run
762 * @param shuffle whether to shuffle or not
763 * @param on_iteration_fetch function that will recieve reference to stat,
764 * buffer which will be called every time data is fetched and set
765 * @param on_epoch_end function that will recieve reference to stat,
766 * buffer which will be called on the epoch end
768 auto run_epoch = [this, &in_dims, &label_dims, &outputs, batch_size](
769 DataBuffer *buffer, bool shuffle,
770 auto &&on_iteration_fetch, auto &&on_iteration_update_stat,
771 auto &&on_epoch_end) {
772 /// @todo managing metrics must be handled here as well!! for now it is
773 /// handled in individual callbacks
775 std::future<std::shared_ptr<IterationQueue>> future_iq =
776 buffer->startFetchWorker(in_dims, label_dims, shuffle);
778 ScopedView<Iteration> iter_view = buffer->fetch();
779 if (iter_view.isEmpty()) {
782 auto &iteration = iter_view.get();
783 if (iteration.batch() != batch_size) {
784 /// @todo support partial batch
788 auto const &labels = iteration.getLabelsRef();
789 auto const &inputs = iteration.getInputsRef();
790 model_graph.setInputsLabels(inputs, labels);
792 on_iteration_fetch(stat, *buffer);
793 on_iteration_update_stat(stat, outputs, labels);
796 on_epoch_end(stat, *buffer);
798 if (stat.num_iterations == 0) {
799 throw std::runtime_error("No data came while buffer ran");
805 auto train_for_iteration = [this, stop_cb](RunStats &stat,
806 DataBuffer &buffer) {
807 forwarding(true, stop_cb);
808 backwarding(iter++, stop_cb);
810 // To avoid unconsidered memory leak, we need to clear the cache
811 model_graph.flushCache();
813 if (!stop_cb(nullptr)) {
814 std::cout << "#" << epoch_idx << "/" << getEpochs();
815 ml_logi("# %d / %d", epoch_idx, getEpochs());
816 auto loss = getLoss();
817 buffer.displayProgress(stat.num_iterations, loss);
821 auto update_train_stat = [this](RunStats &stat,
822 const std::vector<Tensor> &outputs,
823 const std::vector<Tensor> &labels) {
824 stat.loss += getLoss();
825 stat.num_iterations++;
828 auto train_epoch_end = [this, stop_cb](RunStats &stat, DataBuffer &buffer) {
829 stat.loss /= static_cast<float>(stat.num_iterations);
830 auto &save_path = std::get<props::SavePath>(model_flex_props);
831 if (!stop_cb(nullptr)) {
832 if (!save_path.empty()) {
833 save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
836 std::cout << "#" << epoch_idx << "/" << getEpochs()
837 << " - Training Loss: " << stat.loss;
838 ml_logi("# %d / %d - Training Loss: %f", epoch_idx, getEpochs(),
840 ml_logd("[NNTrainer] Training epoch %d / %d finished successfully.",
841 epoch_idx, getEpochs());
843 ml_logd("[NNTrainer] Training stopped by stop callback function during "
849 auto eval_for_iteration = [this, batch_size](RunStats &stat,
850 DataBuffer &buffer) {
854 auto update_eval_stat = [batch_size, &update_train_stat](
855 RunStats &stat, const std::vector<Tensor> &outputs,
856 const std::vector<Tensor> &labels) {
857 auto model_out = outputs[0].argmax();
858 auto label_out = labels[0].argmax();
860 for (unsigned int b = 0; b < batch_size; b++) {
861 if (model_out[b] == label_out[b])
862 stat.num_correct_predictions++;
865 update_train_stat(stat, outputs, labels);
868 auto eval_epoch_end = [this, batch_size, max_acc = 0.0f,
869 min_loss = std::numeric_limits<float>::max()](
870 RunStats &stat, DataBuffer &buffer) mutable {
871 stat.loss /= static_cast<float>(stat.num_iterations);
872 stat.accuracy = stat.num_correct_predictions /
873 static_cast<float>(stat.num_iterations * batch_size) *
876 if (stat.accuracy > max_acc ||
877 (stat.accuracy == max_acc && stat.loss < min_loss)) {
878 max_acc = stat.accuracy;
879 /// @note this is not actually 'the' min loss for whole time but records
881 min_loss = stat.loss;
882 auto &save_best_path = std::get<props::SaveBestPath>(model_flex_props);
883 if (!save_best_path.empty()) {
884 save(save_best_path);
887 std::cout << " >> [ Accuracy: " << stat.accuracy
888 << "% - Validation Loss : " << stat.loss << " ]";
889 ml_logi("[ Accuracy: %.2f %% - Validataion Loss: %.5f", stat.accuracy,
893 PROFILE_MEM_ANNOTATE("TRAIN START");
894 auto epochs = getEpochs();
895 ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
896 epoch_idx + 1, getEpochs());
897 for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
898 if (stop_cb(nullptr)) {
902 training = run_epoch(train_buffer.get(), true, train_for_iteration,
903 update_train_stat, train_epoch_end);
905 validation = run_epoch(valid_buffer.get(), false, eval_for_iteration,
906 update_eval_stat, eval_epoch_end);
910 PROFILE_MEM_ANNOTATE("TRAIN END");
913 std::cout << "Evaluation with test data...\n";
914 testing = run_epoch(test_buffer.get(), false, eval_for_iteration,
915 update_eval_stat, eval_epoch_end);
918 /** Clear the set inputs and labels */
919 model_graph.setInputsLabels({}, {});
924 void swap(NeuralNetwork &lhs, NeuralNetwork &rhs) {
928 swap(lhs.model_props, rhs.model_props);
929 swap(lhs.model_flex_props, rhs.model_flex_props);
930 swap(lhs.load_path, rhs.load_path);
931 swap(lhs.epoch_idx, rhs.epoch_idx);
932 swap(lhs.iter, rhs.iter);
933 swap(lhs.loss, rhs.loss);
934 swap(lhs.opt, rhs.opt);
935 swap(lhs.data_buffers, rhs.data_buffers);
936 swap(lhs.initialized, rhs.initialized);
937 swap(lhs.model_graph, rhs.model_graph);
938 swap(lhs.graph_representation, rhs.graph_representation);
939 swap(lhs.compiled, rhs.compiled);
940 swap(lhs.loadedFromConfig, rhs.loadedFromConfig);
944 int NeuralNetwork::addLayer(NodeType layer) {
945 int status = ML_ERROR_NONE;
948 return ML_ERROR_NOT_SUPPORTED;
951 /** Insert the layer to the graph */
952 model_graph.addLayer(layer);
953 graph_representation.push_back(layer);
958 NeuralNetwork &NeuralNetwork::copyConfiguration(NeuralNetwork &from) {
960 model_props = from.model_props;
961 model_flex_props = from.model_flex_props;
965 NetworkGraph f_graph = from.getNetworkGraph();
966 for (auto &l_node : f_graph.getLayerNodes()) {
967 addLayer(static_cast<std::shared_ptr<ml::train::Layer>>(
968 l_node->cloneConfiguration()));
974 NeuralNetwork::GraphType
975 NeuralNetwork::getUnsortedLayers(const std::string &input_layer,
976 const std::string &output_layer) {
977 return model_graph.getUnsortedLayers(input_layer, output_layer);
980 int NeuralNetwork::setOptimizer(
981 std::shared_ptr<ml::train::Optimizer> optimizer) {
983 return ML_ERROR_NOT_SUPPORTED;
986 opt = std::static_pointer_cast<OptimizerWrapped>(optimizer);
988 return ML_ERROR_NONE;
991 int NeuralNetwork::setDataBuffer(const DatasetModeType &mode,
992 std::shared_ptr<DataBuffer> data_buffer) {
993 if (data_buffer == nullptr) {
994 return ML_ERROR_INVALID_PARAMETER;
997 this->data_buffers[static_cast<int>(mode)] = data_buffer;
999 return ML_ERROR_NONE;
1002 int NeuralNetwork::getLayer(const char *name,
1003 std::shared_ptr<ml::train::Layer> *layer) {
1004 // We provide the layer change through the api with user's responsibility.
1007 // ml_loge("Cannot get compiled layer.");
1008 // return ML_ERROR_NOT_SUPPORTED;
1011 *layer = std::static_pointer_cast<ml::train::Layer>(
1012 model_graph.getLayerNode(std::string(name)));
1013 return ML_ERROR_NONE;
1016 void NeuralNetwork::printMetrics(std::ostream &out, unsigned int flags) {
1018 case ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS:
1019 out << training.loss << std::endl;
1022 case ML_TRAIN_SUMMARY_MODEL_VALID_LOSS:
1023 out << validation.loss << std::endl;
1026 case ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY:
1027 out << validation.accuracy << std::endl;
1035 void NeuralNetwork::printPreset(std::ostream &out, unsigned int preset) {
1036 /** print neuralnet metrics */
1037 printMetrics(out, preset);
1038 if (preset > ML_TRAIN_SUMMARY_TENSOR)
1041 LayerNode::PrintPreset layer_preset = LayerNode::PrintPreset::PRINT_NONE;
1043 ///@todo match flags with preset
1044 unsigned int flags = PRINT_INST_INFO | PRINT_GRAPH_INFO | PRINT_PROP |
1045 PRINT_OPTIMIZER | PRINT_METRIC;
1048 case ML_TRAIN_SUMMARY_TENSOR:
1049 layer_preset = LayerNode::PrintPreset::PRINT_ALL;
1051 case ML_TRAIN_SUMMARY_LAYER:
1052 layer_preset = initialized ? LayerNode::PrintPreset::PRINT_SUMMARY
1053 : LayerNode::PrintPreset::PRINT_SUMMARY_META;
1055 case ML_TRAIN_SUMMARY_MODEL:
1058 throw std::invalid_argument("given verbosity is invalid");
1061 print(out, flags, layer_preset);
1064 void NeuralNetwork::addWithReferenceLayers(
1065 const std::vector<std::shared_ptr<ml::train::Layer>> &reference,
1066 const std::string &scope, const std::vector<std::string> &input_layers,
1067 const std::vector<std::string> &start_layers,
1068 const std::vector<std::string> &end_layers,
1069 ml::train::ReferenceLayersType type,
1070 const std::vector<std::string> &type_properties) {
1071 std::vector<NodeType> casted_reference;
1072 casted_reference.reserve(reference.size());
1073 for (auto &node : reference) {
1074 casted_reference.emplace_back(std::static_pointer_cast<LayerNode>(node));
1077 addWithReferenceLayers(casted_reference, scope, input_layers, start_layers,
1078 end_layers, type, type_properties);
1080 void NeuralNetwork::addWithReferenceLayers(
1081 const std::vector<std::shared_ptr<LayerNode>> &reference,
1082 const std::string &scope, const std::vector<std::string> &input_layers,
1083 const std::vector<std::string> &start_layers,
1084 const std::vector<std::string> &end_layers,
1085 ml::train::ReferenceLayersType type,
1086 const std::vector<std::string> &type_properties) {
1087 /// @todo below configuration should be extracted as a free function to make
1088 /// it more testable, and reused inside graph interpreter
1090 /// @note we can exploit connection to connection more fine grained, for now
1091 /// it is not supported but we can easily make this supported
1092 std::vector<std::shared_ptr<LayerNode>> nodes;
1093 nodes.reserve(reference.size());
1094 for (auto &node : reference) {
1095 nodes.push_back(node->cloneConfiguration());
1099 std::vector<Connection>(start_layers.begin(), start_layers.end());
1101 std::vector<Connection>(input_layers.begin(), input_layers.end());
1103 std::vector<Connection>(end_layers.begin(), end_layers.end());
1105 std::vector<std::unique_ptr<GraphRealizer>> realizers;
1107 realizers.emplace_back(new PreviousInputRealizer(start_conns));
1108 realizers.emplace_back(new SliceRealizer(start_conns, end_conns));
1110 if (!input_conns.empty()) {
1111 realizers.emplace_back(new InputRealizer(start_conns, input_conns));
1114 if (type == ml::train::ReferenceLayersType::RECURRENT) {
1115 realizers.emplace_back(
1116 new RecurrentRealizer(type_properties, input_conns, end_conns));
1119 if (!scope.empty()) {
1120 realizers.emplace_back(
1121 new RemapRealizer([&scope, &input_conns](std::string &name) {
1122 for (auto &i : input_conns) {
1123 if (i.getName() == name) {
1127 name = scope + "/" + name;
1131 for (auto &realizer : realizers) {
1132 nodes = realizer->realize(nodes);
1135 for (auto &node : nodes) {
1140 void NeuralNetwork::exportTo(Exporter &exporter,
1141 const ml::train::ExportMethods &method) const {
1142 exporter.saveResult(model_props, method, this);
1143 exporter.saveResult(model_flex_props, method, this);
1146 void NeuralNetwork::print(std::ostream &out, unsigned int flags,
1147 LayerNode::PrintPreset layerPrintPreset) {
1148 if (flags & PRINT_INST_INFO) {
1149 /// @todo uncomment this after implement getProperty (#1875)
1150 // out << "===================";
1151 // printInstance(out, this);
1154 if (flags & PRINT_GRAPH_INFO) {
1155 unsigned int total_col_size = 80;
1156 std::vector<unsigned int> column_size = {20, 20, 20, 20};
1157 auto print_graph_layer_info =
1158 [column_size](std::ostream &out, std::vector<std::string> layer_info) {
1159 auto trim_string = [](std::string str, unsigned int column_width) {
1160 return str.size() < column_width ? str
1161 : str.substr(0, column_width - 1);
1164 for (unsigned int i = 0; i < column_size.size(); ++i) {
1165 out << std::setw(column_size[i])
1166 << trim_string(layer_info[i], column_size[i]);
1171 out << std::string(total_col_size, '=') << '\n';
1172 print_graph_layer_info(
1173 out, {"Layer name", "Layer type", "Input dimension", "Input layer"});
1174 out << std::string(total_col_size, '=') << '\n';
1176 props::GenericShape dim_property;
1178 for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
1180 std::string first_dim;
1181 if (iter->getInputDimensions().empty()) {
1184 dim_property.set(iter->getInputDimensions()[0]);
1185 first_dim = to_string(dim_property);
1187 const std::vector<std::string> &input_layer_names =
1188 iter->getInputConnections();
1189 std::string first_input_name =
1190 input_layer_names.empty() ? "" : input_layer_names[0];
1191 print_graph_layer_info(
1192 out, {iter->getName(), iter->getType(), first_dim, first_input_name});
1193 for (unsigned int i = 1; i < input_layer_names.size(); ++i) {
1194 dim_property.set(iter->getInputDimensions()[i]);
1195 print_graph_layer_info(
1196 out, {"", "", to_string(dim_property), input_layer_names[i]});
1198 out << std::string(total_col_size,
1199 iter == model_graph.cend() - 1 ? '=' : '-')
1203 auto &input_connection =
1204 std::get<std::vector<props::InputConnection>>(model_props);
1205 auto model_input = std::vector<Connection>(input_connection.begin(),
1206 input_connection.end());
1207 auto is_actually_an_input_node =
1208 [model_input](graph_const_iterator<LayerNode> node) {
1209 return node->hasInputShapeProperty() or
1210 std::any_of(model_input.begin(), model_input.end(),
1211 [node](auto &conn) {
1212 return node->getName() == conn.getName();
1216 for (auto iter = model_graph.cbegin(); iter != model_graph.cend();
1218 const std::vector<std::string> &input_layer_names =
1219 iter->getInputConnections();
1221 /// @brief connection information.
1222 // Intended comment.
1223 // std::string first_input_name =
1224 // input_layer_names.empty()
1225 // ? (is_actually_an_input_node(iter) || iter ==
1226 // model_graph.cbegin()
1228 // : (iter - 1)->getName())
1229 // : input_layer_names[0];
1230 print_graph_layer_info(out, {iter->getName(), iter->getType(), "", ""});
1231 for (unsigned int i = 1; i < input_layer_names.size(); ++i) {
1232 print_graph_layer_info(out, {"", "", "", ""});
1234 out << std::string(total_col_size,
1235 iter == model_graph.cend() - 1 ? '=' : '-')
1241 if (flags & PRINT_PROP) {
1242 /// @todo print neuralnet property
1243 /// @todo print mode (if it is eval or training)
1246 if (flags & PRINT_OPTIMIZER) {
1247 /// @todo print optimizer (with print optimizer prop)
1250 if (flags & PRINT_METRIC) {
1251 /// @todo print metric (currently it is done at printPreset as a
1253 /// @todo print loss function when it is not initialized. (if it is
1254 /// initialized, loss layer will be printed)
1257 if (model_graph.empty()) {
1258 out << "model is empty!" << std::endl;
1262 /** print layer properties */
1263 for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++)
1264 (*iter)->printPreset(out, layerPrintPreset);
1266 /// @todo Add status to check neuralnet has been run. #290
1269 void NeuralNetwork::forEachLayer(
1270 std::function<void(ml::train::Layer &, RunLayerContext &, void *)> fn,
1272 for (auto iter = model_graph.cbegin(); iter != model_graph.cend(); iter++) {
1273 auto ln = std::static_pointer_cast<LayerNode>(*iter).get();
1274 fn(*ln, std::forward<RunLayerContext &>(ln->getRunContext()), user_data);
1278 void NeuralNetwork::exports(const ml::train::ExportMethods &method,
1279 const std::string file_path) {
1281 case ml::train::ExportMethods::METHOD_TFLITE: {
1282 #ifdef ENABLE_TFLITE_INTERPRETER
1283 nntrainer::TfliteInterpreter interpreter;
1285 /// We will call "serialize" method for the model which is already trained
1286 /// or allocated. So, we need to call deallocateTensors first to make sure
1287 /// `dealloc_weights == false`
1288 model_graph.deallocateTensors();
1289 model_graph.allocateTensors(ExecutionMode::INFERENCE);
1290 interpreter.serialize(graph_representation, file_path);
1291 model_graph.deallocateTensors();
1293 throw std::runtime_error{
1294 "Export methods METHOD_TFLITE is not supported. Please enable tflite "
1295 "interpreter by set ENABLE_TFLITE_INTERPRETER=1"};
1299 case ml::train::ExportMethods::METHOD_FLATBUFFER: {
1301 model_graph.deallocateTensors();
1302 model_graph.allocateTensors(ExecutionMode::TRAIN);
1306 throw std::runtime_error{"Unsupported export method"};
1309 } /* namespace nntrainer */