clean up updating data set 39/271639/2
authorInki Dae <inki.dae@samsung.com>
Thu, 24 Feb 2022 03:10:48 +0000 (12:10 +0900)
committerInki Dae <inki.dae@samsung.com>
Thu, 24 Feb 2022 03:35:05 +0000 (12:35 +0900)
[Version] : 0.1.3-0
[Issue type] : cleanup

Cleaned up UpdateTrainData and UpdateVerifyData callback functions.

Change-Id: If5c95dc3b402096db57edff521a1528888b5cf8d
Signed-off-by: Seungbae Shin <seungbae.shin@samsung.com>
Fixed indexing data set vector and using more meaningful function name.
Signed-off-by: Inki Dae <inki.dae@samsung.com>
packaging/training-engine-nntrainer.spec
src/training_engine_nntrainer.cpp
src/training_engine_nntrainer_private.h

index b2dc0524f48514b9077bc3f6f4559a13582b50ac..0db93911922ef02eff27adebe4eefccbbd786a44 100644 (file)
@@ -1,6 +1,6 @@
 Name:       training-engine-nntrainer
 Summary:    Training engine NNTrainer backend
-Version:    0.1.2
+Version:    0.1.3
 Release:    0
 Group:      Multimedia/Libraries
 License:    Apache-2.0
index 0b9c98db27f427d3e2a6269c50c9b48d5c6b4ad8..a96fcac0218ecc0245893f70bbe03d91de65068d 100644 (file)
@@ -30,9 +30,8 @@ namespace NntrainerImpl
        int UpdateTrainData(float **data, float **label, bool *last, void *user_data)
        {
                auto engine = static_cast<TrainingNntrainer *>(user_data);
-               auto train = engine->GetTrainDataSet();
 
-               if (engine->GetTrainDataIdx() == engine->GetTrainDataCnt()) {
+               if (engine->IsLastTrainDataIdx()) {
                        // if last is false then NNTrainer starts training with given data and label.
                        // Otherwise, it finalizes the training.
                        // 'last = true' means that all data and labels have been passed.
@@ -42,14 +41,11 @@ namespace NntrainerImpl
                        return ML_ERROR_NONE;
                }
 
-               int idx = 0;
-
-               for (auto v : train[engine->GetTrainDataIdx()].data)
-                       data[0][idx++] = v;
+               auto train = engine->GetTrainDataSet();
+               const auto train_data_idx = engine->GetTrainDataIdx();
 
-               idx = 0;
-               for (auto l : train[engine->GetTrainDataIdx()].label)
-                       label[0][idx++] = l;
+               std::copy(train[train_data_idx].data.begin(), train[train_data_idx].data.end(), data[0]);
+               std::copy(train[train_data_idx].label.begin(), train[train_data_idx].label.end(), label[0]);
 
                engine->IncreaseTrainDataIdx();
                *last = false;
@@ -62,9 +58,8 @@ namespace NntrainerImpl
        int UpdateVerifyData(float **data, float **label, bool *last, void *user_data)
        {
                auto engine = static_cast<TrainingNntrainer *>(user_data);
-               auto verify = engine->GetVerifyDataSet();
 
-               if (engine->GetVerifyDataIdx() == engine->GetVerifyDataCnt()) {
+               if (engine->IsLastVerifyDataIdx()) {
                        // if last is false then NNTrainer starts training with given data and label.
                        // Otherwise, it finalizes the training.
                        // 'last = true' means that all data and labels have been passed.
@@ -74,14 +69,11 @@ namespace NntrainerImpl
                        return ML_ERROR_NONE;
                }
 
-               int idx = 0;
-
-               for (auto v : verify[engine->GetVerifyDataIdx()].data)
-                       data[0][idx++] = v;
+               auto verify = engine->GetVerifyDataSet();
+               const auto verify_data_idx = engine->GetVerifyDataIdx();
 
-               idx = 0;
-               for (auto l : verify[engine->GetVerifyDataIdx()].label)
-                       label[0][idx++] = l;
+               std::copy(verify[verify_data_idx].data.begin(), verify[verify_data_idx].data.end(), data[0]);
+               std::copy(verify[verify_data_idx].label.begin(), verify[verify_data_idx].label.end(), label[0]);
 
                engine->IncreaseVerifyDataIdx();
                *last = false;
@@ -130,6 +122,16 @@ namespace NntrainerImpl
                _verify_data_idx = 0;
        }
 
+       bool TrainingNntrainer::IsLastTrainDataIdx()
+       {
+               return GetTrainDataIdx() == GetTrainDataCnt();
+       }
+
+       bool TrainingNntrainer::IsLastVerifyDataIdx()
+       {
+               return GetVerifyDataIdx() == GetVerifyDataCnt();
+       }
+
        int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity)
        {
                LOGI("ENTER");
index c5f28dceafd36b1ae566da20d721d77e58d5197e..20831117aa25aba2d3c35dccebf82d530e9246a4 100644 (file)
@@ -100,6 +100,8 @@ namespace NntrainerImpl
                int GetVerifyDataIdx();
                void ClearTrainDataIdx();
                void ClearVerifyDataIdx();
+               bool IsLastTrainDataIdx();
+               bool IsLastVerifyDataIdx();
        };
 } /* NntrainerImpl */
 } /* TrainingEngineImpl */