[ML][Training] Model saving to file
[platform/core/api/webapi-plugins.git] / src / ml / ml_instance.cc
index db9cc67..5d1dfbd 100644 (file)
@@ -83,6 +83,8 @@ const std::string kDatasetId = "datasetId";
 const std::string kOptimizerId = "optimizerId";
 const std::string kLevel = "level";
 const std::string kSummary = "summary";
+const std::string kSavePath = "savePath";
+const std::string kSaveFormat = "saveFormat";
 }  //  namespace
 
 using namespace common;
@@ -188,6 +190,7 @@ MlInstance::MlInstance()
   REGISTER_METHOD(MLTrainerModelAddLayer);
   REGISTER_METHOD(MLTrainerModelRun);
   REGISTER_METHOD(MLTrainerModelSummarize);
+  REGISTER_METHOD(MLTrainerModelSave);
   REGISTER_METHOD(MLTrainerModelSetDataset);
   REGISTER_METHOD(MLTrainerModelSetOptimizer);
   REGISTER_METHOD(MLTrainerDatasetCreateGenerator);
@@ -1914,6 +1917,32 @@ void MlInstance::MLTrainerModelSummarize(const picojson::value& args, picojson::
   ReportSuccess(out);
 }
 
+void MlInstance::MLTrainerModelSave(const picojson::value& args,
+                                    picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kSavePath, std::string, out);
+  CHECK_ARGS(args, kSaveFormat, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto path = args.get(kSavePath).get<std::string>();
+
+  ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN;
+  PlatformResult result = types::ModelSaveFormatEnum.getValue(
+      args.get(kSaveFormat).get<std::string>(), &model_format);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  result = trainer_manager_.ModelSave(id, path, model_format);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
+}
+
 void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
   ScopeLogger("args: %s", args.serialize().c_str());
   CHECK_ARGS(args, kId, double, out);