[ML][Training] Dataset.setProperty() refactoring 48/269948/6
authorMarcin Kaminski <marcin.ka@partner.samsung.com>
Fri, 21 Jan 2022 18:44:41 +0000 (19:44 +0100)
committerMarcin Kaminski <marcin.ka@partner.samsung.com>
Wed, 9 Feb 2022 18:33:51 +0000 (19:33 +0100)
Changes:
- enum DatasetMode added
- Dataset.setProperty() expanded with 'mode' param to match C API design

Change-Id: Ia619a69d6367c2624fff9d402e6a3505ce33f14b

src/ml/js/ml_trainer.js
src/ml/ml_instance.cc
src/ml/ml_trainer_manager.cc
src/ml/ml_trainer_manager.h
src/ml/ml_utils.cc
src/ml/ml_utils.h

index 643a5f6..4c3a841 100755 (executable)
@@ -58,7 +58,13 @@ var SaveFormat = {
     FORMAT_BIN: 'FORMAT_BIN',
     FORMAT_INI: 'FORMAT_INI',
     FORMAT_INI_WITH_BIN: 'FORMAT_INI_WITH_BIN'
-}
+};
+
+var DatasetMode = {
+    MODE_TRAIN: 'MODE_TRAIN',
+    MODE_VALID: 'MODE_VALID',
+    MODE_TEST: 'MODE_TEST'
+};
 
 var Layer = function(id, type) {
     Object.defineProperties(this, {
@@ -215,20 +221,26 @@ Dataset.prototype.setProperty = function() {
         {
             name: 'value',
             type: types_.STRING
+        },
+        {
+            name: 'mode',
+            type: types_.ENUM,
+            values: Object.values(DatasetMode)
         }
     ]);
 
-    if (!args.has.name || !args.has.value) {
+    if (!args.has.name || !args.has.value || !args.has.mode) {
         throw new WebAPIException(
             WebAPIException.TYPE_MISMATCH_ERR,
-            'Invalid parameter: ' + (args.has.name ? 'value' : 'name') + ' is undefined'
+            'Invalid parameter: name, value and mode have to be defined'
         );
     }
 
     var callArgs = {
         id: this._id,
         name: args.name,
-        value: args.value
+        value: args.value,
+        mode: args.mode
     };
 
     var result = native_.callSync('MLTrainerDatasetSetProperty', callArgs);
index d7a1ba3..250eced 100644 (file)
@@ -85,6 +85,7 @@ const std::string kLevel = "level";
 const std::string kSummary = "summary";
 const std::string kSavePath = "savePath";
 const std::string kSaveFormat = "saveFormat";
+const std::string kMode = "mode";
 }  //  namespace
 
 using namespace common;
@@ -2058,12 +2059,21 @@ void MlInstance::MLTrainerDatasetSetProperty(const picojson::value& args, picojs
   CHECK_ARGS(args, kId, double, out);
   CHECK_ARGS(args, kName, std::string, out);
   CHECK_ARGS(args, kValue, std::string, out);
+  CHECK_ARGS(args, kMode, std::string, out);
 
   auto id = static_cast<int>(args.get(kId).get<double>());
   auto name = args.get(kName).get<std::string>();
   auto value = args.get(kValue).get<std::string>();
 
-  PlatformResult result = trainer_manager_.DatasetSetProperty(id, name, value);
+  ml_train_dataset_mode_e datasetMode = ML_TRAIN_DATASET_MODE_TRAIN;
+  PlatformResult result = types::DatasetModeEnum.getValue(
+      args.get(kMode).get<std::string>(), &datasetMode);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  result = trainer_manager_.DatasetSetProperty(id, name, value, datasetMode);
   if (!result) {
     ReportError(result, &out);
     return;
index 01043f9..00795d5 100644 (file)
@@ -612,9 +612,11 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
 
 // MK-TODO Add creating Dataset with generator
 
-PlatformResult TrainerManager::DatasetSetProperty(int id,
-                                                  const std::string& name,
-                                                  const std::string& value) {
+PlatformResult TrainerManager::DatasetSetProperty(
+    int id,
+    const std::string& name,
+    const std::string& value,
+    ml_train_dataset_mode_e mode) {
   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
 
   if (datasets_.find(id) == datasets_.end()) {
@@ -625,32 +627,12 @@ PlatformResult TrainerManager::DatasetSetProperty(int id,
   auto dataset = datasets_[id];
   std::string opt = name + "=" + value;
 
-  // ml_train_dataset_set_property() is marked as deprecated
-  // temporary set same property for all modes (all data files) if possible
-  int ret_val = ml_train_dataset_set_property_for_mode(
-      dataset->getNative(), ML_TRAIN_DATASET_MODE_TRAIN, opt.c_str(), NULL);
+  int ret_val = ml_train_dataset_set_property_for_mode(dataset->getNative(),
+                                                       mode, opt.c_str(), NULL);
   if (ret_val != ML_ERROR_NONE) {
-    LoggerE("Could not set dataset property for train mode: %d (%s)", ret_val,
-            ml_strerror(ret_val));
-    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
-  }
-
-  ret_val = ml_train_dataset_set_property_for_mode(
-      dataset->getNative(), ML_TRAIN_DATASET_MODE_VALID, opt.c_str(), NULL);
-  if (ret_val != ML_ERROR_NONE) {
-    LoggerE("Could not set dataset property for validation mode: %d (%s)",
+    LoggerE("Could not set dataset property for mode %d: %d (%s)", mode,
             ret_val, ml_strerror(ret_val));
-    // MK-TODO report error for each file when extracted to separate functions
-    // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
-  }
-
-  ret_val = ml_train_dataset_set_property_for_mode(
-      dataset->getNative(), ML_TRAIN_DATASET_MODE_TEST, opt.c_str(), NULL);
-  if (ret_val != ML_ERROR_NONE) {
-    LoggerE("Could not set dataset property for test mode: %d (%s)", ret_val,
-            ml_strerror(ret_val));
-    // MK-TODO report error for each file when extracted to separate functions
-    // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
   return PlatformResult();
index 15de4f8..4e38c5a 100644 (file)
@@ -62,8 +62,10 @@ class TrainerManager {
 
   PlatformResult CreateFileDataset(int& id, const std::string train_file,
                                    const std::string valid_file, const std::string test_file);
-  PlatformResult DatasetSetProperty(int id, const std::string& name,
-                                    const std::string& value);
+  PlatformResult DatasetSetProperty(int id,
+                                    const std::string& name,
+                                    const std::string& value,
+                                    ml_train_dataset_mode_e mode);
   PlatformResult DatasetDispose(int id);
 
  private:
index 4b0734c..92fea3f 100644 (file)
@@ -100,6 +100,11 @@ const PlatformEnum<ml_train_model_format_e> ModelSaveFormatEnum{
     {"FORMAT_INI", ML_TRAIN_MODEL_FORMAT_INI},
     {"FORMAT_INI_WITH_BIN", ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN}};
 
+const PlatformEnum<ml_train_dataset_mode_e> DatasetModeEnum{
+    {"MODE_TRAIN", ML_TRAIN_DATASET_MODE_TRAIN},
+    {"MODE_VALID", ML_TRAIN_DATASET_MODE_VALID},
+    {"MODE_TEST", ML_TRAIN_DATASET_MODE_TEST}};
+
 }  // namespace types
 
 namespace util {
index 6569623..37cd89c 100644 (file)
@@ -49,6 +49,7 @@ extern const PlatformEnum<ml_train_optimizer_type_e> OptimizerTypeEnum;
 extern const PlatformEnum<ml_train_layer_type_e> LayerTypeEnum;
 extern const PlatformEnum<ml_train_summary_type_e> SummaryTypeEnum;
 extern const PlatformEnum<ml_train_model_format_e> ModelSaveFormatEnum;
+extern const PlatformEnum<ml_train_dataset_mode_e> DatasetModeEnum;
 
 }  // namespace types