From d69e0aaeb80f4f2f48d609468f4be29e204c7b34 Mon Sep 17 00:00:00 2001
From: "Piotr Kosko/Tizen API (PLT) /SRPOL/Engineer/Samsung Electronics"
Date: Mon, 21 Feb 2022 12:56:59 +0100
Subject: [PATCH] [ML] Train - added load() method implementation
[ACR] https://code.sec.samsung.net/jira/browse/TWDAPI-285
[Verification] Code compiles without errors.
Verified in chrome console - load() works.
var m2 = tizen.ml.trainer.createModel()
m2.load("documents/ttt_INI_WITH_BIN.ini", "FORMAT_INI_WITH_BIN")
Change-Id: Ic9d248790814dee47f5d0b712fe15e59ee8b93b9
---
src/ml/js/ml_trainer.js | 40 +++++++++++++++++++++++++++++++++++
src/ml/ml_instance.cc | 27 ++++++++++++++++++++++++
src/ml/ml_instance.h | 1 +
src/ml/ml_trainer_manager.cc | 41 +++++++++++++++++++++++++++++++++++-
src/ml/ml_trainer_manager.h | 5 +++--
5 files changed, 111 insertions(+), 3 deletions(-)
diff --git a/src/ml/js/ml_trainer.js b/src/ml/js/ml_trainer.js
index 81e8d837..8354becb 100755
--- a/src/ml/js/ml_trainer.js
+++ b/src/ml/js/ml_trainer.js
@@ -516,6 +516,46 @@ Model.prototype.saveToFile = function () {
}
};
+Model.prototype.load = function () {
+ var args = validator_.validateArgs(arguments, [
+ {
+ name: 'path',
+ type: types_.STRING
+ },
+ {
+ name: 'format',
+ type: types_.ENUM,
+ values: Object.values(SaveFormat)
+ }
+ ]);
+
+ var callArgs = {
+ id: this._id,
+ savePath: args.path,
+ saveFormat: args.format
+ }
+
+ try {
+ callArgs.savePath = tizen.filesystem.toURI(args.path);
+ } catch (e) {
+ throw new WebAPIException(WebAPIException.InvalidValuesError, 'Invalid file path given');
+ }
+
+ if (!tizen.filesystem.pathExists(callArgs.savePath)) {
+ throw new WebAPIException(WebAPIException.NotFoundError, 'Path not found');
+ }
+
+ var result = native_.callSync('MLTrainerModelLoad', callArgs);
+
+ if (native_.isFailure(result)) {
+ throw native_.getErrorObjectAndValidate(
+ result,
+ ValidModelSaveExceptions,
+ AbortError
+ );
+ }
+};
+
Model.prototype.addLayer = function() {
var args = validator_.validateArgs(arguments, [
{
diff --git a/src/ml/ml_instance.cc b/src/ml/ml_instance.cc
index 9a4d4473..6e256d79 100644
--- a/src/ml/ml_instance.cc
+++ b/src/ml/ml_instance.cc
@@ -197,6 +197,7 @@ MlInstance::MlInstance()
REGISTER_METHOD(MLTrainerModelSummarize);
REGISTER_METHOD(MLTrainerModelCheckMetrics);
REGISTER_METHOD(MLTrainerModelSave);
+ REGISTER_METHOD(MLTrainerModelLoad);
REGISTER_METHOD(MLTrainerModelSetDataset);
REGISTER_METHOD(MLTrainerModelSetOptimizer);
REGISTER_METHOD(MLTrainerModelDispose);
@@ -2026,6 +2027,32 @@ void MlInstance::MLTrainerModelSave(const picojson::value& args,
ReportSuccess(out);
}
+void MlInstance::MLTrainerModelLoad(const picojson::value& args,
+ picojson::object& out) {
+ ScopeLogger("args: %s", args.serialize().c_str());
+ CHECK_ARGS(args, kId, double, out);
+ CHECK_ARGS(args, kSavePath, std::string, out);
+ CHECK_ARGS(args, kSaveFormat, std::string, out);
+
+ auto id = static_cast(args.get(kId).get());
+ auto path = args.get(kSavePath).get();
+
+ ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN;
+ PlatformResult result = types::ModelSaveFormatEnum.getValue(
+ args.get(kSaveFormat).get(), &model_format);
+ if (!result) {
+ LogAndReportError(result, &out);
+ return;
+ }
+
+ result = trainer_manager_.ModelLoad(id, path, model_format);
+ if (!result) {
+ ReportError(result, &out);
+ return;
+ }
+ ReportSuccess(out);
+}
+
void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
ScopeLogger("args: %s", args.serialize().c_str());
CHECK_ARGS(args, kId, double, out);
diff --git a/src/ml/ml_instance.h b/src/ml/ml_instance.h
index fbf50789..5fed8fa8 100644
--- a/src/ml/ml_instance.h
+++ b/src/ml/ml_instance.h
@@ -171,6 +171,7 @@ class MlInstance : public common::ParsedInstance {
void MLTrainerModelCheckMetrics(const picojson::value& args,
picojson::object& out);
void MLTrainerModelSave(const picojson::value& args, picojson::object& out);
+ void MLTrainerModelLoad(const picojson::value& args, picojson::object& out);
void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out);
void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out);
void MLTrainerModelDispose(const picojson::value& args,
diff --git a/src/ml/ml_trainer_manager.cc b/src/ml/ml_trainer_manager.cc
index 86d852b2..5afe64dd 100644
--- a/src/ml/ml_trainer_manager.cc
+++ b/src/ml/ml_trainer_manager.cc
@@ -468,7 +468,46 @@ PlatformResult TrainerManager::ModelSave(int id, const std::string& path,
model->instanceLock.unlock();
if (ret_val != ML_ERROR_NONE) {
- LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
+ LoggerE("Could not save model to file: %d (%s)", ret_val,
+ ml_strerror(ret_val));
+ return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+ }
+
+ return PlatformResult();
+}
+
+PlatformResult TrainerManager::ModelLoad(int id, const std::string& path,
+ ml_train_model_format_e format) {
+ ScopeLogger();
+
+ if (models_.find(id) == models_.end()) {
+ LoggerE("Could not find model with id: %d", id);
+ return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+ }
+
+ auto& model = models_[id];
+ bool available = model->instanceLock.try_lock();
+ if (!available) {
+ LoggerE("Model locked - probaly training in progress");
+ return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR,
+ "Model training in progress - cannot save");
+ }
+
+ auto tmpString = path;
+ if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
+ // remove 'file://' prefix from path before passing to native api
+ tmpString.erase(0, FILE_PATH_PREFIX.length());
+ }
+
+ LoggerI("Loading model from file: %s", tmpString.c_str());
+ int ret_val =
+ ml_train_model_load(model->getNative(), tmpString.c_str(), format);
+
+ model->instanceLock.unlock();
+
+ if (ret_val != ML_ERROR_NONE) {
+ LoggerE("Could not load model from file: %d (%s)", ret_val,
+ ml_strerror(ret_val));
return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
}
diff --git a/src/ml/ml_trainer_manager.h b/src/ml/ml_trainer_manager.h
index 009715ba..4e2b100a 100644
--- a/src/ml/ml_trainer_manager.h
+++ b/src/ml/ml_trainer_manager.h
@@ -49,8 +49,9 @@ class TrainerManager {
std::string& summary);
PlatformResult CheckMetrics(int id, double train_loss, double valid_loss,
double valid_accuracy, bool* result);
- PlatformResult ModelSave(int id,
- const std::string& path,
+ PlatformResult ModelSave(int id, const std::string& path,
+ ml_train_model_format_e format);
+ PlatformResult ModelLoad(int id, const std::string& path,
ml_train_model_format_e format);
PlatformResult ModelDispose(int id);
--
2.34.1