From: Inki Dae Date: Thu, 24 Feb 2022 03:10:48 +0000 (+0900) Subject: clean up updating data set X-Git-Tag: submit/tizen/20220307.024724~1 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=98a8b53ee2c0342e6b6cd661b082adb18fd207a7;p=platform%2Fupstream%2Ftraining-engine-nntrainer.git clean up updating data set [Version] : 0.1.3-0 [Issue type] : cleanup Cleaned up UpdateTrainData and UpdateVerifyData callback functions. Change-Id: If5c95dc3b402096db57edff521a1528888b5cf8d Signed-off-by: Seungbae Shin Fixed indexing data set vector and using more meaningful function name. Signed-off-by: Inki Dae --- diff --git a/packaging/training-engine-nntrainer.spec b/packaging/training-engine-nntrainer.spec index b2dc052..0db9391 100644 --- a/packaging/training-engine-nntrainer.spec +++ b/packaging/training-engine-nntrainer.spec @@ -1,6 +1,6 @@ Name: training-engine-nntrainer Summary: Training engine NNTrainer backend -Version: 0.1.2 +Version: 0.1.3 Release: 0 Group: Multimedia/Libraries License: Apache-2.0 diff --git a/src/training_engine_nntrainer.cpp b/src/training_engine_nntrainer.cpp index 0b9c98d..a96fcac 100644 --- a/src/training_engine_nntrainer.cpp +++ b/src/training_engine_nntrainer.cpp @@ -30,9 +30,8 @@ namespace NntrainerImpl int UpdateTrainData(float **data, float **label, bool *last, void *user_data) { auto engine = static_cast(user_data); - auto train = engine->GetTrainDataSet(); - if (engine->GetTrainDataIdx() == engine->GetTrainDataCnt()) { + if (engine->IsLastTrainDataIdx()) { // if last is false then NNTrainer starts training with given data and label. // Otherwise, it finalizes the training. // 'last = true' means that all data and labels have been passed. @@ -42,14 +41,11 @@ namespace NntrainerImpl return ML_ERROR_NONE; } - int idx = 0; - - for (auto v : train[engine->GetTrainDataIdx()].data) - data[0][idx++] = v; + auto train = engine->GetTrainDataSet(); + const auto train_data_idx = engine->GetTrainDataIdx(); - idx = 0; - for (auto l : train[engine->GetTrainDataIdx()].label) - label[0][idx++] = l; + std::copy(train[train_data_idx].data.begin(), train[train_data_idx].data.end(), data[0]); + std::copy(train[train_data_idx].label.begin(), train[train_data_idx].label.end(), label[0]); engine->IncreaseTrainDataIdx(); *last = false; @@ -62,9 +58,8 @@ namespace NntrainerImpl int UpdateVerifyData(float **data, float **label, bool *last, void *user_data) { auto engine = static_cast(user_data); - auto verify = engine->GetVerifyDataSet(); - if (engine->GetVerifyDataIdx() == engine->GetVerifyDataCnt()) { + if (engine->IsLastVerifyDataIdx()) { // if last is false then NNTrainer starts training with given data and label. // Otherwise, it finalizes the training. // 'last = true' means that all data and labels have been passed. @@ -74,14 +69,11 @@ namespace NntrainerImpl return ML_ERROR_NONE; } - int idx = 0; - - for (auto v : verify[engine->GetVerifyDataIdx()].data) - data[0][idx++] = v; + auto verify = engine->GetVerifyDataSet(); + const auto verify_data_idx = engine->GetVerifyDataIdx(); - idx = 0; - for (auto l : verify[engine->GetVerifyDataIdx()].label) - label[0][idx++] = l; + std::copy(verify[verify_data_idx].data.begin(), verify[verify_data_idx].data.end(), data[0]); + std::copy(verify[verify_data_idx].label.begin(), verify[verify_data_idx].label.end(), label[0]); engine->IncreaseVerifyDataIdx(); *last = false; @@ -130,6 +122,16 @@ namespace NntrainerImpl _verify_data_idx = 0; } + bool TrainingNntrainer::IsLastTrainDataIdx() + { + return GetTrainDataIdx() == GetTrainDataCnt(); + } + + bool TrainingNntrainer::IsLastVerifyDataIdx() + { + return GetVerifyDataIdx() == GetVerifyDataCnt(); + } + int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity) { LOGI("ENTER"); diff --git a/src/training_engine_nntrainer_private.h b/src/training_engine_nntrainer_private.h index c5f28dc..2083111 100644 --- a/src/training_engine_nntrainer_private.h +++ b/src/training_engine_nntrainer_private.h @@ -100,6 +100,8 @@ namespace NntrainerImpl int GetVerifyDataIdx(); void ClearTrainDataIdx(); void ClearVerifyDataIdx(); + bool IsLastTrainDataIdx(); + bool IsLastVerifyDataIdx(); }; } /* NntrainerImpl */ } /* TrainingEngineImpl */