From 1cfb47ab166bf2eac83c36b5313d0a079a55a4a8 Mon Sep 17 00:00:00 2001 From: Marcin Kaminski Date: Fri, 21 Jan 2022 19:44:41 +0100 Subject: [PATCH] [ML][Training] Dataset.setProperty() refactoring Changes: - enum DatasetMode added - Dataset.setProperty() expanded with 'mode' param to match C API design Change-Id: Ia619a69d6367c2624fff9d402e6a3505ce33f14b --- src/ml/js/ml_trainer.js | 20 ++++++++++++++++---- src/ml/ml_instance.cc | 12 +++++++++++- src/ml/ml_trainer_manager.cc | 36 +++++++++--------------------------- src/ml/ml_trainer_manager.h | 6 ++++-- src/ml/ml_utils.cc | 5 +++++ src/ml/ml_utils.h | 1 + 6 files changed, 46 insertions(+), 34 deletions(-) diff --git a/src/ml/js/ml_trainer.js b/src/ml/js/ml_trainer.js index 643a5f6f..4c3a8412 100755 --- a/src/ml/js/ml_trainer.js +++ b/src/ml/js/ml_trainer.js @@ -58,7 +58,13 @@ var SaveFormat = { FORMAT_BIN: 'FORMAT_BIN', FORMAT_INI: 'FORMAT_INI', FORMAT_INI_WITH_BIN: 'FORMAT_INI_WITH_BIN' -} +}; + +var DatasetMode = { + MODE_TRAIN: 'MODE_TRAIN', + MODE_VALID: 'MODE_VALID', + MODE_TEST: 'MODE_TEST' +}; var Layer = function(id, type) { Object.defineProperties(this, { @@ -215,20 +221,26 @@ Dataset.prototype.setProperty = function() { { name: 'value', type: types_.STRING + }, + { + name: 'mode', + type: types_.ENUM, + values: Object.values(DatasetMode) } ]); - if (!args.has.name || !args.has.value) { + if (!args.has.name || !args.has.value || !args.has.mode) { throw new WebAPIException( WebAPIException.TYPE_MISMATCH_ERR, - 'Invalid parameter: ' + (args.has.name ? 'value' : 'name') + ' is undefined' + 'Invalid parameter: name, value and mode have to be defined' ); } var callArgs = { id: this._id, name: args.name, - value: args.value + value: args.value, + mode: args.mode }; var result = native_.callSync('MLTrainerDatasetSetProperty', callArgs); diff --git a/src/ml/ml_instance.cc b/src/ml/ml_instance.cc index d7a1ba3f..250eced0 100644 --- a/src/ml/ml_instance.cc +++ b/src/ml/ml_instance.cc @@ -85,6 +85,7 @@ const std::string kLevel = "level"; const std::string kSummary = "summary"; const std::string kSavePath = "savePath"; const std::string kSaveFormat = "saveFormat"; +const std::string kMode = "mode"; } // namespace using namespace common; @@ -2058,12 +2059,21 @@ void MlInstance::MLTrainerDatasetSetProperty(const picojson::value& args, picojs CHECK_ARGS(args, kId, double, out); CHECK_ARGS(args, kName, std::string, out); CHECK_ARGS(args, kValue, std::string, out); + CHECK_ARGS(args, kMode, std::string, out); auto id = static_cast(args.get(kId).get()); auto name = args.get(kName).get(); auto value = args.get(kValue).get(); - PlatformResult result = trainer_manager_.DatasetSetProperty(id, name, value); + ml_train_dataset_mode_e datasetMode = ML_TRAIN_DATASET_MODE_TRAIN; + PlatformResult result = types::DatasetModeEnum.getValue( + args.get(kMode).get(), &datasetMode); + if (!result) { + LogAndReportError(result, &out); + return; + } + + result = trainer_manager_.DatasetSetProperty(id, name, value, datasetMode); if (!result) { ReportError(result, &out); return; diff --git a/src/ml/ml_trainer_manager.cc b/src/ml/ml_trainer_manager.cc index 01043f99..00795d51 100644 --- a/src/ml/ml_trainer_manager.cc +++ b/src/ml/ml_trainer_manager.cc @@ -612,9 +612,11 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai // MK-TODO Add creating Dataset with generator -PlatformResult TrainerManager::DatasetSetProperty(int id, - const std::string& name, - const std::string& value) { +PlatformResult TrainerManager::DatasetSetProperty( + int id, + const std::string& name, + const std::string& value, + ml_train_dataset_mode_e mode) { ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str()); if (datasets_.find(id) == datasets_.end()) { @@ -625,32 +627,12 @@ PlatformResult TrainerManager::DatasetSetProperty(int id, auto dataset = datasets_[id]; std::string opt = name + "=" + value; - // ml_train_dataset_set_property() is marked as deprecated - // temporary set same property for all modes (all data files) if possible - int ret_val = ml_train_dataset_set_property_for_mode( - dataset->getNative(), ML_TRAIN_DATASET_MODE_TRAIN, opt.c_str(), NULL); + int ret_val = ml_train_dataset_set_property_for_mode(dataset->getNative(), + mode, opt.c_str(), NULL); if (ret_val != ML_ERROR_NONE) { - LoggerE("Could not set dataset property for train mode: %d (%s)", ret_val, - ml_strerror(ret_val)); - return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); - } - - ret_val = ml_train_dataset_set_property_for_mode( - dataset->getNative(), ML_TRAIN_DATASET_MODE_VALID, opt.c_str(), NULL); - if (ret_val != ML_ERROR_NONE) { - LoggerE("Could not set dataset property for validation mode: %d (%s)", + LoggerE("Could not set dataset property for mode %d: %d (%s)", mode, ret_val, ml_strerror(ret_val)); - // MK-TODO report error for each file when extracted to separate functions - // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); - } - - ret_val = ml_train_dataset_set_property_for_mode( - dataset->getNative(), ML_TRAIN_DATASET_MODE_TEST, opt.c_str(), NULL); - if (ret_val != ML_ERROR_NONE) { - LoggerE("Could not set dataset property for test mode: %d (%s)", ret_val, - ml_strerror(ret_val)); - // MK-TODO report error for each file when extracted to separate functions - // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); + return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); } return PlatformResult(); diff --git a/src/ml/ml_trainer_manager.h b/src/ml/ml_trainer_manager.h index 15de4f8b..4e38c5a6 100644 --- a/src/ml/ml_trainer_manager.h +++ b/src/ml/ml_trainer_manager.h @@ -62,8 +62,10 @@ class TrainerManager { PlatformResult CreateFileDataset(int& id, const std::string train_file, const std::string valid_file, const std::string test_file); - PlatformResult DatasetSetProperty(int id, const std::string& name, - const std::string& value); + PlatformResult DatasetSetProperty(int id, + const std::string& name, + const std::string& value, + ml_train_dataset_mode_e mode); PlatformResult DatasetDispose(int id); private: diff --git a/src/ml/ml_utils.cc b/src/ml/ml_utils.cc index 4b0734cf..92fea3f7 100644 --- a/src/ml/ml_utils.cc +++ b/src/ml/ml_utils.cc @@ -100,6 +100,11 @@ const PlatformEnum ModelSaveFormatEnum{ {"FORMAT_INI", ML_TRAIN_MODEL_FORMAT_INI}, {"FORMAT_INI_WITH_BIN", ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN}}; +const PlatformEnum DatasetModeEnum{ + {"MODE_TRAIN", ML_TRAIN_DATASET_MODE_TRAIN}, + {"MODE_VALID", ML_TRAIN_DATASET_MODE_VALID}, + {"MODE_TEST", ML_TRAIN_DATASET_MODE_TEST}}; + } // namespace types namespace util { diff --git a/src/ml/ml_utils.h b/src/ml/ml_utils.h index 6569623a..37cd89cf 100644 --- a/src/ml/ml_utils.h +++ b/src/ml/ml_utils.h @@ -49,6 +49,7 @@ extern const PlatformEnum OptimizerTypeEnum; extern const PlatformEnum LayerTypeEnum; extern const PlatformEnum SummaryTypeEnum; extern const PlatformEnum ModelSaveFormatEnum; +extern const PlatformEnum DatasetModeEnum; } // namespace types -- 2.34.1