From d69e0aaeb80f4f2f48d609468f4be29e204c7b34 Mon Sep 17 00:00:00 2001 From: "Piotr Kosko/Tizen API (PLT) /SRPOL/Engineer/Samsung Electronics" Date: Mon, 21 Feb 2022 12:56:59 +0100 Subject: [PATCH] [ML] Train - added load() method implementation [ACR] https://code.sec.samsung.net/jira/browse/TWDAPI-285 [Verification] Code compiles without errors. Verified in chrome console - load() works. var m2 = tizen.ml.trainer.createModel() m2.load("documents/ttt_INI_WITH_BIN.ini", "FORMAT_INI_WITH_BIN") Change-Id: Ic9d248790814dee47f5d0b712fe15e59ee8b93b9 --- src/ml/js/ml_trainer.js | 40 +++++++++++++++++++++++++++++++++++ src/ml/ml_instance.cc | 27 ++++++++++++++++++++++++ src/ml/ml_instance.h | 1 + src/ml/ml_trainer_manager.cc | 41 +++++++++++++++++++++++++++++++++++- src/ml/ml_trainer_manager.h | 5 +++-- 5 files changed, 111 insertions(+), 3 deletions(-) diff --git a/src/ml/js/ml_trainer.js b/src/ml/js/ml_trainer.js index 81e8d837..8354becb 100755 --- a/src/ml/js/ml_trainer.js +++ b/src/ml/js/ml_trainer.js @@ -516,6 +516,46 @@ Model.prototype.saveToFile = function () { } }; +Model.prototype.load = function () { + var args = validator_.validateArgs(arguments, [ + { + name: 'path', + type: types_.STRING + }, + { + name: 'format', + type: types_.ENUM, + values: Object.values(SaveFormat) + } + ]); + + var callArgs = { + id: this._id, + savePath: args.path, + saveFormat: args.format + } + + try { + callArgs.savePath = tizen.filesystem.toURI(args.path); + } catch (e) { + throw new WebAPIException(WebAPIException.InvalidValuesError, 'Invalid file path given'); + } + + if (!tizen.filesystem.pathExists(callArgs.savePath)) { + throw new WebAPIException(WebAPIException.NotFoundError, 'Path not found'); + } + + var result = native_.callSync('MLTrainerModelLoad', callArgs); + + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidModelSaveExceptions, + AbortError + ); + } +}; + Model.prototype.addLayer = function() { var args = validator_.validateArgs(arguments, [ { diff --git a/src/ml/ml_instance.cc b/src/ml/ml_instance.cc index 9a4d4473..6e256d79 100644 --- a/src/ml/ml_instance.cc +++ b/src/ml/ml_instance.cc @@ -197,6 +197,7 @@ MlInstance::MlInstance() REGISTER_METHOD(MLTrainerModelSummarize); REGISTER_METHOD(MLTrainerModelCheckMetrics); REGISTER_METHOD(MLTrainerModelSave); + REGISTER_METHOD(MLTrainerModelLoad); REGISTER_METHOD(MLTrainerModelSetDataset); REGISTER_METHOD(MLTrainerModelSetOptimizer); REGISTER_METHOD(MLTrainerModelDispose); @@ -2026,6 +2027,32 @@ void MlInstance::MLTrainerModelSave(const picojson::value& args, ReportSuccess(out); } +void MlInstance::MLTrainerModelLoad(const picojson::value& args, + picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + CHECK_ARGS(args, kId, double, out); + CHECK_ARGS(args, kSavePath, std::string, out); + CHECK_ARGS(args, kSaveFormat, std::string, out); + + auto id = static_cast(args.get(kId).get()); + auto path = args.get(kSavePath).get(); + + ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN; + PlatformResult result = types::ModelSaveFormatEnum.getValue( + args.get(kSaveFormat).get(), &model_format); + if (!result) { + LogAndReportError(result, &out); + return; + } + + result = trainer_manager_.ModelLoad(id, path, model_format); + if (!result) { + ReportError(result, &out); + return; + } + ReportSuccess(out); +} + void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) { ScopeLogger("args: %s", args.serialize().c_str()); CHECK_ARGS(args, kId, double, out); diff --git a/src/ml/ml_instance.h b/src/ml/ml_instance.h index fbf50789..5fed8fa8 100644 --- a/src/ml/ml_instance.h +++ b/src/ml/ml_instance.h @@ -171,6 +171,7 @@ class MlInstance : public common::ParsedInstance { void MLTrainerModelCheckMetrics(const picojson::value& args, picojson::object& out); void MLTrainerModelSave(const picojson::value& args, picojson::object& out); + void MLTrainerModelLoad(const picojson::value& args, picojson::object& out); void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out); void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out); void MLTrainerModelDispose(const picojson::value& args, diff --git a/src/ml/ml_trainer_manager.cc b/src/ml/ml_trainer_manager.cc index 86d852b2..5afe64dd 100644 --- a/src/ml/ml_trainer_manager.cc +++ b/src/ml/ml_trainer_manager.cc @@ -468,7 +468,46 @@ PlatformResult TrainerManager::ModelSave(int id, const std::string& path, model->instanceLock.unlock(); if (ret_val != ML_ERROR_NONE) { - LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val)); + LoggerE("Could not save model to file: %d (%s)", ret_val, + ml_strerror(ret_val)); + return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); + } + + return PlatformResult(); +} + +PlatformResult TrainerManager::ModelLoad(int id, const std::string& path, + ml_train_model_format_e format) { + ScopeLogger(); + + if (models_.find(id) == models_.end()) { + LoggerE("Could not find model with id: %d", id); + return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model"); + } + + auto& model = models_[id]; + bool available = model->instanceLock.try_lock(); + if (!available) { + LoggerE("Model locked - probaly training in progress"); + return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR, + "Model training in progress - cannot save"); + } + + auto tmpString = path; + if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) { + // remove 'file://' prefix from path before passing to native api + tmpString.erase(0, FILE_PATH_PREFIX.length()); + } + + LoggerI("Loading model from file: %s", tmpString.c_str()); + int ret_val = + ml_train_model_load(model->getNative(), tmpString.c_str(), format); + + model->instanceLock.unlock(); + + if (ret_val != ML_ERROR_NONE) { + LoggerE("Could not load model from file: %d (%s)", ret_val, + ml_strerror(ret_val)); return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); } diff --git a/src/ml/ml_trainer_manager.h b/src/ml/ml_trainer_manager.h index 009715ba..4e2b100a 100644 --- a/src/ml/ml_trainer_manager.h +++ b/src/ml/ml_trainer_manager.h @@ -49,8 +49,9 @@ class TrainerManager { std::string& summary); PlatformResult CheckMetrics(int id, double train_loss, double valid_loss, double valid_accuracy, bool* result); - PlatformResult ModelSave(int id, - const std::string& path, + PlatformResult ModelSave(int id, const std::string& path, + ml_train_model_format_e format); + PlatformResult ModelLoad(int id, const std::string& path, ml_train_model_format_e format); PlatformResult ModelDispose(int id); -- 2.34.1