int UpdateTrainData(float **data, float **label, bool *last, void *user_data)
{
auto engine = static_cast<TrainingNntrainer *>(user_data);
- auto train = engine->GetTrainDataSet();
- if (engine->GetTrainDataIdx() == engine->GetTrainDataCnt()) {
+ if (engine->IsLastTrainDataIdx()) {
// if last is false then NNTrainer starts training with given data and label.
// Otherwise, it finalizes the training.
// 'last = true' means that all data and labels have been passed.
return ML_ERROR_NONE;
}
- int idx = 0;
-
- for (auto v : train[engine->GetTrainDataIdx()].data)
- data[0][idx++] = v;
+ auto train = engine->GetTrainDataSet();
+ const auto train_data_idx = engine->GetTrainDataIdx();
- idx = 0;
- for (auto l : train[engine->GetTrainDataIdx()].label)
- label[0][idx++] = l;
+ std::copy(train[train_data_idx].data.begin(), train[train_data_idx].data.end(), data[0]);
+ std::copy(train[train_data_idx].label.begin(), train[train_data_idx].label.end(), label[0]);
engine->IncreaseTrainDataIdx();
*last = false;
int UpdateVerifyData(float **data, float **label, bool *last, void *user_data)
{
auto engine = static_cast<TrainingNntrainer *>(user_data);
- auto verify = engine->GetVerifyDataSet();
- if (engine->GetVerifyDataIdx() == engine->GetVerifyDataCnt()) {
+ if (engine->IsLastVerifyDataIdx()) {
// if last is false then NNTrainer starts training with given data and label.
// Otherwise, it finalizes the training.
// 'last = true' means that all data and labels have been passed.
return ML_ERROR_NONE;
}
- int idx = 0;
-
- for (auto v : verify[engine->GetVerifyDataIdx()].data)
- data[0][idx++] = v;
+ auto verify = engine->GetVerifyDataSet();
+ const auto verify_data_idx = engine->GetVerifyDataIdx();
- idx = 0;
- for (auto l : verify[engine->GetVerifyDataIdx()].label)
- label[0][idx++] = l;
+ std::copy(verify[verify_data_idx].data.begin(), verify[verify_data_idx].data.end(), data[0]);
+ std::copy(verify[verify_data_idx].label.begin(), verify[verify_data_idx].label.end(), label[0]);
engine->IncreaseVerifyDataIdx();
*last = false;
_verify_data_idx = 0;
}
+ bool TrainingNntrainer::IsLastTrainDataIdx()
+ {
+ return GetTrainDataIdx() == GetTrainDataCnt();
+ }
+
+ bool TrainingNntrainer::IsLastVerifyDataIdx()
+ {
+ return GetVerifyDataIdx() == GetVerifyDataCnt();
+ }
+
int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity)
{
LOGI("ENTER");