[ML][Training] Model implementation 71/268071/15
authorMarcin Kaminski <marcin.ka@partner.samsung.com>
Wed, 15 Dec 2021 18:34:35 +0000 (19:34 +0100)
committerMarcin Kaminski <marcin.ka@partner.samsung.com>
Fri, 4 Feb 2022 16:24:03 +0000 (16:24 +0000)
Implementation of Model methods:
- addLayer()
- setOptimizer()
- setDataset()
- summarize()
- compile()
- run() - uhnandled exception to be checked

Minor fix in Dataset creation and property setting.

Change-Id: I2f45bb449a34d0d959411a5120aee8c1e6a39da5

src/ml/js/ml_trainer.js
src/ml/ml_instance.cc
src/ml/ml_trainer_manager.cc
src/ml/ml_trainer_manager.h
src/ml/ml_utils.cc
src/ml/ml_utils.h

index 6339efa..00a470f 100755 (executable)
@@ -14,7 +14,7 @@
  *    limitations under the License.
  */
 
-var MachineLearningTrainer = function() {};
+var MachineLearningTrainer = function () { };
 
 var OptimizerType = {
     OPTIMIZER_ADAM: 'OPTIMIZER_ADAM',
@@ -216,22 +216,28 @@ var Model = function(id) {
 function ValidateCompileOptions(options) {
     var args = {};
     if (options.hasOwnProperty('loss_val')) {
-        args.loss_val = options.loss_val;
+        args.loss = options.loss_val;
     }
     if (options.hasOwnProperty('loss')) {
-        args.loss_val = options.loss;
+        args.loss = options.loss;
     }
     if (options.hasOwnProperty('batch_size')) {
-        args.loss_val = options.batch_size;
+        args.batch_size = options.batch_size;
     }
     return args;
 }
 
+var ValidModelCompileExceptions = [
+    'InvalidValuesError',
+    'TypeMismatchError',
+    'AbortError'
+];
+
 Model.prototype.compile = function() {
-    var args = validator.validateArgs(arguments, [
+    var args = validator_.validateArgs(arguments, [
         {
             name: 'options',
-            type: validator.Types.DICTIONARY,
+            type: validator_.Types.DICTIONARY,
             optional: true,
             nullable: true
         }
@@ -251,43 +257,83 @@ Model.prototype.compile = function() {
     if (native_.isFailure(result)) {
         throw native_.getErrorObjectAndValidate(
             result,
-            ValidSetPropertyExceptions,
+            ValidModelCompileExceptions,
             AbortError
         );
     }
-    // TODO:
 };
 
 function ValidateRunOptions(options) {
     var args = {};
     if (options.hasOwnProperty('batch_size')) {
-        args.loss_val = options.batch_size;
+        args.batch_size = options.batch_size;
     }
     if (options.hasOwnProperty('epochs')) {
-        args.loss_val = options.epochs;
+        args.epochs = options.epochs;
     }
     if (options.hasOwnProperty('save_path')) {
-        args.loss_val = options.save_path;
-    }
-    if (options.hasOwnProperty('continue_train')) {
-        args.loss_val = options.continue_train;
+        args.save_path = options.save_path;
     }
     return args;
 }
 
+var ValidModelRunExceptions = [
+    'InvalidValuesError',
+    'TypeMismatchError'
+];
+
 Model.prototype.run = function() {
-    var args = validator.validateArgs(arguments, [
+    var args = validator_.validateArgs(arguments, [
         {
             name: 'options',
-            type: validator.Types.DICTIONARY,
+            type: validator_.Types.DICTIONARY,
+            optional: true,
+            nullable: true
+        },
+        {
+            name: 'successCallback',
+            type: types_.FUNCTION
+        },
+        {
+            name: 'errorCallback',
+            type: types_.FUNCTION,
             optional: true,
             nullable: true
         }
     ]);
+    var runOptions = {};
     if (args.has.options) {
-        ValidateRunOptions(args.options);
+        runOptions = ValidateRunOptions(args.options);
+    }
+
+    var callArgs = {
+        id: this._id,
+        options: runOptions
+    };
+
+    var callback = function (result) {
+        if (native_.isFailure(result)) {
+            native_.callIfPossible(
+                args.errorCallback,
+                native_.getErrorObjectAndValidate(
+                    result,
+                    ValidModelRunExceptions,
+                    AbortError
+                )
+            );
+        } else {
+            args.successCallback();
+        }
+    };
+
+    var result = native_.call('MLTrainerModelRun', callArgs, callback);
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidModelRunExceptions,
+            AbortError
+        );
     }
-    // TODO
 };
 
 Model.prototype.summarize = function() {
@@ -296,12 +342,34 @@ Model.prototype.summarize = function() {
             name: 'level',
             type: types_.ENUM,
             values: Object.values(VerbosityLevel),
-            optional: false
+            optional: true,
+            nullable: true
         }
     ]);
-    // TODO
+
+    var callArgs = {
+        id: this._id,
+        level: args.level ? args.level : "SUMMARY_MODEL"
+    }
+
+    var result = native_.callSync('MLTrainerModelSummarize', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidBasicExceptions,
+            AbortError
+        );
+    }
+
+    return result.summary
 };
 
+var ValidBasicExceptions = [
+    'TypeMismatchError',
+    'AbortError'
+];
+
 Model.prototype.addLayer = function() {
     var args = validator_.validateArgs(arguments, [
         {
@@ -310,7 +378,27 @@ Model.prototype.addLayer = function() {
             values: Layer
         }
     ]);
-    // TODO
+
+    if (!args.has.layer) {
+        throw new WebAPIException(
+            WebAPIException.TYPE_MISMATCH_ERR, 'Invalid parameter: layer is undefined'
+        );
+    }
+
+    var callArgs = {
+        id: this._id,
+        layerId: args.layer._id
+    };
+
+    var result = native_.callSync('MLTrainerModelAddLayer', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidBasicExceptions,
+            AbortError
+        );
+    }
 };
 
 Model.prototype.setDataset = function() {
@@ -321,7 +409,27 @@ Model.prototype.setDataset = function() {
             values: Dataset
         }
     ]);
-    // TODO
+
+    if (!args.has.dataset) {
+        throw new WebAPIException(
+            WebAPIException.TYPE_MISMATCH_ERR, 'Invalid parameter: dataset is undefined'
+        );
+    }
+
+    var callArgs = {
+        id: this._id,
+        datasetId: args.dataset._id
+    };
+
+    var result = native_.callSync('MLTrainerModelSetDataset', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidBasicExceptions,
+            AbortError
+        );
+    }
 };
 
 Model.prototype.setOptimizer = function() {
@@ -332,7 +440,27 @@ Model.prototype.setOptimizer = function() {
             values: Optimizer
         }
     ]);
-    // TODO
+
+    if (!args.has.optimizer) {
+        throw new WebAPIException(
+            WebAPIException.TYPE_MISMATCH_ERR, 'Invalid parameter: optimizer is undefined'
+        );
+    }
+
+    var callArgs = {
+        id: this._id,
+        optimizerId: args.optimizer._id
+    };
+
+    var result = native_.callSync('MLTrainerModelSetOptimizer', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidBasicExceptions,
+            AbortError
+        );
+    }
 };
 
 var ValidCreateLayerExceptions = ['NotSupportedError', 'TypeMismatchError', 'AbortError'];
@@ -361,7 +489,7 @@ MachineLearningTrainer.prototype.createLayer = function() {
         );
     }
 
-    return new Layer(result.id);
+    return new Layer(result.id, args.type);
 };
 
 function ValidateAndReturnDatasetPaths(train, valid, test) {
index 6d0f74c..db9cc67 100644 (file)
@@ -78,6 +78,11 @@ const std::string kTrain = "train";
 const std::string kValid = "valid";
 const std::string kTest = "test";
 const std::string kOptions = "options";
+const std::string kLayerId = "layerId";
+const std::string kDatasetId = "datasetId";
+const std::string kOptimizerId = "optimizerId";
+const std::string kLevel = "level";
+const std::string kSummary = "summary";
 }  //  namespace
 
 using namespace common;
@@ -1811,10 +1816,12 @@ void MlInstance::MLTrainerModelCreate(const picojson::value& args, picojson::obj
 void MlInstance::MLTrainerModelCompile(const picojson::value& args, picojson::object& out) {
   ScopeLogger("args: %s", args.serialize().c_str());
   CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kOptions, picojson::object, out);
 
   auto id = static_cast<int>(args.get(kId).get<double>());
+  auto options = args.get(kOptions).get<picojson::object>();
 
-  PlatformResult result = trainer_manager_.ModelCompile(id);
+  PlatformResult result = trainer_manager_.ModelCompile(id, options);
 
   if (!result) {
     ReportError(result, &out);
@@ -1824,28 +1831,126 @@ void MlInstance::MLTrainerModelCompile(const picojson::value& args, picojson::ob
 }
 
 void MlInstance::MLTrainerModelAddLayer(const picojson::value& args, picojson::object& out) {
-  ScopeLogger();
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kLayerId, double, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto layerId = static_cast<int>(args.get(kLayerId).get<double>());
+
+  PlatformResult result = trainer_manager_.ModelAddLayer(id, layerId);
+
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
 
 void MlInstance::MLTrainerModelRun(const picojson::value& args, picojson::object& out) {
-  ScopeLogger();
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kOptions, picojson::object, out);
+  CHECK_ARGS(args, kCallbackId, double, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto options = args.get(kOptions).get<picojson::object>();
+  auto cb_id = args.get(kCallbackId).get<double>();
+
+  auto async_logic = [this, id, options](decltype(out) out) {
+    PlatformResult result;
+
+    try {
+      result = trainer_manager_.ModelRun(id, options);
+    } catch (...) {  // MK-TODO verify why this exception occurs
+      LoggerE("Unhandled and unexpected exception!!");
+      ReportError(result, &out);
+    }
+
+    if (!result) {
+      ReportError(result, &out);
+      return;
+    }
+
+    ReportSuccess(out);
+  };
+
+  this->worker_.add_job([this, cb_id, async_logic] {
+    picojson::value response = picojson::value(picojson::object());
+    picojson::object& async_out = response.get<picojson::object>();
+    async_out[kCallbackId] = picojson::value(cb_id);
+    async_logic(async_out);
+    this->PostMessage(response.serialize().c_str());
+  });
+
+  ReportSuccess(out);
 }
 
 void MlInstance::MLTrainerModelSummarize(const picojson::value& args, picojson::object& out) {
-  ScopeLogger();
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kLevel, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+
+  ml_train_summary_type_e summaryType = ML_TRAIN_SUMMARY_MODEL;
+  PlatformResult result = types::SummaryTypeEnum.getValue(
+      args.get(kLevel).get<std::string>(), &summaryType);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  std::string summary;
+
+  result = trainer_manager_.ModelSummarize(id, summaryType, summary);
+
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+
+  out[kSummary] = picojson::value(summary);
+  ReportSuccess(out);
 }
 
 void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
-  ScopeLogger();
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kDatasetId, double, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto datasetId = static_cast<int>(args.get(kDatasetId).get<double>());
+
+  PlatformResult result = trainer_manager_.ModelSetDataset(id, datasetId);
+
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
 
 void MlInstance::MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out) {
-  ScopeLogger();
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kOptimizerId, double, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto optimizerId = static_cast<int>(args.get(kOptimizerId).get<double>());
+
+  PlatformResult result = trainer_manager_.ModelSetOptimizer(id, optimizerId);
+
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
 
 void MlInstance::MLTrainerDatasetCreateGenerator(const picojson::value& args,
                                                  picojson::object& out) {
-  ScopeLogger();
+  ScopeLogger("args: %s", args.serialize().c_str());
 }
 
 void MlInstance::MLTrainerDatasetCreateFromFile(const picojson::value& args,
index 91d6ec6..dbef05f 100644 (file)
@@ -24,6 +24,9 @@ using common::PlatformResult;
 namespace extension {
 namespace ml {
 
+const std::string OPTION_SEPARATOR = " | ";
+const std::string FILE_PATH_PREFIX = "file://";
+
 TrainerManager::TrainerManager() {
   ScopeLogger();
 }
@@ -39,7 +42,7 @@ PlatformResult TrainerManager::CreateModel(int& id) {
 
   int ret_val = ml_train_model_construct(&n_model);
   if (ret_val != 0) {
-    LoggerE("Could not create model: %s", ml_strerror(ret_val));
+    LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
@@ -56,7 +59,7 @@ PlatformResult TrainerManager::CreateModel(int& id, const std::string config) {
 
   int ret_val = ml_train_model_construct_with_conf(config.c_str(), &n_model);
   if (ret_val != 0) {
-    LoggerE("Could not create model: %s", ml_strerror(ret_val));
+    LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
@@ -66,7 +69,8 @@ PlatformResult TrainerManager::CreateModel(int& id, const std::string config) {
   return PlatformResult();
 }
 
-PlatformResult TrainerManager::ModelCompile(int id) {
+PlatformResult TrainerManager::ModelCompile(int id,
+                                            const picojson::object& options) {
   ScopeLogger();
 
   if (models_.find(id) == models_.end()) {
@@ -76,16 +80,46 @@ PlatformResult TrainerManager::ModelCompile(int id) {
 
   auto& model = models_[id];
 
-  int ret_val = ml_train_model_compile(model, NULL);
+  std::stringstream ss;
+  for (const auto& opt : options) {
+    const auto& key = opt.first;
+    if (opt.second.is<std::string>()) {
+      const auto& value = opt.second.get<std::string>();
+      ss << key << "=" << value << OPTION_SEPARATOR;
+    } else if (opt.second.is<double>()) {
+      const auto& value = opt.second.get<double>();
+      ss << key << "=" << value << OPTION_SEPARATOR;
+    } else {
+      LoggerE("Unexpected param type for: %s", key.c_str());
+      return PlatformResult(ErrorCode::ABORT_ERR,
+                            "Unexpected param type for:" + key);
+    }
+  }
+
+  int ret_val = 0;
+  auto compileOpts = ss.str();
+  if (compileOpts.length() < OPTION_SEPARATOR.length()) {
+    ret_val = ml_train_model_compile(model, NULL);
+  } else {
+    // remove trailing ' | ' from options string
+    compileOpts =
+        compileOpts.substr(0, compileOpts.length() - OPTION_SEPARATOR.length());
+    LoggerI("Compiling model with options: %s", compileOpts.c_str());
+    ret_val = ml_train_model_compile(model, compileOpts.c_str(), NULL);
+  }
+
+  ss.clear();
+
   if (ret_val != 0) {
-    LoggerE("Could not compile model: %s", ml_strerror(ret_val));
+    LoggerE("Could not compile model: %d (%s)", ret_val, ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
   return PlatformResult();
 }
 
-PlatformResult TrainerManager::ModelRun(int id) {
+PlatformResult TrainerManager::ModelRun(int id,
+                                        const picojson::object& options) {
   ScopeLogger();
 
   if (models_.find(id) == models_.end()) {
@@ -95,15 +129,148 @@ PlatformResult TrainerManager::ModelRun(int id) {
 
   auto& model = models_[id];
 
-  int ret_val = ml_train_model_run(model, NULL);
+  std::stringstream ss;
+  for (const auto& opt : options) {
+    const auto& key = opt.first;
+    if (opt.second.is<std::string>()) {
+      const auto& value = opt.second.get<std::string>();
+      ss << key << "=" << value << OPTION_SEPARATOR;
+    } else if (opt.second.is<double>()) {
+      const auto& value = opt.second.get<double>();
+      ss << key << "=" << value << OPTION_SEPARATOR;
+    } else {
+      LoggerE("Unexpected param type for: %s", key.c_str());
+      return PlatformResult(ErrorCode::ABORT_ERR,
+                            "Unexpected param type for:" + key);
+    }
+  }
+
+  int ret_val = 0;
+  auto runOpts = ss.str();
+
+  if (runOpts.length() < OPTION_SEPARATOR.length()) {
+    ret_val = ml_train_model_run(model, NULL);
+  } else {
+    // remove trailing ' | ' from options string
+    runOpts = runOpts.substr(0, runOpts.length() - OPTION_SEPARATOR.length());
+    LoggerI("Running model with options: %s", runOpts.c_str());
+    ret_val = ml_train_model_run(model, runOpts.c_str(), NULL);
+  }
+
   if (ret_val != 0) {
-    LoggerE("Could not run model: %s", ml_strerror(ret_val));
+    LoggerE("Could not run (train) model: %d (%s)", ret_val,
+            ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::UNKNOWN_ERR, ml_strerror(ret_val));
+  }
+
+  return PlatformResult();
+}
+
+PlatformResult TrainerManager::ModelAddLayer(int id, int layerId) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  if (layers_.find(layerId) == layers_.end()) {
+    LoggerE("Could not find layer with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
+  }
+
+  auto& model = models_[id];
+  auto& layer = layers_[layerId];
+
+  int ret_val = ml_train_model_add_layer(model, layer);
+  if (ret_val != 0) {
+    LoggerE("Could not add layer to model: %d (%s)", ret_val,
+            ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
   return PlatformResult();
 }
 
+PlatformResult TrainerManager::ModelSetOptimizer(int id, int optimizerId) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  if (optimizers_.find(optimizerId) == optimizers_.end()) {
+    LoggerE("Could not find optimizer with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
+  }
+
+  auto& model = models_[id];
+  auto& optimizer = optimizers_[optimizerId];
+
+  int ret_val = ml_train_model_set_optimizer(model, optimizer);
+  if (ret_val != 0) {
+    LoggerE("Could not set optimizer for model: %d (%s)", ret_val,
+            ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+  }
+
+  return PlatformResult();
+}
+
+PlatformResult TrainerManager::ModelSetDataset(int id, int datasetId) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  if (datasets_.find(datasetId) == datasets_.end()) {
+    LoggerE("Could not find dataset with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
+  }
+
+  auto& model = models_[id];
+  auto& dataset = datasets_[datasetId];
+
+  int ret_val = ml_train_model_set_dataset(model, dataset);
+  if (ret_val != 0) {
+    LoggerE("Could not set dataset for model: %d (%s)", ret_val,
+            ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+  }
+
+  return PlatformResult();
+}
+
+PlatformResult TrainerManager::ModelSummarize(int id,
+                                              ml_train_summary_type_e level,
+                                              std::string& summary) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  auto& model = models_[id];
+  char* tmpSummary = NULL;
+
+  int ret_val = ml_train_model_get_summary(model, level, &tmpSummary);
+
+  if (ret_val != 0) {
+    LoggerE("Could not get summary for model: %d (%s)", ret_val,
+            ml_strerror(ret_val));
+    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+  }
+
+  summary = tmpSummary;
+  free(tmpSummary);
+
+  return PlatformResult();
+}
+
 PlatformResult TrainerManager::CreateLayer(int& id,
                                            ml_train_layer_type_e type) {
   ScopeLogger();
@@ -121,7 +288,7 @@ PlatformResult TrainerManager::CreateLayer(int& id,
   return PlatformResult();
 }
 
-PlatformResult TrainerManager::LayerSetProperty(int& id, const std::string& name,
+PlatformResult TrainerManager::LayerSetProperty(int id, const std::string& name,
                                                 const std::string& value) {
   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
 
@@ -135,7 +302,8 @@ PlatformResult TrainerManager::LayerSetProperty(int& id, const std::string& name
 
   int ret_val = ml_train_layer_set_property(layer, opt.c_str(), NULL);
   if (ret_val != 0) {
-    LoggerE("Could not set layer property: %s", ml_strerror(ret_val));
+    LoggerE("Could not set layer property: %d (%s)", ret_val,
+            ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
   return PlatformResult();
@@ -149,7 +317,8 @@ PlatformResult TrainerManager::CreateOptimizer(int& id,
 
   int ret_val = ml_train_optimizer_create(&n_optimizer, type);
   if (ret_val != 0) {
-    LoggerE("Could not create optimizer: %s", ml_strerror(ret_val));
+    LoggerE("Could not create optimizer: %d (%s)", ret_val,
+            ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
@@ -158,7 +327,8 @@ PlatformResult TrainerManager::CreateOptimizer(int& id,
   return PlatformResult();
 }
 
-PlatformResult TrainerManager::OptimizerSetProperty(int& id, const std::string& name,
+PlatformResult TrainerManager::OptimizerSetProperty(int id,
+                                                    const std::string& name,
                                                     const std::string& value) {
   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
 
@@ -171,7 +341,8 @@ PlatformResult TrainerManager::OptimizerSetProperty(int& id, const std::string&
   std::string opt = name + "=" + value;
   int ret_val = ml_train_optimizer_set_property(optimizer, opt.c_str(), NULL);
   if (ret_val != 0) {
-    LoggerE("Could not set optimizer property: %s", ml_strerror(ret_val));
+    LoggerE("Could not set optimizer property: %d (%s)", ret_val,
+            ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
   return PlatformResult();
@@ -191,10 +362,16 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
   }
 
   if (!train_file.empty()) {
+    auto tmpString = train_file;
+    if (tmpString.substr(0, 7) == "file://") {
+      // remove 'file://' prefix from path before passing to native api
+      tmpString.erase(0, 7);
+    }
+
     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TRAIN,
-                                        train_file.c_str());
+                                        tmpString.c_str());
     if (ret_val != 0) {
-      LoggerE("Could not add train file %s to dataset: %s", train_file.c_str(),
+      LoggerE("Could not add train file %s to dataset: %s", tmpString.c_str(),
               ml_strerror(ret_val));
       ml_train_dataset_destroy(n_dataset);
       return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
@@ -202,21 +379,31 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
   }
 
   if (!valid_file.empty()) {
+    auto tmpString = valid_file;
+    if (tmpString.substr(0, 7) == "file://") {
+      // remove 'file://' prefix from path before passing to native api
+      tmpString.erase(0, 7);
+    }
     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_VALID,
-                                        valid_file.c_str());
+                                        tmpString.c_str());
     if (ret_val != 0) {
       LoggerE("Could not add validation file %s to dataset: %s",
-              valid_file.c_str(), ml_strerror(ret_val));
+              tmpString.c_str(), ml_strerror(ret_val));
       ml_train_dataset_destroy(n_dataset);
       return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
     }
   }
 
   if (!test_file.empty()) {
+    auto tmpString = test_file;
+    if (tmpString.substr(0, 7) == "file://") {
+      // remove 'file://' prefix from path before passing to native api
+      tmpString.erase(0, 7);
+    }
     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TEST,
-                                        test_file.c_str());
+                                        tmpString.c_str());
     if (ret_val != 0) {
-      LoggerE("Could not add test file %s to dataset: %s", test_file.c_str(),
+      LoggerE("Could not add test file %s to dataset: %s", tmpString.c_str(),
               ml_strerror(ret_val));
       ml_train_dataset_destroy(n_dataset);
       return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
@@ -230,7 +417,8 @@ PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string trai
 
 // MK-TODO Add creating Dataset with generator
 
-PlatformResult TrainerManager::DatasetSetProperty(int& id, const std::string& name,
+PlatformResult TrainerManager::DatasetSetProperty(int id,
+                                                  const std::string& name,
                                                   const std::string& value) {
   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
 
@@ -243,11 +431,11 @@ PlatformResult TrainerManager::DatasetSetProperty(int& id, const std::string& na
   std::string opt = name + "=" + value;
 
   // ml_train_dataset_set_property() is marked as deprecated
-  // temporary set same property for all modes (all data files)
+  // temporary set same property for all modes (all data files) if possible
   int ret_val = ml_train_dataset_set_property_for_mode(
       dataset, ML_TRAIN_DATASET_MODE_TRAIN, opt.c_str(), NULL);
   if (ret_val != 0) {
-    LoggerE("Could not set dataset property for train mode: %s",
+    LoggerE("Could not set dataset property for train mode: %d (%s)", ret_val,
             ml_strerror(ret_val));
     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
@@ -255,17 +443,19 @@ PlatformResult TrainerManager::DatasetSetProperty(int& id, const std::string& na
   ret_val = ml_train_dataset_set_property_for_mode(
       dataset, ML_TRAIN_DATASET_MODE_VALID, opt.c_str(), NULL);
   if (ret_val != 0) {
-    LoggerE("Could not set dataset property for validation mode: %s",
-            ml_strerror(ret_val));
-    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+    LoggerE("Could not set dataset property for validation mode: %d (%s)",
+            ret_val, ml_strerror(ret_val));
+    // MK-TODO report error for each file when extracted to separate functions
+    // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
   ret_val = ml_train_dataset_set_property_for_mode(
       dataset, ML_TRAIN_DATASET_MODE_TEST, opt.c_str(), NULL);
   if (ret_val != 0) {
-    LoggerE("Could not set dataset property for test mode: %s",
+    LoggerE("Could not set dataset property for test mode: %d (%s)", ret_val,
             ml_strerror(ret_val));
-    return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+    // MK-TODO report error for each file when extracted to separate functions
+    // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
   }
 
   return PlatformResult();
index ac2cfde..3b154dd 100644 (file)
@@ -36,19 +36,26 @@ class TrainerManager {
 
   PlatformResult CreateModel(int& id);
   PlatformResult CreateModel(int& id, const std::string config);
-  PlatformResult ModelCompile(int id);
-  PlatformResult ModelRun(int id);
+  PlatformResult ModelCompile(int id, const picojson::object& options);
+  PlatformResult ModelRun(int id, const picojson::object& options);
+  PlatformResult ModelAddLayer(int id, int layerId);
+  PlatformResult ModelSetOptimizer(int id, int optimizerId);
+  PlatformResult ModelSetDataset(int id, int datasetId);
+  PlatformResult ModelSummarize(int id, ml_train_summary_type_e level,
+                                std::string& summary);
 
   PlatformResult CreateLayer(int& id, ml_train_layer_type_e type);
-  PlatformResult LayerSetProperty(int& id, const std::string& name,
+  PlatformResult LayerSetProperty(int id, const std::string& name,
                                   const std::string& value);
 
   PlatformResult CreateOptimizer(int& id, ml_train_optimizer_type_e type);
-  PlatformResult OptimizerSetProperty(int& id, const std::string& name, const std::string& value);
+  PlatformResult OptimizerSetProperty(int id, const std::string& name,
+                                      const std::string& value);
 
   PlatformResult CreateFileDataset(int& id, const std::string train_file,
                                    const std::string valid_file, const std::string test_file);
-  PlatformResult DatasetSetProperty(int& id, const std::string& name, const std::string& value);
+  PlatformResult DatasetSetProperty(int id, const std::string& name,
+                                    const std::string& value);
 
  private:
   int next_model_id_ = 0;
index 5855369..daa361c 100644 (file)
@@ -64,10 +64,6 @@ const PlatformEnum<ml_tensor_type_e> TensorTypeEnum{
     {"INT64", ML_TENSOR_TYPE_INT64},     {"UINT64", ML_TENSOR_TYPE_UINT64},
     {"UNKNOWN", ML_TENSOR_TYPE_UNKNOWN}};
 
-// const PlatformEnum<TODO> DatasetTypeEnum{{"DATASET_GENERATOR", TODO},
-//     {"DATASET_FILE", TODO},
-//     {"DATASET_UNKNOWN",TODO}};
-
 const PlatformEnum<ml_train_optimizer_type_e> OptimizerTypeEnum{
     {"OPTIMIZER_ADAM", ML_TRAIN_OPTIMIZER_TYPE_ADAM},
     {"OPTIMIZER_SGD", ML_TRAIN_OPTIMIZER_TYPE_SGD},
@@ -94,6 +90,11 @@ const PlatformEnum<ml_train_layer_type_e> LayerTypeEnum{
     {"LAYER_BACKBONE_NNSTREAMER", ML_TRAIN_LAYER_TYPE_BACKBONE_NNSTREAMER},
     {"LAYER_UNKNOWN", ML_TRAIN_LAYER_TYPE_UNKNOWN}};
 
+const PlatformEnum<ml_train_summary_type_e> SummaryTypeEnum{
+    {"SUMMARY_MODEL", ML_TRAIN_SUMMARY_MODEL},
+    {"SUMMARY_LAYER", ML_TRAIN_SUMMARY_LAYER},
+    {"SUMMARY_TENSOR", ML_TRAIN_SUMMARY_TENSOR}};
+
 }  // namespace types
 
 namespace util {
index 5deae9c..ccfeb16 100644 (file)
@@ -45,10 +45,9 @@ extern const PlatformEnum<ml_nnfw_hw_e> HWTypeEnum;
 extern const PlatformEnum<ml_nnfw_type_e> NNFWTypeEnum;
 extern const PlatformEnum<ml_tensor_type_e> TensorTypeEnum;
 
-// MK-TODO implement internal enum or remove from API design if not needed
-// extern const PlatformEnum<TODO> DatasetTypeEnum;
 extern const PlatformEnum<ml_train_optimizer_type_e> OptimizerTypeEnum;
 extern const PlatformEnum<ml_train_layer_type_e> LayerTypeEnum;
+extern const PlatformEnum<ml_train_summary_type_e> SummaryTypeEnum;
 
 }  // namespace types