{
namespace NntrainerImpl
{
- int UpdateTrainData(float **data, float **label, bool *last, void *user_data)
+ int UpdateTrainDataCb(float **data, float **label, bool *last, void *user_data)
{
+ if (!user_data) {
+ LOGE("user_data is null.");
+ return TRAINING_ENGINE_ERROR_INVALID_PARAMETER;
+ }
+
auto engine = static_cast<TrainingNntrainer *>(user_data);
- if (engine->IsLastTrainDataIdx()) {
- // if last is false then NNTrainer starts training with given data and label.
- // Otherwise, it finalizes the training.
- // 'last = true' means that all data and labels have been passed.
- *last = true;
- engine->ClearTrainDataIdx();
+ return engine->UpdateTrainDataSet(data, label, last);
+ }
- return ML_ERROR_NONE;
+ int UpdateVerifyDataCb(float **data, float **label, bool *last, void *user_data)
+ {
+ if (!user_data) {
+ LOGE("user_data is null.");
+ return TRAINING_ENGINE_ERROR_INVALID_PARAMETER;
}
- auto train = engine->GetTrainDataSet();
- const auto train_data_idx = engine->GetTrainDataIdx();
-
- std::copy(train[train_data_idx].data.begin(), train[train_data_idx].data.end(), data[0]);
- std::copy(train[train_data_idx].label.begin(), train[train_data_idx].label.end(), label[0]);
-
- engine->IncreaseTrainDataIdx();
- *last = false;
+ auto engine = static_cast<TrainingNntrainer *>(user_data);
- LOGI("Updated train data.");
+ return engine->UpdateVerifyDataSet(data, label, last);
+ }
- return ML_ERROR_NONE;
+ TrainingNntrainer::TrainingNntrainer(void) :
+ _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx()
+ {
}
- int UpdateVerifyData(float **data, float **label, bool *last, void *user_data)
+ TrainingNntrainer::~TrainingNntrainer()
{
- auto engine = static_cast<TrainingNntrainer *>(user_data);
+ }
- if (engine->IsLastVerifyDataIdx()) {
+ int TrainingNntrainer::UpdateTrainDataSet(float **data, float **label, bool *last)
+ {
+ if (IsLastTrainDataIdx()) {
// if last is false then NNTrainer starts training with given data and label.
// Otherwise, it finalizes the training.
// 'last = true' means that all data and labels have been passed.
*last = true;
- engine->ClearVerifyDataIdx();
+ _train_data_idx = 0;
- return ML_ERROR_NONE;
+ return TRAINING_ENGINE_ERROR_NONE;
}
- auto verify = engine->GetVerifyDataSet();
- const auto verify_data_idx = engine->GetVerifyDataIdx();
-
- std::copy(verify[verify_data_idx].data.begin(), verify[verify_data_idx].data.end(), data[0]);
- std::copy(verify[verify_data_idx].label.begin(), verify[verify_data_idx].label.end(), label[0]);
+ std::copy(_train_data_sets[_train_data_idx].data.begin(),
+ _train_data_sets[_train_data_idx].data.end(), data[0]);
+ std::copy(_train_data_sets[_train_data_idx].label.begin(),
+ _train_data_sets[_train_data_idx].label.end(), label[0]);
- engine->IncreaseVerifyDataIdx();
+ _train_data_idx++;
*last = false;
- LOGI("Updated verify data.");
+ LOGI("Updated train data.");
- return ML_ERROR_NONE;
+ return TRAINING_ENGINE_ERROR_NONE;
}
- TrainingNntrainer::TrainingNntrainer(void) :
- _train_data_sets(), _verify_data_sets(), _train_data_idx(), _verify_data_idx()
+ int TrainingNntrainer::UpdateVerifyDataSet(float **data, float **label, bool *last)
{
- }
+ if (IsLastVerifyDataIdx()) {
+ // if last is false then NNTrainer starts training with given data and label.
+ // Otherwise, it finalizes the training.
+ // 'last = true' means that all data and labels have been passed.
+ *last = true;
+ _verify_data_idx = 0;
- TrainingNntrainer::~TrainingNntrainer()
- {
- }
+ return TRAINING_ENGINE_ERROR_NONE;
+ }
- void TrainingNntrainer::IncreaseTrainDataIdx()
- {
- _train_data_idx++;
- }
+ std::copy(_verify_data_sets[_verify_data_idx].data.begin(),
+ _verify_data_sets[_verify_data_idx].data.end(), data[0]);
+ std::copy(_verify_data_sets[_verify_data_idx].label.begin(),
+ _verify_data_sets[_verify_data_idx].label.end(), label[0]);
- void TrainingNntrainer::IncreaseVerifyDataIdx()
- {
_verify_data_idx++;
- }
-
- int TrainingNntrainer::GetTrainDataIdx()
- {
- return _train_data_idx;
- }
-
- int TrainingNntrainer::GetVerifyDataIdx()
- {
- return _verify_data_idx;
- }
+ *last = false;
- void TrainingNntrainer::ClearTrainDataIdx()
- {
- _train_data_idx = 0;
- }
+ LOGI("Updated verify data.");
- void TrainingNntrainer::ClearVerifyDataIdx()
- {
- _verify_data_idx = 0;
+ return TRAINING_ENGINE_ERROR_NONE;
}
bool TrainingNntrainer::IsLastTrainDataIdx()
{
- return GetTrainDataIdx() == GetTrainDataCnt();
+ return _train_data_idx == static_cast<int>(_train_data_sets.size());
}
bool TrainingNntrainer::IsLastVerifyDataIdx()
{
- return GetVerifyDataIdx() == GetVerifyDataCnt();
+ return _verify_data_idx == static_cast<int>(_verify_data_sets.size());
}
int TrainingNntrainer::GetBackendCapacity(training_engine_capacity &capacity)
case TRAINING_DATASET_TYPE_TRAIN:
_train_data_sets.push_back({data,label});
dataset_mode = ML_TRAIN_DATASET_MODE_TRAIN;
- func = UpdateTrainData;
+ func = UpdateTrainDataCb;
break;
case TRAINING_DATASET_TYPE_VERIFY:
_verify_data_sets.push_back({data,label});
dataset_mode = ML_TRAIN_DATASET_MODE_VALID;
- func = UpdateVerifyData;
+ func = UpdateVerifyDataCb;
break;
class TrainingNntrainer : public TrainingEngineInterface::Common::ITrainingEngineCommon
{
private:
+ bool IsLastTrainDataIdx();
+ bool IsLastVerifyDataIdx();
ml_train_layer_type_e ConvertLayerType(training_engine_layer_type_e type);
ml_train_optimizer_type_e ConvertOptimizerType(training_engine_optimizer_type_e type);
std::vector<dataset_type> _train_data_sets;
int SetDataset(training_engine_model *model, const training_engine_dataset *dataset) final;
int CompileModel(const training_engine_model *model, training_engine_compile_property &property) final;
int TrainModel(const training_engine_model *model, training_engine_model_property &property) final;
-
- std::vector<dataset_type>& GetTrainDataSet(void)
- {
- return _train_data_sets;
- }
-
- std::vector<dataset_type>& GetVerifyDataSet(void)
- {
- return _verify_data_sets;
- }
-
- int GetTrainDataCnt(void)
- {
- return static_cast<int>(_train_data_sets.size());
- }
-
- int GetVerifyDataCnt(void)
- {
- return static_cast<int>(_verify_data_sets.size());
- }
-
- void IncreaseTrainDataIdx();
- void IncreaseVerifyDataIdx();
- int GetTrainDataIdx();
- int GetVerifyDataIdx();
- void ClearTrainDataIdx();
- void ClearVerifyDataIdx();
- bool IsLastTrainDataIdx();
- bool IsLastVerifyDataIdx();
+ int UpdateTrainDataSet(float **data, float **label, bool *last);
+ int UpdateVerifyDataSet(float **data, float **label, bool *last);
};
} /* NntrainerImpl */
} /* TrainingEngineImpl */