[ML Trainier] Add create, setProperty and model's methods 70/266470/1
authorRafal Walczyna <r.walczyna@samsung.com>
Mon, 7 Jun 2021 13:26:42 +0000 (15:26 +0200)
committerPiotr Kosko/Tizen API (PLT) /SRPOL/Engineer/Samsung Electronics <p.kosko@samsung.com>
Fri, 12 Nov 2021 10:07:10 +0000 (11:07 +0100)
original change: https://review.tizen.org/gerrit/259417

Tested partially

Change-Id: Ifae3654265a391228a090200f8877aaba5be83a9
Signed-off-by: Rafal Walczyna <r.walczyna@samsung.com>
src/ml/js/ml_trainer.js
src/ml/ml_instance.cc
src/ml/ml_instance.h
src/ml/ml_trainer_manager.cc
src/ml/ml_trainer_manager.h
src/ml/ml_utils.cc
src/ml/ml_utils.h

index 685899c..b56b9c4 100755 (executable)
@@ -53,7 +53,7 @@ var VerbosityLevel = {
     SUMMARY_TENSOR: 'SUMMARY_TENSOR'
 };
 
-var Layer = function(id) {
+var Layer = function(id, type) {
     Object.defineProperties(this, {
         name: {
             enumerable: true,
@@ -63,9 +63,8 @@ var Layer = function(id) {
         },
         type: {
             enumerable: true,
-            get: function() {
-                // TODO
-            }
+            writable: false,
+            value: type
         },
         _id: { value: id, writable: false, enumerable: false }
     });
@@ -113,13 +112,12 @@ Layer.prototype.setProperty = function() {
     }
 };
 
-var Optimizer = function(id) {
+var Optimizer = function(id, type) {
     Object.defineProperties(this, {
         type: {
             enumerable: true,
-            get: function() {
-                // TODO
-            }
+            writable: false,
+            value: type
         },
         _id: { value: id, writable: false, enumerable: false }
     });
@@ -161,13 +159,12 @@ Optimizer.prototype.setProperty = function() {
     }
 };
 
-var Dataset = function(id) {
+var Dataset = function(id, type) {
     Object.defineProperties(this, {
         type: {
             enumerable: true,
-            get: function() {
-                // TODO
-            }
+            writable: false,
+            value: type
         },
         _id: { value: id, writable: false, enumerable: false }
     });
@@ -216,7 +213,17 @@ var Model = function(id) {
 };
 
 function ValidateCompileOptions(options) {
-    // TODO:
+    var args = {};
+    if (options.hasOwnProperty('loss_val')) {
+        args.loss_val = options.loss_val;
+    }
+    if (options.hasOwnProperty('loss')) {
+        args.loss_val = options.loss;
+    }
+    if (options.hasOwnProperty('batch_size')) {
+        args.loss_val = options.batch_size;
+    }
+    return args;
 }
 
 Model.prototype.compile = function() {
@@ -228,14 +235,43 @@ Model.prototype.compile = function() {
             nullable: true
         }
     ]);
+    var options = {};
     if (args.has.options) {
-        ValidateCompileOptions(args.options);
+        options = ValidateCompileOptions(args.options);
+    }
+
+    var callArgs = {
+        id: this._id,
+        options: options
+    };
+
+    var result = native_.callSync('MLTrainerModelCompile', callArgs);
+
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidSetPropertyExceptions,
+            AbortError
+        );
     }
     // TODO:
 };
 
 function ValidateRunOptions(options) {
-    // TODO:
+    var args = {};
+    if (options.hasOwnProperty('batch_size')) {
+        args.loss_val = options.batch_size;
+    }
+    if (options.hasOwnProperty('epochs')) {
+        args.loss_val = options.epochs;
+    }
+    if (options.hasOwnProperty('save_path')) {
+        args.loss_val = options.save_path;
+    }
+    if (options.hasOwnProperty('continue_train')) {
+        args.loss_val = options.continue_train;
+    }
+    return args;
 }
 
 Model.prototype.run = function() {
@@ -311,15 +347,36 @@ MachineLearningTrainer.prototype.createLayer = function() {
         }
     ]);
 
-    // TODO
-    return new Layer(NO_ID);
+    var nativeArgs = {
+        type: args.type
+    };
+
+    var result = native_.callSync('MLTrainerLayerCreate', nativeArgs);
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidCreateLayerExceptions,
+            AbortError
+        );
+    }
+
+    return new Layer(result.id);
 };
 
-function ValidateDatasetPaths(train, validate, test) {
-    // TODO
+function ValidateAndReturnDatasetPaths(train, validate, test) {
+    try {
+        var args = {
+            train: tizen.filesystem.toURI(train),
+            validate: validate ? tizen.filesystem.toURI(validate) : '',
+            test: test ? tizen.filesystem.toURI(test) : ''
+        };
+        return args;
+    } catch (e) {
+        throw new WebAPIException(WebAPIException.NOT_FOUND_ERR, 'Path is invalid');
+    }
 }
 
-MachineLearningTrainer.prototype.createGeneratorDataset = function() {
+MachineLearningTrainer.prototype.createFileDataset = function() {
     var args = validator_.validateArgs(arguments, [
         {
             name: 'train',
@@ -338,13 +395,27 @@ MachineLearningTrainer.prototype.createGeneratorDataset = function() {
             nullable: true
         }
     ]);
-    ValidateDatasetPaths(args.train, args.validate.args.test);
+    if (!args.has.train) {
+        throw new WebAPIException(
+            WebAPIException.TYPE_MISMATCH_ERR,
+            'Invalid parameter: training set path is undefined'
+        );
+    }
+    var nativeArgs = ValidateAndReturnDatasetPaths(args.train, args.validate, args.test);
 
-    // TODO
-    return new Dataset(NO_ID);
+    var result = native_.callSync('MLTrainerDatasetCreateFromFile', nativeArgs);
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidCreateLayerExceptions,
+            AbortError
+        );
+    }
+
+    return new Dataset(result.id, 'DATASET_FILE');
 };
 
-MachineLearningTrainer.prototype.createFileDataset = function() {
+MachineLearningTrainer.prototype.createGeneratorDataset = function() {
     var args = validator_.validateArgs(arguments, [
         {
             name: 'train',
@@ -366,24 +437,49 @@ MachineLearningTrainer.prototype.createFileDataset = function() {
     ValidateDatasetPaths(args.train, args.validate.args.test);
 
     // TODO
-    return new Dataset(NO_ID);
+    return new Dataset(result.id, 'DATASET_GENERATOR');
 };
 
+var ValidCreateOptimizerExceptions = [
+    'NotSupportedError',
+    'TypeMismatchError',
+    'AbortError'
+];
+
 MachineLearningTrainer.prototype.createOptimizer = function() {
     var args = validator_.validateArgs(arguments, [
         {
-            name: 'optimizer',
+            name: 'type',
             type: types_.ENUM,
             values: Object.values(OptimizerType),
             optional: false
         }
     ]);
 
-    // TODO
-    return new Optimizer(NO_ID);
+    var nativeArgs = {
+        type: args.type
+    };
+
+    var result = native_.callSync('MLTrainerOptimizerCreate', nativeArgs);
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidCreateOptimizerExceptions,
+            AbortError
+        );
+    }
+
+    return new Optimizer(result.id, args.type);
 };
 
-MachineLearningTrainer.prototype.constructModelWithConfiguration = function() {
+var ValidCreateModelWithConfigurationExceptions = [
+    'InvalidValuesError',
+    'NotFoundError',
+    'SecurityError',
+    'AbortError'
+];
+
+MachineLearningTrainer.prototype.createModelWithConfiguration = function() {
     var args = validator_.validateArgs(arguments, [
         {
             name: 'configPath',
@@ -398,12 +494,27 @@ MachineLearningTrainer.prototype.constructModelWithConfiguration = function() {
             throw new WebAPIException(WebAPIException.NOT_FOUND_ERR, 'Path is invalid');
         }
     }
+    var nativeArgs = {
+        configPath: args.configPath
+    };
 
-    // TODO
-    return new Model(NO_ID);
+    var result = native_.callSync('MLTrainerModelCreate', nativeArgs);
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidCreateModelWithConfigurationExceptions,
+            AbortError
+        );
+    }
+
+    return new Model(result.id);
 };
 
-MachineLearningTrainer.prototype.constructModel = function() {
-    // TODO
-    return new Model(NO_ID);
+MachineLearningTrainer.prototype.createModel = function() {
+    var result = native_.callSync('MLTrainerModelCreate', {});
+    if (native_.isFailure(result)) {
+        throw new WebAPIException(WebAPIException.AbortError, 'Could not create model');
+    }
+
+    return new Model(result.id);
 };
index bf728cf..e52990b 100644 (file)
@@ -72,6 +72,12 @@ const std::string kTensorsInfoId = "tensorsInfoId";
 const std::string kTimeout = "timeout";
 const std::string kType = "type";
 const std::string kValue = "value";
+
+// TODO: sort const
+const std::string kTrainFilePath = "trainFilePath";
+const std::string kValidFilePath = "validFilePath";
+const std::string kTestFilePath = "testFilePath";
+const std::string kOptions = "options";
 }  //  namespace
 
 using namespace common;
@@ -168,6 +174,21 @@ MlInstance::MlInstance()
   REGISTER_METHOD(MLPipelineManagerCustomFilterOutput);
   REGISTER_METHOD(MLPipelineManagerUnregisterCustomFilter);
 
+  REGISTER_METHOD(MLTrainerLayerSetProperty);
+  REGISTER_METHOD(MLTrainerLayerCreate);
+  REGISTER_METHOD(MLTrainerOptimizerSetProperty);
+  REGISTER_METHOD(MLTrainerOptimizerCreate);
+  REGISTER_METHOD(MLTrainerModelCreate);
+  REGISTER_METHOD(MLTrainerModelCompile);
+  REGISTER_METHOD(MLTrainerModelAddLayer);
+  REGISTER_METHOD(MLTrainerModelRun);
+  REGISTER_METHOD(MLTrainerModelSummarize);
+  REGISTER_METHOD(MLTrainerModelSetDataset);
+  REGISTER_METHOD(MLTrainerModelSetOptimizer);
+  REGISTER_METHOD(MLTrainerDatasetCreateGenerator);
+  REGISTER_METHOD(MLTrainerDatasetCreateFromFile);
+  REGISTER_METHOD(MLTrainerDatasetSetProperty);
+
 #undef REGISTER_METHOD
 }
 
@@ -1685,33 +1706,188 @@ void MlInstance::MLPipelineValveIsOpen(const picojson::value& args, picojson::ob
 }
 
 void MlInstance::MLTrainerLayerSetProperty(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kName, std::string, out);
+  CHECK_ARGS(args, kValue, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto name = args.get(kName).get<std::string>();
+  auto value = args.get(kValue).get<std::string>();
+
+  PlatformResult result = trainer_manager_.LayerSetProperty(id, name, value);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
+
 void MlInstance::MLTrainerLayerCreate(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kType, std::string, out);
+  int id = -1;
+
+  LayerType layer_type = LayerType::LAYER_UNKNOWN;
+  PlatformResult result =
+      types::LayerTypeEnum.getValue(args.get(kType).get<std::string>(), &layer_type);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  result = trainer_manager_.CreateLayer(id, layer_type);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  out[kId] = picojson::value(static_cast<double>(id));
+  ReportSuccess(out);
 }
 
 void MlInstance::MLTrainerOptimizerSetProperty(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kName, std::string, out);
+  CHECK_ARGS(args, kValue, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto name = args.get(kName).get<std::string>();
+  auto value = args.get(kValue).get<std::string>();
+
+  PlatformResult result = trainer_manager_.OptimizerSetProperty(id, name, value);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
+
 void MlInstance::MLTrainerOptimizerCreate(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kType, std::string, out);
+  int id = -1;
+
+  OptimizerType optimizer_type = OptimizerType::UNKNOWN;
+  PlatformResult result =
+      types::OptimizerTypeEnum.getValue(args.get(kType).get<std::string>(), &optimizer_type);
+  if (!result) {
+    LogAndReportError(result, &out);
+    return;
+  }
+
+  result = trainer_manager_.CreateOptimizer(id, optimizer_type);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  out[kId] = picojson::value(static_cast<double>(id));
+  ReportSuccess(out);
 }
 
-void MlInstance::MLTrainerModelConstruct(const picojson::value& args, picojson::object& out) {
+void MlInstance::MLTrainerModelCreate(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  int id = -1;
+  PlatformResult result;
+  if (args.contains(kModelPath)) {
+    // create model with config file
+    CHECK_ARGS(args, kModelPath, std::string, out);
+    const auto& config_path =
+        common::tools::ConvertUriToPath(args.get(kModelPath).get<std::string>());
+    CHECK_STORAGE_ACCESS(config_path, &out);
+
+    result = trainer_manager_.CreateModel(id, config_path);
+  } else {
+    result = trainer_manager_.CreateModel(id);
+  }
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  out[kId] = picojson::value(static_cast<double>(id));
+  ReportSuccess(out);
 }
+
 void MlInstance::MLTrainerModelCompile(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kOptions, picojson::object, out);
+
+  auto options = args.get(kOptions).get<picojson::object>();
+  auto id = static_cast<int>(args.get(kId).get<double>());
+
+  PlatformResult result = trainer_manager_.ModelCompile(id, options);
+
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
+
 void MlInstance::MLTrainerModelAddLayer(const picojson::value& args, picojson::object& out) {
+  ScopeLogger();
 }
+
 void MlInstance::MLTrainerModelRun(const picojson::value& args, picojson::object& out) {
+  ScopeLogger();
 }
+
 void MlInstance::MLTrainerModelSummarize(const picojson::value& args, picojson::object& out) {
+  ScopeLogger();
 }
+
 void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
+  ScopeLogger();
 }
+
 void MlInstance::MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out) {
+  ScopeLogger();
 }
-void MlInstance::MLTrainerCreateGeneratorDataset(const picojson::value& args,
+
+void MlInstance::MLTrainerDatasetCreateGenerator(const picojson::value& args,
                                                  picojson::object& out) {
+  ScopeLogger();
+}
+
+void MlInstance::MLTrainerDatasetCreateFromFile(const picojson::value& args,
+                                                picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kTrainFilePath, std::string, out);
+  CHECK_ARGS(args, kValidFilePath, std::string, out);
+  CHECK_ARGS(args, kTestFilePath, std::string, out);
+  int id = -1;
+
+  const std::string& train_file_path = args.get(kTrainFilePath).get<std::string>();
+  const std::string& valid_file_path = args.get(kValidFilePath).get<std::string>();
+  const std::string& test_file_path = args.get(kTestFilePath).get<std::string>();
+
+  PlatformResult result =
+      trainer_manager_.CreateFileDataset(id, train_file_path, valid_file_path, test_file_path);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  out[kId] = picojson::value(static_cast<double>(id));
+  ReportSuccess(out);
 }
-void MlInstance::MLTrainerCreateFileDataset(const picojson::value& args, picojson::object& out) {
+
+void MlInstance::MLTrainerDatasetSetProperty(const picojson::value& args, picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+  CHECK_ARGS(args, kId, double, out);
+  CHECK_ARGS(args, kName, std::string, out);
+  CHECK_ARGS(args, kValue, std::string, out);
+
+  auto id = static_cast<int>(args.get(kId).get<double>());
+  auto name = args.get(kName).get<std::string>();
+  auto value = args.get(kValue).get<std::string>();
+
+  PlatformResult result = trainer_manager_.DatasetSetProperty(id, name, value);
+  if (!result) {
+    ReportError(result, &out);
+    return;
+  }
+  ReportSuccess(out);
 }
 
 #undef CHECK_EXIST
index b99890d..da95fcd 100644 (file)
@@ -149,7 +149,7 @@ class MlInstance : public common::ParsedInstance {
 
   void MLPipelineValveIsOpen(const picojson::value& args, picojson::object& out);
 
-  Trainer trainer_manager_;
+  TrainerManager trainer_manager_;
 
   void MLTrainerLayerSetProperty(const picojson::value& args, picojson::object& out);
   void MLTrainerLayerCreate(const picojson::value& args, picojson::object& out);
@@ -157,15 +157,17 @@ class MlInstance : public common::ParsedInstance {
   void MLTrainerOptimizerSetProperty(const picojson::value& args, picojson::object& out);
   void MLTrainerOptimizerCreate(const picojson::value& args, picojson::object& out);
 
-  void MLTrainerModelConstruct(const picojson::value& args, picojson::object& out);
+  void MLTrainerModelCreate(const picojson::value& args, picojson::object& out);
   void MLTrainerModelCompile(const picojson::value& args, picojson::object& out);
   void MLTrainerModelAddLayer(const picojson::value& args, picojson::object& out);
   void MLTrainerModelRun(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSummarize(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out);
   void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out);
-  void MLTrainerCreateGeneratorDataset(const picojson::value& args, picojson::object& out);
-  void MLTrainerCreateFileDataset(const picojson::value& args, picojson::object& out);
+
+  void MLTrainerDatasetCreateGenerator(const picojson::value& args, picojson::object& out);
+  void MLTrainerDatasetCreateFromFile(const picojson::value& args, picojson::object& out);
+  void MLTrainerDatasetSetProperty(const picojson::value& args, picojson::object& out);
 };
 
 }  // namespace ml
index 7cc5d07..35227b0 100644 (file)
@@ -35,6 +35,179 @@ TrainerManager::~TrainerManager() {
   ScopeLogger();
 }
 
+PlatformResult TrainerManager::CreateModel(int& id) {
+  ScopeLogger();
+
+  try {
+    auto model = train::createModel(train::ModelType::NEURAL_NET);
+    models_[next_model_id_] = std::move(model);
+    id = next_model_id_++;
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    // TODO: Add errors handling
+    LoggerE("Could not create model: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::CreateModel(int& id, const std::string config) {
+  ScopeLogger();
+
+  try {
+    auto model = train::createModel(train::ModelType::NEURAL_NET);
+    model->loadFromConfig(config);
+    models_[next_model_id_] = std::move(model);
+    id = next_model_id_++;
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    // TODO: Add errors handling
+    LoggerE("Could not create model: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::ModelCompile(int id, const picojson::object& options) {
+  ScopeLogger();
+
+  if (models_.find(id) == models_.end()) {
+    LoggerE("Could not find model with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+  }
+
+  auto& model = models_[id];
+  std::stringstream ss;
+  for (const auto& opt : options) {
+    const auto& key = opt.first;
+    const auto& value = opt.second.get<std::string>();
+    ss << key << "=" << value;
+    try {
+      model->setProperty({ss.str()});
+    } catch (const std::exception& e) {
+      LoggerE("Could not create set property: %s", e.what());
+      return PlatformResult(ErrorCode::INVALID_VALUES_ERR, e.what());
+    }
+    ss.clear();
+  }
+
+  try {
+    model->compile();
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    // TODO: Add errors handling
+    LoggerE("Could not create model: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::ModelRun(int id, const picojson::object& options) {
+  ScopeLogger();
+  return PlatformResult();
+}
+
+PlatformResult TrainerManager::CreateLayer(int& id, train::LayerType type) {
+  ScopeLogger();
+
+  try {
+    auto layer = train::createLayer(type);
+    layers_[next_layer_id_] = std::move(layer);
+    id = next_layer_id_++;
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    LoggerE("Could not create layer: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::LayerSetProperty(int& id, const std::string& name,
+                                                const std::string& value) {
+  ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
+
+  if (layers_.find(id) == layers_.end()) {
+    LoggerE("Could not find layer with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
+  }
+  auto layer = layers_[id];
+  std::stringstream ss;
+  ss << name << '=' << value;
+  try {
+    layer->setProperty({ss.str()});
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    LoggerE("Failed to set property for layer: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::CreateOptimizer(int& id, train::OptimizerType type) {
+  ScopeLogger();
+
+  try {
+    auto optimizer = train::createOptimizer(type);
+    optimizers_[next_optimizer_id_] = std::move(optimizer);
+    id = next_optimizer_id_++;
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    LoggerE("Could not create optimizer: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::OptimizerSetProperty(int& id, const std::string& name,
+                                                    const std::string& value) {
+  ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
+
+  if (optimizers_.find(id) == optimizers_.end()) {
+    LoggerE("Could not find optimizer with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
+  }
+  auto optimizer = optimizers_[id];
+  std::stringstream ss;
+  ss << name << '=' << value;
+  try {
+    optimizer->setProperty({ss.str()});
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    LoggerE("Failed to set property for optimizer: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string train_file,
+                                                 const std::string valid_file,
+                                                 const std::string test_file) {
+  ScopeLogger();
+  try {
+    auto dataset = train::createDataset(train::DatasetType::FILE, train_file.c_str(),
+                                        valid_file.c_str(), test_file.c_str());
+    datasets_[next_dataset_id_] = std::move(dataset);
+    id = next_layer_id_++;
+    return PlatformResult();
+
+  } catch (const std::exception& e) {
+    LoggerE("Failed to set property for dataset: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
+
+PlatformResult TrainerManager::DatasetSetProperty(int& id, const std::string& name,
+                                                  const std::string& value) {
+  ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
+
+  if (datasets_.find(id) == datasets_.end()) {
+    LoggerE("Could not find dataset with id: %d", id);
+    return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
+  }
+  auto dataset = datasets_[id];
+  std::stringstream ss;
+  ss << name << '=' << value;
+  try {
+    dataset->setProperty({ss.str()});
+    return PlatformResult();
+  } catch (const std::exception& e) {
+    LoggerE("Could not create layer: %s", e.what());
+    return PlatformResult(ErrorCode::ABORT_ERR, e.what());
+  }
+}
 
 }  // namespace ml
 }  // namespace extension
index 17bb576..2b98824 100644 (file)
 #ifndef ML_ML_TRAINER_MANAGER_H_
 #define ML_ML_TRAINER_MANAGER_H_
 
+#include <nntrainer/dataset.h>
+#include <nntrainer/layer.h>
+#include <nntrainer/model.h>
+#include <nntrainer/optimizer.h>
+
 #include <mutex>
 
 #include "common/platform_result.h"
 #include "ml_trainer.h"
 
-#include <nntrainer/model.h>
-
 using common::PlatformResult;
 
 namespace train = ml::train;
@@ -39,13 +42,30 @@ class TrainerManager {
   TrainerManager(const TrainerManager&) = delete;
   TrainerManager& operator=(const TrainerManager&) = delete;
 
+  PlatformResult CreateModel(int& id);
+  PlatformResult CreateModel(int& id, const std::string config);
+  PlatformResult ModelCompile(int id, const picojson::object& options);
+  PlatformResult ModelRun(int id, const picojson::object& options);
+
+  PlatformResult CreateLayer(int& id, train::LayerType type);
+  PlatformResult LayerSetProperty(int& id, const std::string& name, const std::string& value);
+
+  PlatformResult CreateOptimizer(int& id, train::OptimizerType type);
+  PlatformResult OptimizerSetProperty(int& id, const std::string& name, const std::string& value);
+
+  PlatformResult CreateFileDataset(int& id, const std::string train_file,
+                                   const std::string valid_file, const std::string test_file);
+  PlatformResult DatasetSetProperty(int& id, const std::string& name, const std::string& value);
+
  private:
   int next_model_id_ = 0;
   int next_layer_id_ = 0;
   int next_optimizer_id_ = 0;
   int next_dataset_id_ = 0;
   std::map<int, std::unique_ptr<train::Model>> models_;
-  std::mutex trainers_mutex_;
+  std::map<int, std::shared_ptr<train::Optimizer>> optimizers_;
+  std::map<int, std::shared_ptr<train::Layer>> layers_;
+  std::map<int, std::shared_ptr<train::Dataset>> datasets_;
 };
 
 }  // namespace ml
index 2cd17b7..8fe16dc 100644 (file)
@@ -60,6 +60,32 @@ const PlatformEnum<ml_tensor_type_e> TensorTypeEnum{
     {"INT64", ML_TENSOR_TYPE_INT64},     {"UINT64", ML_TENSOR_TYPE_UINT64},
     {"UNKNOWN", ML_TENSOR_TYPE_UNKNOWN}};
 
+// const PlatformEnum<DatasetType> DatasetTypeEnum{{"DATASET_GENERATOR", DatasetType::GENERATOR},
+//                                                 {"DATASET_FILE", DatasetType::FILE},
+//                                                 {"DATASET_UNKNOWN", DatasetType::UNKNOWN}};
+
+const PlatformEnum<OptimizerType> OptimizerTypeEnum{{"OPTIMIZER_ADAM", OptimizerType::ADAM},
+                                                    {"OPTIMIZER_SGD", OptimizerType::SGD},
+                                                    {"OPTIMIZER_UNKNOWN", OptimizerType::UNKNOWN}};
+
+const PlatformEnum<LayerType> LayerTypeEnum{
+    {"LAYER_IN", LayerType::LAYER_IN},
+    {"LAYER_FC", LayerType::LAYER_FC},
+    {"LAYER_BN", LayerType::LAYER_BN},
+    {"LAYER_CONV2D", LayerType::LAYER_CONV2D},
+    {"LAYER_POOLING2D", LayerType::LAYER_POOLING2D},
+    {"LAYER_FLATTEN", LayerType::LAYER_FLATTEN},
+    {"LAYER_ACTIVATION", LayerType::LAYER_ACTIVATION},
+    {"LAYER_ADDITION", LayerType::LAYER_ADDITION},
+    {"LAYER_CONCAT", LayerType::LAYER_CONCAT},
+    {"LAYER_MULTIOUT", LayerType::LAYER_MULTIOUT},
+    {"LAYER_LOSS", LayerType::LAYER_LOSS},
+    {"LAYER_BACKBONE_NNSTREAMER", LayerType::LAYER_BACKBONE_NNSTREAMER},
+    {"LAYER_BACKBONE_TFLITE", LayerType::LAYER_BACKBONE_TFLITE},
+    {"LAYER_EMBEDDING", LayerType::LAYER_EMBEDDING},
+    {"LAYER_RNN", LayerType::LAYER_RNN},
+    {"LAYER_UNKNOWN", LayerType::LAYER_UNKNOWN}};
+
 }  // types
 
 namespace util {
index e91b5aa..8edd6ec 100644 (file)
@@ -18,6 +18,9 @@
 #define ML_ML_UTILS_H_
 
 #include <nnstreamer/nnstreamer.h>
+#include <nntrainer/dataset.h>
+#include <nntrainer/layer.h>
+#include <nntrainer/optimizer.h>
 
 #if __cplusplus > 201402L
 #include <optional>
@@ -31,9 +34,13 @@ using common::optional;
 #include "common/platform_enum.h"
 #include "common/platform_result.h"
 
+using common::ErrorCode;
 using common::PlatformEnum;
 using common::PlatformResult;
-using common::ErrorCode;
+
+using ml::train::DatasetType;
+using ml::train::LayerType;
+using ml::train::OptimizerType;
 
 namespace extension {
 namespace ml {
@@ -44,27 +51,32 @@ extern const PlatformEnum<ml_nnfw_hw_e> HWTypeEnum;
 extern const PlatformEnum<ml_nnfw_type_e> NNFWTypeEnum;
 extern const PlatformEnum<ml_tensor_type_e> TensorTypeEnum;
 
-}  // types
+// extern const PlatformEnum<DatasetType> DatasetTypeEnum;
+extern const PlatformEnum<OptimizerType> OptimizerTypeEnum;
+extern const PlatformEnum<LayerType> LayerTypeEnum;
+
+}  // namespace types
 
 namespace util {
 
-PlatformResult ToPlatformResult(int ml_error_code, const std::string& error_message);
+PlatformResult ToPlatformResult(int ml_error_code,
+                                const std::string& error_message);
 
 bool CheckNNFWAvailability(const std::string& nnfw, const std::string& hw,
                            const optional<std::string> customRequirement);
 
-PlatformResult GetDimensionsFromJsonArray(const picojson::array& dim,
-                                          unsigned int dimensions[ML_TENSOR_RANK_LIMIT]);
-PlatformResult GetLocationFromJsonArray(const picojson::array& array,
-                                        unsigned int location[ML_TENSOR_RANK_LIMIT]);
+PlatformResult GetDimensionsFromJsonArray(
+    const picojson::array& dim, unsigned int dimensions[ML_TENSOR_RANK_LIMIT]);
+PlatformResult GetLocationFromJsonArray(
+    const picojson::array& array, unsigned int location[ML_TENSOR_RANK_LIMIT]);
 
-PlatformResult GetSizeFromJsonArray(const picojson::array& array,
-                                    unsigned int location[ML_TENSOR_RANK_LIMIT],
-                                    unsigned int dimensions[ML_TENSOR_RANK_LIMIT],
-                                    unsigned int size[ML_TENSOR_RANK_LIMIT]);
+PlatformResult GetSizeFromJsonArray(
+    const picojson::array& array, unsigned int location[ML_TENSOR_RANK_LIMIT],
+    unsigned int dimensions[ML_TENSOR_RANK_LIMIT],
+    unsigned int size[ML_TENSOR_RANK_LIMIT]);
 
-}  // util
-}  // ml
-}  // extension
+}  // namespace util
+}  // namespace ml
+}  // namespace extension
 
 #endif  // ML_ML_UTILS_H_