[ML] Train - added load() method implementation 58/271458/3
authorPiotr Kosko/Tizen API (PLT) /SRPOL/Engineer/Samsung Electronics <p.kosko@samsung.com>
Mon, 21 Feb 2022 11:56:59 +0000 (12:56 +0100)
committerPiotr Kosko/Tizen API (PLT) /SRPOL/Engineer/Samsung Electronics <p.kosko@samsung.com>
Thu, 24 Feb 2022 07:46:21 +0000 (08:46 +0100)
[ACR] https://code.sec.samsung.net/jira/browse/TWDAPI-285

[Verification] Code compiles without errors.
Verified in chrome console - load() works.
var m2 = tizen.ml.trainer.createModel()
m2.load("documents/ttt_INI_WITH_BIN.ini", "FORMAT_INI_WITH_BIN")

Change-Id: Ic9d248790814dee47f5d0b712fe15e59ee8b93b9

src/ml/js/ml_trainer.js
src/ml/ml_instance.cc
src/ml/ml_instance.h
src/ml/ml_trainer_manager.cc
src/ml/ml_trainer_manager.h

index 81e8d83..8354bec 100755 (executable)
@@ -516,6 +516,46 @@ Model.prototype.saveToFile = function () {
     }
 };
 
+Model.prototype.load = function () {
+    var args = validator_.validateArgs(arguments, [
+        {
+            name: 'path',
+            type: types_.STRING
+        },
+        {
+            name: 'format',
+            type: types_.ENUM,
+            values: Object.values(SaveFormat)
+        }
+    ]);
+
+    var callArgs = {
+        id: this._id,
+        savePath: args.path,
+        saveFormat: args.format
+    }
+
+    try {
+        callArgs.savePath = tizen.filesystem.toURI(args.path);
+    } catch (e) {
+        throw new WebAPIException(WebAPIException.InvalidValuesError, 'Invalid file path given');
+    }
+
+    if (!tizen.filesystem.pathExists(callArgs.savePath)) {
+        throw new WebAPIException(WebAPIException.NotFoundError, 'Path not found');
+    }
+
+    var result = native_.callSync('MLTrainerModelLoad', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidModelSaveExceptions,
+            AbortError
+        );
+    }
+};
+
 Model.prototype.addLayer = function() {
     var args = validator_.validateArgs(arguments, [
         {
index 9a4d447..6e256d7 100644 (file)
@@ -197,6 +197,7 @@ MlInstance::MlInstance()
   REGISTER_METHOD(MLTrainerModelSummarize);
   REGISTER_METHOD(MLTrainerModelCheckMetrics);
   REGISTER_METHOD(MLTrainerModelSave);
+  REGISTER_METHOD(MLTrainerModelLoad);
   REGISTER_METHOD(MLTrainerModelSetDataset);
   REGISTER_METHOD(MLTrainerModelSetOptimizer);
   REGISTER_METHOD(MLTrainerModelDispose);
@@ -2026,6 +2027,32 @@ void MlInstance::MLTrainerModelSave(const picojson::value& args,
   ReportSuccess(out);
 }
 
+void MlInstance::MLTrainerModelLoad(const picojson::value& args,
+                                    picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kSavePath, std::string, out);
+  CHECK_ARGS(args, kSaveFormat, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto path = args.get(kSavePath).get<std::string>();
+
+  ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN;
+  PlatformResult result = types::ModelSaveFormatEnum.getValue(
+      args.get(kSaveFormat).get<std::string>(), &model_format);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  result = trainer_manager_.ModelLoad(id, path, model_format);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
+}
+
 void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
   ScopeLogger("args: %s", args.serialize().c_str());
   CHECK_ARGS(args, kId, double, out);
index fbf5078..5fed8fa 100644 (file)
@@ -171,6 +171,7 @@ class MlInstance : public common::ParsedInstance {
   void MLTrainerModelCheckMetrics(const picojson::value& args,
                                   picojson::object& out);
   void MLTrainerModelSave(const picojson::value& args, picojson::object& out);
+  void MLTrainerModelLoad(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out);
   void MLTrainerModelDispose(const picojson::value& args,
index 86d852b..5afe64d 100644 (file)
@@ -468,7 +468,46 @@ PlatformResult TrainerManager::ModelSave(int id, const std::string& path,
   model->instanceLock.unlock();
 
   if (ret_val != ML_ERROR_NONE) {
-    LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
+    LoggerE("Could not save model to file: %d (%s)", ret_val,
+            ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+  }
+
+  return PlatformResult();
+}
+
+PlatformResult TrainerManager::ModelLoad(int id, const std::string& path,
+                                         ml_train_model_format_e format) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  auto& model = models_[id];
+  bool available = model->instanceLock.try_lock();
+  if (!available) {
+    LoggerE("Model locked - probaly training in progress");
+    return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR,
+                          "Model training in progress - cannot save");
+  }
+
+  auto tmpString = path;
+  if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
+    // remove 'file://' prefix from path before passing to native api
+    tmpString.erase(0, FILE_PATH_PREFIX.length());
+  }
+
+  LoggerI("Loading model from file: %s", tmpString.c_str());
+  int ret_val =
+      ml_train_model_load(model->getNative(), tmpString.c_str(), format);
+
+  model->instanceLock.unlock();
+
+  if (ret_val != ML_ERROR_NONE) {
+    LoggerE("Could not load model from file: %d (%s)", ret_val,
+            ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
index 009715b..4e2b100 100644 (file)
@@ -49,8 +49,9 @@ class TrainerManager {
                                 std::string& summary);
   PlatformResult CheckMetrics(int id, double train_loss, double valid_loss,
                               double valid_accuracy, bool* result);
-  PlatformResult ModelSave(int id,
-                           const std::string& path,
+  PlatformResult ModelSave(int id, const std::string& path,
+                           ml_train_model_format_e format);
+  PlatformResult ModelLoad(int id, const std::string& path,
                            ml_train_model_format_e format);
   PlatformResult ModelDispose(int id);