[ML][Training] Model saving to file 78/269478/6
authorMarcin Kaminski <marcin.ka@partner.samsung.com>
Sat, 15 Jan 2022 13:00:35 +0000 (14:00 +0100)
committerMarcin Kaminski <marcin.ka@partner.samsung.com>
Tue, 8 Feb 2022 19:05:31 +0000 (20:05 +0100)
Changes:
- implementation of Model.saveToFile() method for saving model
on storage for re-use
- new enum for model saving selection
- minor change in dataset creation (path prefix removal)

Change-Id: Ib71d21b2d4a61e2ff2ed8e8b4f983acaeaa02408

src/ml/js/ml_trainer.js
src/ml/ml_instance.cc
src/ml/ml_instance.h
src/ml/ml_trainer_manager.cc
src/ml/ml_trainer_manager.h
src/ml/ml_utils.cc
src/ml/ml_utils.h

index 00a470f..4dcc99d 100755 (executable)
@@ -54,6 +54,12 @@ var VerbosityLevel = {
     SUMMARY_TENSOR: 'SUMMARY_TENSOR'
 };
 
+var SaveFormat = {
+    FORMAT_BIN: 'FORMAT_BIN',
+    FORMAT_INI: 'FORMAT_INI',
+    FORMAT_INI_WITH_BIN: 'FORMAT_INI_WITH_BIN'
+}
+
 var Layer = function(id, type) {
     Object.defineProperties(this, {
         name: {
@@ -336,6 +342,11 @@ Model.prototype.run = function() {
     }
 };
 
+var ValidBasicExceptions = [
+    'TypeMismatchError',
+    'AbortError'
+];
+
 Model.prototype.summarize = function() {
     var args = validator_.validateArgs(arguments, [
         {
@@ -365,11 +376,52 @@ Model.prototype.summarize = function() {
     return result.summary
 };
 
-var ValidBasicExceptions = [
+var ValidModelSaveExceptions = [
+    'InvalidValuesError',
     'TypeMismatchError',
     'AbortError'
 ];
 
+Model.prototype.saveToFile = function () {
+    var args = validator_.validateArgs(arguments, [
+        {
+            name: 'path',
+            type: types_.STRING
+        },
+        {
+            name: 'format',
+            type: types_.ENUM,
+            values: Object.values(SaveFormat)
+        }
+    ]);
+
+    var callArgs = {
+        id: this._id,
+        savePath: args.path,
+        saveFormat: args.format
+    }
+
+    try {
+        callArgs.savePath = tizen.filesystem.toURI(args.path);
+    } catch (e) {
+        throw new WebAPIException(WebAPIException.InvalidValuesError, 'Invalid file path given');
+    }
+
+    if (tizen.filesystem.pathExists(callArgs.savePath)) {
+        throw new WebAPIException(WebAPIException.NO_MODIFICATION_ALLOWED_ERR, 'Path already exists - overwriting is not allowed');
+    }
+
+    var result = native_.callSync('MLTrainerModelSave', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidModelSaveExceptions,
+            AbortError
+        );
+    }
+};
+
 Model.prototype.addLayer = function() {
     var args = validator_.validateArgs(arguments, [
         {
index db9cc67..5d1dfbd 100644 (file)
@@ -83,6 +83,8 @@ const std::string kDatasetId = "datasetId";
 const std::string kOptimizerId = "optimizerId";
 const std::string kLevel = "level";
 const std::string kSummary = "summary";
+const std::string kSavePath = "savePath";
+const std::string kSaveFormat = "saveFormat";
 }  //  namespace
 
 using namespace common;
@@ -188,6 +190,7 @@ MlInstance::MlInstance()
   REGISTER_METHOD(MLTrainerModelAddLayer);
   REGISTER_METHOD(MLTrainerModelRun);
   REGISTER_METHOD(MLTrainerModelSummarize);
+  REGISTER_METHOD(MLTrainerModelSave);
   REGISTER_METHOD(MLTrainerModelSetDataset);
   REGISTER_METHOD(MLTrainerModelSetOptimizer);
   REGISTER_METHOD(MLTrainerDatasetCreateGenerator);
@@ -1914,6 +1917,32 @@ void MlInstance::MLTrainerModelSummarize(const picojson::value& args, picojson::
   ReportSuccess(out);
 }
 
+void MlInstance::MLTrainerModelSave(const picojson::value& args,
+                                    picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kSavePath, std::string, out);
+  CHECK_ARGS(args, kSaveFormat, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto path = args.get(kSavePath).get<std::string>();
+
+  ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN;
+  PlatformResult result = types::ModelSaveFormatEnum.getValue(
+      args.get(kSaveFormat).get<std::string>(), &model_format);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  result = trainer_manager_.ModelSave(id, path, model_format);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
+}
+
 void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
   ScopeLogger("args: %s", args.serialize().c_str());
   CHECK_ARGS(args, kId, double, out);
index da95fcd..69bc070 100644 (file)
@@ -162,6 +162,7 @@ class MlInstance : public common::ParsedInstance {
   void MLTrainerModelAddLayer(const picojson::value& args, picojson::object& out);
   void MLTrainerModelRun(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSummarize(const picojson::value& args, picojson::object& out);
+  void MLTrainerModelSave(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out);
 
index dbef05f..071ccb3 100644 (file)
@@ -271,6 +271,35 @@ PlatformResult TrainerManager::ModelSummarize(int id,
   return PlatformResult();
 }
 
+PlatformResult TrainerManager::ModelSave(int id,
+                                         const std::string& path,
+                                         ml_train_model_format_e format) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  auto& model = models_[id];
+
+  auto tmpString = path;
+  if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
+    // remove 'file://' prefix from path before passing to native api
+    tmpString.erase(0, FILE_PATH_PREFIX.length());
+  }
+
+  LoggerI("Saving model to file: %s", tmpString.c_str());
+  int ret_val = ml_train_model_save(model, tmpString.c_str(), format);
+
+  if (ret_val != 0) {
+    LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+  }
+
+  return PlatformResult();
+}
+
 PlatformResult TrainerManager::CreateLayer(int& id,
                                            ml_train_layer_type_e type) {
   ScopeLogger();
@@ -363,9 +392,9 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
 
   if (!train_file.empty()) {
     auto tmpString = train_file;
-    if (tmpString.substr(0, 7) == "file://") {
+    if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
       // remove 'file://' prefix from path before passing to native api
-      tmpString.erase(0, 7);
+      tmpString.erase(0, FILE_PATH_PREFIX.length());
     }
 
     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TRAIN,
@@ -380,9 +409,9 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
 
   if (!valid_file.empty()) {
     auto tmpString = valid_file;
-    if (tmpString.substr(0, 7) == "file://") {
+    if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
       // remove 'file://' prefix from path before passing to native api
-      tmpString.erase(0, 7);
+      tmpString.erase(0, FILE_PATH_PREFIX.length());
     }
     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_VALID,
                                         tmpString.c_str());
@@ -396,9 +425,9 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
 
   if (!test_file.empty()) {
     auto tmpString = test_file;
-    if (tmpString.substr(0, 7) == "file://") {
+    if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
       // remove 'file://' prefix from path before passing to native api
-      tmpString.erase(0, 7);
+      tmpString.erase(0, FILE_PATH_PREFIX.length());
     }
     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TEST,
                                         tmpString.c_str());
index 3b154dd..72f800b 100644 (file)
@@ -43,6 +43,9 @@ class TrainerManager {
   PlatformResult ModelSetDataset(int id, int datasetId);
   PlatformResult ModelSummarize(int id, ml_train_summary_type_e level,
                                 std::string& summary);
+  PlatformResult ModelSave(int id,
+                           const std::string& path,
+                           ml_train_model_format_e format);
 
   PlatformResult CreateLayer(int& id, ml_train_layer_type_e type);
   PlatformResult LayerSetProperty(int id, const std::string& name,
index daa361c..4b0734c 100644 (file)
@@ -95,6 +95,11 @@ const PlatformEnum<ml_train_summary_type_e> SummaryTypeEnum{
     {"SUMMARY_LAYER", ML_TRAIN_SUMMARY_LAYER},
     {"SUMMARY_TENSOR", ML_TRAIN_SUMMARY_TENSOR}};
 
+const PlatformEnum<ml_train_model_format_e> ModelSaveFormatEnum{
+    {"FORMAT_BIN", ML_TRAIN_MODEL_FORMAT_BIN},
+    {"FORMAT_INI", ML_TRAIN_MODEL_FORMAT_INI},
+    {"FORMAT_INI_WITH_BIN", ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN}};
+
 }  // namespace types
 
 namespace util {
index ccfeb16..6569623 100644 (file)
@@ -48,6 +48,7 @@ extern const PlatformEnum<ml_tensor_type_e> TensorTypeEnum;
 extern const PlatformEnum<ml_train_optimizer_type_e> OptimizerTypeEnum;
 extern const PlatformEnum<ml_train_layer_type_e> LayerTypeEnum;
 extern const PlatformEnum<ml_train_summary_type_e> SummaryTypeEnum;
+extern const PlatformEnum<ml_train_model_format_e> ModelSaveFormatEnum;
 
 }  // namespace types