clean up updating data set again 59/271659/2 accepted/tizen/unified/20220308.015808 submit/tizen/20220307.024724 submit/tizen/20220307.041636 submit/tizen/20220308.013950
authorInki Dae <inki.dae@samsung.com>
Thu, 24 Feb 2022 07:00:57 +0000 (16:00 +0900)
committerInki Dae <inki.dae@samsung.com>
Thu, 24 Feb 2022 07:50:41 +0000 (16:50 +0900)
[Version] : 0.1.4-0
[Issue type] : cleanup

Cleaned up UpdateTrainData and UpdateVerifyData callback functions again
by moving these functions into TrainingNntrainer class. With this patch,
the public members for accessing private members are dropped.

Change-Id: I199487aeaf58077a367f91850986e5c3dac01a61
Signed-off-by: Inki Dae <inki.dae@samsung.com>
packaging/training-engine-nntrainer.spec
src/training_engine_nntrainer.cpp
src/training_engine_nntrainer_private.h

index 0db93911922ef02eff27adebe4eefccbbd786a44..53e14709d8c5ee667f40d4b14d622ff9a700caab 100644 (file)
@@ -1,6 +1,6 @@
 Name:       training-engine-nntrainer
 Summary:    Training engine NNTrainer backend
-Version:    0.1.3
+Version:    0.1.4
 Release:    0
 Group:      Multimedia/Libraries
 License:    Apache-2.0
index a96fcac0218ecc0245893f70bbe03d91de65068d..4e59b325bd143343a9b3de2cfa2c46d17ccb8c10 100644 (file)
@@ -27,109 +27,97 @@ namespace TrainingEngineImpl
 {
 namespace NntrainerImpl
 {
-       int UpdateTrainData(float **data, float **label, bool *last, void *user_data)
+       int UpdateTrainDataCb(float **data, float **label, bool *last, void *user_data)
        {
+               if (!user_data) {
+                       LOGE("user_data is null.");
+                       return TRAINING_ENGINE_ERROR_INVALID_PARAMETER;
+               }
+
                auto engine = static_cast<TrainingNntrainer *>(user_data);
 
-               if (engine->IsLastTrainDataIdx()) {
-                       // if last is false then NNTrainer starts training with given data and label.
-                       // Otherwise, it finalizes the training.
-                       // 'last = true' means that all data and labels have been passed.
-                       *last = true;
-                       engine->ClearTrainDataIdx();
+               return engine->UpdateTrainDataSet(data, label, last);
+       }
 
-                       return ML_ERROR_NONE;
+       int UpdateVerifyDataCb(float **data, float **label, bool *last, void *user_data)
+       {
+               if (!user_data) {
+                       LOGE("user_data is null.");
+                       return TRAINING_ENGINE_ERROR_INVALID_PARAMETER;
                }
 
-               auto train = engine->GetTrainDataSet();
-               const auto train_data_idx = engine->GetTrainDataIdx();
-
-               std::copy(train[train_data_idx].data.begin(), train[train_data_idx].data.end(), data[0]);
-               std::copy(train[train_data_idx].label.begin(), train[train_data_idx].label.end(), label[0]);
-
-               engine->IncreaseTrainDataIdx();
-               *last = false;
+               auto engine = static_cast<TrainingNntrainer *>(user_data);
 
-               LOGI("Updated train data.");
+               return engine->UpdateVerifyDataSet(data, label, last);
+       }
 
-               return ML_ERROR_NONE;
+       TrainingNntrainer::TrainingNntrainer(void) :
+               _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx()
+       {
        }
 
-       int UpdateVerifyData(float **data, float **label, bool *last, void *user_data)
+       TrainingNntrainer::~TrainingNntrainer()
        {
-               auto engine = static_cast<TrainingNntrainer *>(user_data);
+       }
 
-               if (engine->IsLastVerifyDataIdx()) {
+       int TrainingNntrainer::UpdateTrainDataSet(float **data, float **label, bool *last)
+       {
+               if (IsLastTrainDataIdx()) {
                        // if last is false then NNTrainer starts training with given data and label.
                        // Otherwise, it finalizes the training.
                        // 'last = true' means that all data and labels have been passed.
                        *last = true;
-                       engine->ClearVerifyDataIdx();
+                       _train_data_idx = 0;
 
-                       return ML_ERROR_NONE;
+                       return TRAINING_ENGINE_ERROR_NONE;
                }
 
-               auto verify = engine->GetVerifyDataSet();
-               const auto verify_data_idx = engine->GetVerifyDataIdx();
-
-               std::copy(verify[verify_data_idx].data.begin(), verify[verify_data_idx].data.end(), data[0]);
-               std::copy(verify[verify_data_idx].label.begin(), verify[verify_data_idx].label.end(), label[0]);
+               std::copy(_train_data_sets[_train_data_idx].data.begin(),
+                                 _train_data_sets[_train_data_idx].data.end(), data[0]);
+               std::copy(_train_data_sets[_train_data_idx].label.begin(),
+                                 _train_data_sets[_train_data_idx].label.end(), label[0]);
 
-               engine->IncreaseVerifyDataIdx();
+               _train_data_idx++;
                *last = false;
 
-               LOGI("Updated verify data.");
+               LOGI("Updated train data.");
 
-               return ML_ERROR_NONE;
+               return TRAINING_ENGINE_ERROR_NONE;
        }
 
-       TrainingNntrainer::TrainingNntrainer(void) :
-               _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx()
+       int TrainingNntrainer::UpdateVerifyDataSet(float **data, float **label, bool *last)
        {
-       }
+               if (IsLastVerifyDataIdx()) {
+                       // if last is false then NNTrainer starts training with given data and label.
+                       // Otherwise, it finalizes the training.
+                       // 'last = true' means that all data and labels have been passed.
+                       *last = true;
+                       _verify_data_idx = 0;
 
-       TrainingNntrainer::~TrainingNntrainer()
-       {
-       }
+                       return TRAINING_ENGINE_ERROR_NONE;
+               }
 
-       void TrainingNntrainer::IncreaseTrainDataIdx()
-       {
-               _train_data_idx++;
-       }
+               std::copy(_verify_data_sets[_verify_data_idx].data.begin(),
+                                 _verify_data_sets[_verify_data_idx].data.end(), data[0]);
+               std::copy(_verify_data_sets[_verify_data_idx].label.begin(),
+                                 _verify_data_sets[_verify_data_idx].label.end(), label[0]);
 
-       void TrainingNntrainer::IncreaseVerifyDataIdx()
-       {
                _verify_data_idx++;
-       }
-
-       int TrainingNntrainer::GetTrainDataIdx()
-       {
-               return _train_data_idx;
-       }
-
-       int TrainingNntrainer::GetVerifyDataIdx()
-       {
-               return _verify_data_idx;
-       }
+               *last = false;
 
-       void TrainingNntrainer::ClearTrainDataIdx()
-       {
-               _train_data_idx = 0;
-       }
+               LOGI("Updated verify data.");
 
-       void TrainingNntrainer::ClearVerifyDataIdx()
-       {
-               _verify_data_idx = 0;
+               return TRAINING_ENGINE_ERROR_NONE;
        }
 
        bool TrainingNntrainer::IsLastTrainDataIdx()
        {
-               return GetTrainDataIdx() == GetTrainDataCnt();
+               return _train_data_idx == static_cast<int>(_train_data_sets.size());
        }
 
        bool TrainingNntrainer::IsLastVerifyDataIdx()
        {
-               return GetVerifyDataIdx() == GetVerifyDataCnt();
+               return _verify_data_idx == static_cast<int>(_verify_data_sets.size());
        }
 
        int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity)
@@ -510,14 +498,14 @@ namespace NntrainerImpl
                        case TRAINING_DATASET_TYPE_TRAIN:
                                _train_data_sets.push_back({data,label});
                                dataset_mode = ML_TRAIN_DATASET_MODE_TRAIN;
-                               func = UpdateTrainData;
+                               func = UpdateTrainDataCb;
 
                                break;
 
                        case TRAINING_DATASET_TYPE_VERIFY:
                                _verify_data_sets.push_back({data,label});
                                dataset_mode = ML_TRAIN_DATASET_MODE_VALID;
-                               func = UpdateVerifyData;
+                               func = UpdateVerifyDataCb;
 
                                break;
 
index 20831117aa25aba2d3c35dccebf82d530e9246a4..c713b182f2fb2eae676afaae784611533b4120c5 100644 (file)
@@ -42,6 +42,8 @@ namespace NntrainerImpl
        class TrainingNntrainer : public TrainingEngineInterface::Common::ITrainingEngineCommon
        {
        private:
+               bool IsLastTrainDataIdx();
+               bool IsLastVerifyDataIdx();
                ml_train_layer_type_e ConvertLayerType(training_engine_layer_type_e type);
                ml_train_optimizer_type_e ConvertOptimizerType(training_engine_optimizer_type_e type);
                std::vector<dataset_type> _train_data_sets;
@@ -73,35 +75,8 @@ namespace NntrainerImpl
                int SetDataset(training_engine_model *model, const training_engine_dataset *dataset) final;
                int CompileModel(const training_engine_model *model, training_engine_compile_property &property) final;
                int TrainModel(const training_engine_model *model, training_engine_model_property &property) final;
-
-               std::vector<dataset_type>& GetTrainDataSet(void)
-               {
-                       return _train_data_sets;
-               }
-
-               std::vector<dataset_type>& GetVerifyDataSet(void)
-               {
-                       return _verify_data_sets;
-               }
-
-               int GetTrainDataCnt(void)
-               {
-                       return static_cast<int>(_train_data_sets.size());
-               }
-
-               int GetVerifyDataCnt(void)
-               {
-                       return static_cast<int>(_verify_data_sets.size());
-               }
-
-               void IncreaseTrainDataIdx();
-               void IncreaseVerifyDataIdx();
-               int GetTrainDataIdx();
-               int GetVerifyDataIdx();
-               void ClearTrainDataIdx();
-               void ClearVerifyDataIdx();
-               bool IsLastTrainDataIdx();
-               bool IsLastVerifyDataIdx();
+               int UpdateTrainDataSet(float **data, float **label, bool *last);
+               int UpdateVerifyDataSet(float **data, float **label, bool *last);
        };
 } /* NntrainerImpl */
 } /* TrainingEngineImpl */