auto engine = static_cast<TrainingNntrainer *>(user_data);
auto train = engine->GetTrainDataSet();
- if (engine->_train_data_idx == engine->GetTrainDataCnt()) {
+ if (engine->GetTrainDataIdx() == engine->GetTrainDataCnt()) {
// if last is false then NNTrainer starts training with given data and label.
// Otherwise, it finalizes the training.
// 'last = true' means that all data and labels have been passed.
*last = true;
- engine->_train_data_idx = 0;
+ engine->ClearTrainDataIdx();
return ML_ERROR_NONE;
}
int idx = 0;
- for (auto v : train[engine->_train_data_idx].data)
+ for (auto v : train[engine->GetTrainDataIdx()].data)
data[0][idx++] = v;
idx = 0;
- for (auto l : train[engine->_train_data_idx].label)
+ for (auto l : train[engine->GetTrainDataIdx()].label)
label[0][idx++] = l;
- engine->_train_data_idx++;
+ engine->IncreaseTrainDataIdx();
*last = false;
LOGI("Updated train data.");
auto engine = static_cast<TrainingNntrainer *>(user_data);
auto verify = engine->GetVerifyDataSet();
- if (engine->_verify_data_idx == engine->GetVerifyDataCnt()) {
+ if (engine->GetVerifyDataIdx() == engine->GetVerifyDataCnt()) {
// if last is false then NNTrainer starts training with given data and label.
// Otherwise, it finalizes the training.
// 'last = true' means that all data and labels have been passed.
*last = true;
- engine->_verify_data_idx = 0;
+ engine->ClearVerifyDataIdx();
return ML_ERROR_NONE;
}
int idx = 0;
- for (auto v : verify[engine->_verify_data_idx].data)
+ for (auto v : verify[engine->GetVerifyDataIdx()].data)
data[0][idx++] = v;
idx = 0;
- for (auto l : verify[engine->_verify_data_idx].label)
+ for (auto l : verify[engine->GetVerifyDataIdx()].label)
label[0][idx++] = l;
- engine->_verify_data_idx++;
+ engine->IncreaseVerifyDataIdx();
*last = false;
LOGI("Updated verify data.");
}
TrainingNntrainer::TrainingNntrainer(void) :
- _train_data_sets(), _verify_data_sets()
+ _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx()
{
}
{
}
+ void TrainingNntrainer::IncreaseTrainDataIdx()
+ {
+ _train_data_idx++;
+ }
+
+ void TrainingNntrainer::IncreaseVerifyDataIdx()
+ {
+ _verify_data_idx++;
+ }
+
+ int TrainingNntrainer::GetTrainDataIdx()
+ {
+ return _train_data_idx;
+ }
+
+ int TrainingNntrainer::GetVerifyDataIdx()
+ {
+ return _verify_data_idx;
+ }
+
+ void TrainingNntrainer::ClearTrainDataIdx()
+ {
+ _train_data_idx = 0;
+ }
+
+ void TrainingNntrainer::ClearVerifyDataIdx()
+ {
+ _verify_data_idx = 0;
+ }
+
int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity)
{
LOGI("ENTER");
ml_train_optimizer_type_e ConvertOptimizerType(training_engine_optimizer_type_e type);
std::vector<dataset_type> _train_data_sets;
std::vector<dataset_type> _verify_data_sets;
+ int _train_data_idx;
+ int _verify_data_idx;
public:
TrainingNntrainer();
return static_cast<int>(_verify_data_sets.size());
}
- int _train_data_idx;
- int _verify_data_idx;
+ void IncreaseTrainDataIdx();
+ void IncreaseVerifyDataIdx();
+ int GetTrainDataIdx();
+ int GetVerifyDataIdx();
+ void ClearTrainDataIdx();
+ void ClearVerifyDataIdx();
};
-
} /* NntrainerImpl */
} /* TrainingEngineImpl */