encapsulate class members 47/271447/1
authorInki Dae <inki.dae@samsung.com>
Mon, 21 Feb 2022 10:10:10 +0000 (19:10 +0900)
committerInki Dae <inki.dae@samsung.com>
Mon, 21 Feb 2022 10:10:10 +0000 (19:10 +0900)
[Version] 0.1.2-0
[Issue type] cleanup

Encapsulated _train_data_idx and _verify_data_idx member variables.
Callback functions are called outside of the class so it makes
these two members are accessed only using public functions.

Change-Id: I56a9afdd3866eb85ed9870bbffddf56efa0067e2
Signed-off-by: Inki Dae <inki.dae@samsung.com>
packaging/training-engine-nntrainer.spec
src/training_engine_nntrainer.cpp
src/training_engine_nntrainer_private.h

index 90e63bc4f9c07d4f70f8b123ceff7f77410b711c..b2dc0524f48514b9077bc3f6f4559a13582b50ac 100644 (file)
@@ -1,6 +1,6 @@
 Name:       training-engine-nntrainer
 Summary:    Training engine NNTrainer backend
-Version:    0.1.1
+Version:    0.1.2
 Release:    0
 Group:      Multimedia/Libraries
 License:    Apache-2.0
index 61af208ef1681ff8c88dcbbbe493d94bf2d05422..0b9c98db27f427d3e2a6269c50c9b48d5c6b4ad8 100644 (file)
@@ -32,26 +32,26 @@ namespace NntrainerImpl
                auto engine = static_cast<TrainingNntrainer *>(user_data);
                auto train = engine->GetTrainDataSet();
 
-               if (engine->_train_data_idx == engine->GetTrainDataCnt()) {
+               if (engine->GetTrainDataIdx() == engine->GetTrainDataCnt()) {
                        // if last is false then NNTrainer starts training with given data and label.
                        // Otherwise, it finalizes the training.
                        // 'last = true' means that all data and labels have been passed.
                        *last = true;
-                       engine->_train_data_idx = 0;
+                       engine->ClearTrainDataIdx();
 
                        return ML_ERROR_NONE;
                }
 
                int idx = 0;
 
-               for (auto v : train[engine->_train_data_idx].data)
+               for (auto v : train[engine->GetTrainDataIdx()].data)
                        data[0][idx++] = v;
 
                idx = 0;
-               for (auto l : train[engine->_train_data_idx].label)
+               for (auto l : train[engine->GetTrainDataIdx()].label)
                        label[0][idx++] = l;
 
-               engine->_train_data_idx++;
+               engine->IncreaseTrainDataIdx();
                *last = false;
 
                LOGI("Updated train data.");
@@ -64,26 +64,26 @@ namespace NntrainerImpl
                auto engine = static_cast<TrainingNntrainer *>(user_data);
                auto verify = engine->GetVerifyDataSet();
 
-               if (engine->_verify_data_idx == engine->GetVerifyDataCnt()) {
+               if (engine->GetVerifyDataIdx() == engine->GetVerifyDataCnt()) {
                        // if last is false then NNTrainer starts training with given data and label.
                        // Otherwise, it finalizes the training.
                        // 'last = true' means that all data and labels have been passed.
                        *last = true;
-                       engine->_verify_data_idx = 0;
+                       engine->ClearVerifyDataIdx();
 
                        return ML_ERROR_NONE;
                }
 
                int idx = 0;
 
-               for (auto v : verify[engine->_verify_data_idx].data)
+               for (auto v : verify[engine->GetVerifyDataIdx()].data)
                        data[0][idx++] = v;
 
                idx = 0;
-               for (auto l : verify[engine->_verify_data_idx].label)
+               for (auto l : verify[engine->GetVerifyDataIdx()].label)
                        label[0][idx++] = l;
 
-               engine->_verify_data_idx++;
+               engine->IncreaseVerifyDataIdx();
                *last = false;
 
                LOGI("Updated verify data.");
@@ -92,7 +92,7 @@ namespace NntrainerImpl
        }
 
        TrainingNntrainer::TrainingNntrainer(void) :
-               _train_data_sets(), _verify_data_sets()
+               _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx()
        {
        }
 
@@ -100,6 +100,36 @@ namespace NntrainerImpl
        {
        }
 
+       void TrainingNntrainer::IncreaseTrainDataIdx()
+       {
+               _train_data_idx++;
+       }
+
+       void TrainingNntrainer::IncreaseVerifyDataIdx()
+       {
+               _verify_data_idx++;
+       }
+
+       int TrainingNntrainer::GetTrainDataIdx()
+       {
+               return _train_data_idx;
+       }
+
+       int TrainingNntrainer::GetVerifyDataIdx()
+       {
+               return _verify_data_idx;
+       }
+
+       void TrainingNntrainer::ClearTrainDataIdx()
+       {
+               _train_data_idx = 0;
+       }
+
+       void TrainingNntrainer::ClearVerifyDataIdx()
+       {
+               _verify_data_idx = 0;
+       }
+
        int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity)
        {
                LOGI("ENTER");
index ec79e26009bba7b7ea4c223baf2a782f3cf62506..c5f28dceafd36b1ae566da20d721d77e58d5197e 100644 (file)
@@ -46,6 +46,8 @@ namespace NntrainerImpl
                ml_train_optimizer_type_e ConvertOptimizerType(training_engine_optimizer_type_e type);
                std::vector<dataset_type> _train_data_sets;
                std::vector<dataset_type> _verify_data_sets;
+               int _train_data_idx;
+               int _verify_data_idx;
 
        public:
                TrainingNntrainer();
@@ -92,10 +94,13 @@ namespace NntrainerImpl
                        return static_cast<int>(_verify_data_sets.size());
                }
 
-               int _train_data_idx;
-               int _verify_data_idx;
+               void IncreaseTrainDataIdx();
+               void IncreaseVerifyDataIdx();
+               int GetTrainDataIdx();
+               int GetVerifyDataIdx();
+               void ClearTrainDataIdx();
+               void ClearVerifyDataIdx();
        };
-
 } /* NntrainerImpl */
 } /* TrainingEngineImpl */