From c046d6460f394de2b23e967c929745cfdb193899 Mon Sep 17 00:00:00 2001 From: Marcin Kaminski Date: Sat, 15 Jan 2022 14:00:35 +0100 Subject: [PATCH] [ML][Training] Model saving to file Changes: - implementation of Model.saveToFile() method for saving model on storage for re-use - new enum for model saving selection - minor change in dataset creation (path prefix removal) Change-Id: Ib71d21b2d4a61e2ff2ed8e8b4f983acaeaa02408 --- src/ml/js/ml_trainer.js | 54 +++++++++++++++++++++++++++++++++++- src/ml/ml_instance.cc | 29 +++++++++++++++++++ src/ml/ml_instance.h | 1 + src/ml/ml_trainer_manager.cc | 41 +++++++++++++++++++++++---- src/ml/ml_trainer_manager.h | 3 ++ src/ml/ml_utils.cc | 5 ++++ src/ml/ml_utils.h | 1 + 7 files changed, 127 insertions(+), 7 deletions(-) diff --git a/src/ml/js/ml_trainer.js b/src/ml/js/ml_trainer.js index 00a470f0..4dcc99da 100755 --- a/src/ml/js/ml_trainer.js +++ b/src/ml/js/ml_trainer.js @@ -54,6 +54,12 @@ var VerbosityLevel = { SUMMARY_TENSOR: 'SUMMARY_TENSOR' }; +var SaveFormat = { + FORMAT_BIN: 'FORMAT_BIN', + FORMAT_INI: 'FORMAT_INI', + FORMAT_INI_WITH_BIN: 'FORMAT_INI_WITH_BIN' +} + var Layer = function(id, type) { Object.defineProperties(this, { name: { @@ -336,6 +342,11 @@ Model.prototype.run = function() { } }; +var ValidBasicExceptions = [ + 'TypeMismatchError', + 'AbortError' +]; + Model.prototype.summarize = function() { var args = validator_.validateArgs(arguments, [ { @@ -365,11 +376,52 @@ Model.prototype.summarize = function() { return result.summary }; -var ValidBasicExceptions = [ +var ValidModelSaveExceptions = [ + 'InvalidValuesError', 'TypeMismatchError', 'AbortError' ]; +Model.prototype.saveToFile = function () { + var args = validator_.validateArgs(arguments, [ + { + name: 'path', + type: types_.STRING + }, + { + name: 'format', + type: types_.ENUM, + values: Object.values(SaveFormat) + } + ]); + + var callArgs = { + id: this._id, + savePath: args.path, + saveFormat: args.format + } + + try { + callArgs.savePath = tizen.filesystem.toURI(args.path); + } catch (e) { + throw new WebAPIException(WebAPIException.InvalidValuesError, 'Invalid file path given'); + } + + if (tizen.filesystem.pathExists(callArgs.savePath)) { + throw new WebAPIException(WebAPIException.NO_MODIFICATION_ALLOWED_ERR, 'Path already exists - overwriting is not allowed'); + } + + var result = native_.callSync('MLTrainerModelSave', callArgs); + + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidModelSaveExceptions, + AbortError + ); + } +}; + Model.prototype.addLayer = function() { var args = validator_.validateArgs(arguments, [ { diff --git a/src/ml/ml_instance.cc b/src/ml/ml_instance.cc index db9cc67d..5d1dfbda 100644 --- a/src/ml/ml_instance.cc +++ b/src/ml/ml_instance.cc @@ -83,6 +83,8 @@ const std::string kDatasetId = "datasetId"; const std::string kOptimizerId = "optimizerId"; const std::string kLevel = "level"; const std::string kSummary = "summary"; +const std::string kSavePath = "savePath"; +const std::string kSaveFormat = "saveFormat"; } // namespace using namespace common; @@ -188,6 +190,7 @@ MlInstance::MlInstance() REGISTER_METHOD(MLTrainerModelAddLayer); REGISTER_METHOD(MLTrainerModelRun); REGISTER_METHOD(MLTrainerModelSummarize); + REGISTER_METHOD(MLTrainerModelSave); REGISTER_METHOD(MLTrainerModelSetDataset); REGISTER_METHOD(MLTrainerModelSetOptimizer); REGISTER_METHOD(MLTrainerDatasetCreateGenerator); @@ -1914,6 +1917,32 @@ void MlInstance::MLTrainerModelSummarize(const picojson::value& args, picojson:: ReportSuccess(out); } +void MlInstance::MLTrainerModelSave(const picojson::value& args, + picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + CHECK_ARGS(args, kId, double, out); + CHECK_ARGS(args, kSavePath, std::string, out); + CHECK_ARGS(args, kSaveFormat, std::string, out); + + auto id = static_cast(args.get(kId).get()); + auto path = args.get(kSavePath).get(); + + ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN; + PlatformResult result = types::ModelSaveFormatEnum.getValue( + args.get(kSaveFormat).get(), &model_format); + if (!result) { + LogAndReportError(result, &out); + return; + } + + result = trainer_manager_.ModelSave(id, path, model_format); + if (!result) { + ReportError(result, &out); + return; + } + ReportSuccess(out); +} + void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) { ScopeLogger("args: %s", args.serialize().c_str()); CHECK_ARGS(args, kId, double, out); diff --git a/src/ml/ml_instance.h b/src/ml/ml_instance.h index da95fcd0..69bc0706 100644 --- a/src/ml/ml_instance.h +++ b/src/ml/ml_instance.h @@ -162,6 +162,7 @@ class MlInstance : public common::ParsedInstance { void MLTrainerModelAddLayer(const picojson::value& args, picojson::object& out); void MLTrainerModelRun(const picojson::value& args, picojson::object& out); void MLTrainerModelSummarize(const picojson::value& args, picojson::object& out); + void MLTrainerModelSave(const picojson::value& args, picojson::object& out); void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out); void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out); diff --git a/src/ml/ml_trainer_manager.cc b/src/ml/ml_trainer_manager.cc index dbef05f7..071ccb3a 100644 --- a/src/ml/ml_trainer_manager.cc +++ b/src/ml/ml_trainer_manager.cc @@ -271,6 +271,35 @@ PlatformResult TrainerManager::ModelSummarize(int id, return PlatformResult(); } +PlatformResult TrainerManager::ModelSave(int id, + const std::string& path, + ml_train_model_format_e format) { + ScopeLogger(); + + if (models_.find(id) == models_.end()) { + LoggerE("Could not find model with id: %d", id); + return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model"); + } + + auto& model = models_[id]; + + auto tmpString = path; + if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) { + // remove 'file://' prefix from path before passing to native api + tmpString.erase(0, FILE_PATH_PREFIX.length()); + } + + LoggerI("Saving model to file: %s", tmpString.c_str()); + int ret_val = ml_train_model_save(model, tmpString.c_str(), format); + + if (ret_val != 0) { + LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val)); + return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val)); + } + + return PlatformResult(); +} + PlatformResult TrainerManager::CreateLayer(int& id, ml_train_layer_type_e type) { ScopeLogger(); @@ -363,9 +392,9 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai if (!train_file.empty()) { auto tmpString = train_file; - if (tmpString.substr(0, 7) == "file://") { + if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) { // remove 'file://' prefix from path before passing to native api - tmpString.erase(0, 7); + tmpString.erase(0, FILE_PATH_PREFIX.length()); } ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TRAIN, @@ -380,9 +409,9 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai if (!valid_file.empty()) { auto tmpString = valid_file; - if (tmpString.substr(0, 7) == "file://") { + if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) { // remove 'file://' prefix from path before passing to native api - tmpString.erase(0, 7); + tmpString.erase(0, FILE_PATH_PREFIX.length()); } ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_VALID, tmpString.c_str()); @@ -396,9 +425,9 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai if (!test_file.empty()) { auto tmpString = test_file; - if (tmpString.substr(0, 7) == "file://") { + if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) { // remove 'file://' prefix from path before passing to native api - tmpString.erase(0, 7); + tmpString.erase(0, FILE_PATH_PREFIX.length()); } ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TEST, tmpString.c_str()); diff --git a/src/ml/ml_trainer_manager.h b/src/ml/ml_trainer_manager.h index 3b154dd1..72f800b4 100644 --- a/src/ml/ml_trainer_manager.h +++ b/src/ml/ml_trainer_manager.h @@ -43,6 +43,9 @@ class TrainerManager { PlatformResult ModelSetDataset(int id, int datasetId); PlatformResult ModelSummarize(int id, ml_train_summary_type_e level, std::string& summary); + PlatformResult ModelSave(int id, + const std::string& path, + ml_train_model_format_e format); PlatformResult CreateLayer(int& id, ml_train_layer_type_e type); PlatformResult LayerSetProperty(int id, const std::string& name, diff --git a/src/ml/ml_utils.cc b/src/ml/ml_utils.cc index daa361c6..4b0734cf 100644 --- a/src/ml/ml_utils.cc +++ b/src/ml/ml_utils.cc @@ -95,6 +95,11 @@ const PlatformEnum SummaryTypeEnum{ {"SUMMARY_LAYER", ML_TRAIN_SUMMARY_LAYER}, {"SUMMARY_TENSOR", ML_TRAIN_SUMMARY_TENSOR}}; +const PlatformEnum ModelSaveFormatEnum{ + {"FORMAT_BIN", ML_TRAIN_MODEL_FORMAT_BIN}, + {"FORMAT_INI", ML_TRAIN_MODEL_FORMAT_INI}, + {"FORMAT_INI_WITH_BIN", ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN}}; + } // namespace types namespace util { diff --git a/src/ml/ml_utils.h b/src/ml/ml_utils.h index ccfeb162..6569623a 100644 --- a/src/ml/ml_utils.h +++ b/src/ml/ml_utils.h @@ -48,6 +48,7 @@ extern const PlatformEnum TensorTypeEnum; extern const PlatformEnum OptimizerTypeEnum; extern const PlatformEnum LayerTypeEnum; extern const PlatformEnum SummaryTypeEnum; +extern const PlatformEnum ModelSaveFormatEnum; } // namespace types -- 2.34.1