From 9e644681aec12db981599e71f3828fb15f437646 Mon Sep 17 00:00:00 2001 From: Inki Dae Date: Thu, 24 Feb 2022 16:00:57 +0900 Subject: [PATCH] clean up updating data set again [Version] : 0.1.4-0 [Issue type] : cleanup Cleaned up UpdateTrainData and UpdateVerifyData callback functions again by moving these functions into TrainingNntrainer class. With this patch, the public members for accessing private members are dropped. Change-Id: I199487aeaf58077a367f91850986e5c3dac01a61 Signed-off-by: Inki Dae --- packaging/training-engine-nntrainer.spec | 2 +- src/training_engine_nntrainer.cpp | 118 ++++++++++------------- src/training_engine_nntrainer_private.h | 33 +------ 3 files changed, 58 insertions(+), 95 deletions(-) diff --git a/packaging/training-engine-nntrainer.spec b/packaging/training-engine-nntrainer.spec index 0db9391..53e1470 100644 --- a/packaging/training-engine-nntrainer.spec +++ b/packaging/training-engine-nntrainer.spec @@ -1,6 +1,6 @@ Name: training-engine-nntrainer Summary: Training engine NNTrainer backend -Version: 0.1.3 +Version: 0.1.4 Release: 0 Group: Multimedia/Libraries License: Apache-2.0 diff --git a/src/training_engine_nntrainer.cpp b/src/training_engine_nntrainer.cpp index a96fcac..4e59b32 100644 --- a/src/training_engine_nntrainer.cpp +++ b/src/training_engine_nntrainer.cpp @@ -27,109 +27,97 @@ namespace TrainingEngineImpl { namespace NntrainerImpl { - int UpdateTrainData(float **data, float **label, bool *last, void *user_data) + int UpdateTrainDataCb(float **data, float **label, bool *last, void *user_data) { + if (!user_data) { + LOGE("user_data is null."); + return TRAINING_ENGINE_ERROR_INVALID_PARAMETER; + } + auto engine = static_cast(user_data); - if (engine->IsLastTrainDataIdx()) { - // if last is false then NNTrainer starts training with given data and label. - // Otherwise, it finalizes the training. - // 'last = true' means that all data and labels have been passed. - *last = true; - engine->ClearTrainDataIdx(); + return engine->UpdateTrainDataSet(data, label, last); + } - return ML_ERROR_NONE; + int UpdateVerifyDataCb(float **data, float **label, bool *last, void *user_data) + { + if (!user_data) { + LOGE("user_data is null."); + return TRAINING_ENGINE_ERROR_INVALID_PARAMETER; } - auto train = engine->GetTrainDataSet(); - const auto train_data_idx = engine->GetTrainDataIdx(); - - std::copy(train[train_data_idx].data.begin(), train[train_data_idx].data.end(), data[0]); - std::copy(train[train_data_idx].label.begin(), train[train_data_idx].label.end(), label[0]); - - engine->IncreaseTrainDataIdx(); - *last = false; + auto engine = static_cast(user_data); - LOGI("Updated train data."); + return engine->UpdateVerifyDataSet(data, label, last); + } - return ML_ERROR_NONE; + TrainingNntrainer::TrainingNntrainer(void) : + _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx() + { } - int UpdateVerifyData(float **data, float **label, bool *last, void *user_data) + TrainingNntrainer::~TrainingNntrainer() { - auto engine = static_cast(user_data); + } - if (engine->IsLastVerifyDataIdx()) { + int TrainingNntrainer::UpdateTrainDataSet(float **data, float **label, bool *last) + { + if (IsLastTrainDataIdx()) { // if last is false then NNTrainer starts training with given data and label. // Otherwise, it finalizes the training. // 'last = true' means that all data and labels have been passed. *last = true; - engine->ClearVerifyDataIdx(); + _train_data_idx = 0; - return ML_ERROR_NONE; + return TRAINING_ENGINE_ERROR_NONE; } - auto verify = engine->GetVerifyDataSet(); - const auto verify_data_idx = engine->GetVerifyDataIdx(); - - std::copy(verify[verify_data_idx].data.begin(), verify[verify_data_idx].data.end(), data[0]); - std::copy(verify[verify_data_idx].label.begin(), verify[verify_data_idx].label.end(), label[0]); + std::copy(_train_data_sets[_train_data_idx].data.begin(), + _train_data_sets[_train_data_idx].data.end(), data[0]); + std::copy(_train_data_sets[_train_data_idx].label.begin(), + _train_data_sets[_train_data_idx].label.end(), label[0]); - engine->IncreaseVerifyDataIdx(); + _train_data_idx++; *last = false; - LOGI("Updated verify data."); + LOGI("Updated train data."); - return ML_ERROR_NONE; + return TRAINING_ENGINE_ERROR_NONE; } - TrainingNntrainer::TrainingNntrainer(void) : - _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx() + int TrainingNntrainer::UpdateVerifyDataSet(float **data, float **label, bool *last) { - } + if (IsLastVerifyDataIdx()) { + // if last is false then NNTrainer starts training with given data and label. + // Otherwise, it finalizes the training. + // 'last = true' means that all data and labels have been passed. + *last = true; + _verify_data_idx = 0; - TrainingNntrainer::~TrainingNntrainer() - { - } + return TRAINING_ENGINE_ERROR_NONE; + } - void TrainingNntrainer::IncreaseTrainDataIdx() - { - _train_data_idx++; - } + std::copy(_verify_data_sets[_verify_data_idx].data.begin(), + _verify_data_sets[_verify_data_idx].data.end(), data[0]); + std::copy(_verify_data_sets[_verify_data_idx].label.begin(), + _verify_data_sets[_verify_data_idx].label.end(), label[0]); - void TrainingNntrainer::IncreaseVerifyDataIdx() - { _verify_data_idx++; - } - - int TrainingNntrainer::GetTrainDataIdx() - { - return _train_data_idx; - } - - int TrainingNntrainer::GetVerifyDataIdx() - { - return _verify_data_idx; - } + *last = false; - void TrainingNntrainer::ClearTrainDataIdx() - { - _train_data_idx = 0; - } + LOGI("Updated verify data."); - void TrainingNntrainer::ClearVerifyDataIdx() - { - _verify_data_idx = 0; + return TRAINING_ENGINE_ERROR_NONE; } bool TrainingNntrainer::IsLastTrainDataIdx() { - return GetTrainDataIdx() == GetTrainDataCnt(); + return _train_data_idx == static_cast(_train_data_sets.size()); } bool TrainingNntrainer::IsLastVerifyDataIdx() { - return GetVerifyDataIdx() == GetVerifyDataCnt(); + return _verify_data_idx == static_cast(_verify_data_sets.size()); } int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity) @@ -510,14 +498,14 @@ namespace NntrainerImpl case TRAINING_DATASET_TYPE_TRAIN: _train_data_sets.push_back({data,label}); dataset_mode = ML_TRAIN_DATASET_MODE_TRAIN; - func = UpdateTrainData; + func = UpdateTrainDataCb; break; case TRAINING_DATASET_TYPE_VERIFY: _verify_data_sets.push_back({data,label}); dataset_mode = ML_TRAIN_DATASET_MODE_VALID; - func = UpdateVerifyData; + func = UpdateVerifyDataCb; break; diff --git a/src/training_engine_nntrainer_private.h b/src/training_engine_nntrainer_private.h index 2083111..c713b18 100644 --- a/src/training_engine_nntrainer_private.h +++ b/src/training_engine_nntrainer_private.h @@ -42,6 +42,8 @@ namespace NntrainerImpl class TrainingNntrainer : public TrainingEngineInterface::Common::ITrainingEngineCommon { private: + bool IsLastTrainDataIdx(); + bool IsLastVerifyDataIdx(); ml_train_layer_type_e ConvertLayerType(training_engine_layer_type_e type); ml_train_optimizer_type_e ConvertOptimizerType(training_engine_optimizer_type_e type); std::vector _train_data_sets; @@ -73,35 +75,8 @@ namespace NntrainerImpl int SetDataset(training_engine_model *model, const training_engine_dataset *dataset) final; int CompileModel(const training_engine_model *model, training_engine_compile_property &property) final; int TrainModel(const training_engine_model *model, training_engine_model_property &property) final; - - std::vector& GetTrainDataSet(void) - { - return _train_data_sets; - } - - std::vector& GetVerifyDataSet(void) - { - return _verify_data_sets; - } - - int GetTrainDataCnt(void) - { - return static_cast(_train_data_sets.size()); - } - - int GetVerifyDataCnt(void) - { - return static_cast(_verify_data_sets.size()); - } - - void IncreaseTrainDataIdx(); - void IncreaseVerifyDataIdx(); - int GetTrainDataIdx(); - int GetVerifyDataIdx(); - void ClearTrainDataIdx(); - void ClearVerifyDataIdx(); - bool IsLastTrainDataIdx(); - bool IsLastVerifyDataIdx(); + int UpdateTrainDataSet(float **data, float **label, bool *last); + int UpdateVerifyDataSet(float **data, float **label, bool *last); }; } /* NntrainerImpl */ } /* TrainingEngineImpl */ -- 2.34.1