2 * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ml_trainer_manager.h"
19 #include "common/tools.h"
21 using common::ErrorCode;
22 using common::PlatformResult;
27 const std::string OPTION_SEPARATOR = " | ";
28 const std::string FILE_PATH_PREFIX = "file://";
30 TrainerManager::TrainerManager() {
34 TrainerManager::~TrainerManager() {
38 PlatformResult TrainerManager::CreateModel(int& id) {
41 ml_train_model_h n_model = NULL;
43 int ret_val = ml_train_model_construct(&n_model);
44 if (ret_val != ML_ERROR_NONE) {
45 LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
46 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
49 models_[next_model_id_] = std::make_shared<Model>(n_model);
50 id = next_model_id_++;
52 return PlatformResult();
55 PlatformResult TrainerManager::CreateModel(int& id, const std::string config) {
58 ml_train_model_h n_model = NULL;
60 int ret_val = ml_train_model_construct_with_conf(config.c_str(), &n_model);
61 if (ret_val != ML_ERROR_NONE) {
62 LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
63 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
66 models_[next_model_id_] = std::make_shared<Model>(n_model);
67 id = next_model_id_++;
69 return PlatformResult();
72 PlatformResult TrainerManager::ModelCompile(int id,
73 const picojson::object& options) {
76 if (models_.find(id) == models_.end()) {
77 LoggerE("Could not find model with id: %d", id);
78 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
81 auto& model = models_[id];
84 for (const auto& opt : options) {
85 const auto& key = opt.first;
86 if (opt.second.is<std::string>()) {
87 const auto& value = opt.second.get<std::string>();
88 ss << key << "=" << value << OPTION_SEPARATOR;
89 } else if (opt.second.is<double>()) {
90 const auto& value = opt.second.get<double>();
91 ss << key << "=" << value << OPTION_SEPARATOR;
93 LoggerE("Unexpected param type for: %s", key.c_str());
94 return PlatformResult(ErrorCode::ABORT_ERR,
95 "Unexpected param type for:" + key);
100 auto compileOpts = ss.str();
101 if (compileOpts.length() < OPTION_SEPARATOR.length()) {
102 ret_val = ml_train_model_compile(model->getNative(), NULL);
104 // remove trailing ' | ' from options string
106 compileOpts.substr(0, compileOpts.length() - OPTION_SEPARATOR.length());
107 LoggerI("Compiling model with options: %s", compileOpts.c_str());
109 ml_train_model_compile(model->getNative(), compileOpts.c_str(), NULL);
114 if (ret_val != ML_ERROR_NONE) {
115 LoggerE("Could not compile model: %d (%s)", ret_val, ml_strerror(ret_val));
116 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
118 model->setCompiled(true);
120 return PlatformResult();
123 PlatformResult TrainerManager::ModelRun(int id,
124 const picojson::object& options) {
127 if (models_.find(id) == models_.end()) {
128 LoggerE("Could not find model with id: %d", id);
129 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
132 auto& model = models_[id];
134 if (!model->isCompiled()) {
135 LoggerE("Trying to train model that is not compiled");
136 return PlatformResult(ErrorCode::INVALID_STATE_ERR,
137 "Cannot train model before compilation");
140 std::stringstream ss;
141 for (const auto& opt : options) {
142 const auto& key = opt.first;
143 if (opt.second.is<std::string>()) {
144 const auto& value = opt.second.get<std::string>();
145 ss << key << "=" << value << OPTION_SEPARATOR;
146 } else if (opt.second.is<double>()) {
147 const auto& value = opt.second.get<double>();
148 ss << key << "=" << value << OPTION_SEPARATOR;
150 LoggerE("Unexpected param type for: %s", key.c_str());
151 return PlatformResult(ErrorCode::ABORT_ERR,
152 "Unexpected param type for:" + key);
157 auto runOpts = ss.str();
159 if (runOpts.length() < OPTION_SEPARATOR.length()) {
160 ret_val = ml_train_model_run(model->getNative(), NULL);
162 // remove trailing ' | ' from options string
163 runOpts = runOpts.substr(0, runOpts.length() - OPTION_SEPARATOR.length());
164 LoggerI("Running model with options: %s", runOpts.c_str());
165 ret_val = ml_train_model_run(model->getNative(), runOpts.c_str(), NULL);
168 if (ret_val != ML_ERROR_NONE) {
169 LoggerE("Could not run (train) model: %d (%s)", ret_val,
170 ml_strerror(ret_val));
171 return PlatformResult(ErrorCode::UNKNOWN_ERR, ml_strerror(ret_val));
174 return PlatformResult();
177 PlatformResult TrainerManager::ModelAddLayer(int id, int layerId) {
180 if (models_.find(id) == models_.end()) {
181 LoggerE("Could not find model with id: %d", id);
182 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
185 if (layers_.find(layerId) == layers_.end()) {
186 LoggerE("Could not find layer with id: %d", id);
187 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
190 auto& model = models_[id];
191 auto& layer = layers_[layerId];
193 if (model->isCompiled()) {
194 LoggerE("Modification of compiled model");
195 return PlatformResult(ErrorCode::INVALID_STATE_ERR,
196 "Modification of compiled model not allowed");
200 ml_train_model_add_layer(model->getNative(), layer->getNative());
201 if (ret_val != ML_ERROR_NONE) {
202 LoggerE("Could not add layer to model: %d (%s)", ret_val,
203 ml_strerror(ret_val));
204 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
207 model->layerIndices.push_back(layerId);
208 layer->setAttached(true);
210 return PlatformResult();
213 PlatformResult TrainerManager::ModelSetOptimizer(int id, int optimizerId) {
216 if (models_.find(id) == models_.end()) {
217 LoggerE("Could not find model with id: %d", id);
218 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
221 if (optimizers_.find(optimizerId) == optimizers_.end()) {
222 LoggerE("Could not find optimizer with id: %d", id);
223 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
226 auto& model = models_[id];
227 auto& optimizer = optimizers_[optimizerId];
228 if (model->isCompiled()) {
229 LoggerE("Modification of compiled model");
230 return PlatformResult(ErrorCode::INVALID_STATE_ERR,
231 "Modification of compiled model not allowed");
235 ml_train_model_set_optimizer(model->getNative(), optimizer->getNative());
236 if (ret_val != ML_ERROR_NONE) {
237 LoggerE("Could not set optimizer for model: %d (%s)", ret_val,
238 ml_strerror(ret_val));
239 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
242 if (model->optimizerIndex != INVALID_INDEX) {
243 // "release" optimizer that has been set before
244 auto optPosition = optimizers_.find(model->optimizerIndex);
245 if (optPosition != optimizers_.end()) {
246 (*optPosition).second->setAttached(false);
248 // This should never happen but just in case check and log such situation
250 "Attached optimizer does not exist in map - some internal error "
255 model->optimizerIndex = optimizerId;
256 optimizer->setAttached(true);
258 return PlatformResult();
261 PlatformResult TrainerManager::ModelSetDataset(int id, int datasetId) {
264 if (models_.find(id) == models_.end()) {
265 LoggerE("Could not find model with id: %d", id);
266 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
269 if (datasets_.find(datasetId) == datasets_.end()) {
270 LoggerE("Could not find dataset with id: %d", id);
271 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
274 auto& model = models_[id];
275 auto& dataset = datasets_[datasetId];
277 if (model->isCompiled()) {
278 LoggerE("Modification of compiled model");
279 return PlatformResult(ErrorCode::INVALID_STATE_ERR,
280 "Modification of compiled model not allowed");
284 ml_train_model_set_dataset(model->getNative(), dataset->getNative());
285 if (ret_val != ML_ERROR_NONE) {
286 LoggerE("Could not set dataset for model: %d (%s)", ret_val,
287 ml_strerror(ret_val));
288 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
291 if (model->datasetIndex != INVALID_INDEX) {
292 // "release" dataset that has been set before
293 auto datPosition = datasets_.find(model->datasetIndex);
294 if (datPosition != datasets_.end()) {
295 (*datPosition).second->setAttached(false);
297 // This should never happen but just in case check and log such situation
299 "Attached dataset does not exist in map = some internal error faced");
303 model->datasetIndex = datasetId;
304 dataset->setAttached(true);
306 return PlatformResult();
309 PlatformResult TrainerManager::ModelSummarize(int id,
310 ml_train_summary_type_e level,
311 std::string& summary) {
314 if (models_.find(id) == models_.end()) {
315 LoggerE("Could not find model with id: %d", id);
316 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
319 auto& model = models_[id];
320 char* tmpSummary = NULL;
323 ml_train_model_get_summary(model->getNative(), level, &tmpSummary);
325 if (ret_val != ML_ERROR_NONE) {
326 LoggerE("Could not get summary for model: %d (%s)", ret_val,
327 ml_strerror(ret_val));
328 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
331 summary = tmpSummary;
334 return PlatformResult();
337 PlatformResult TrainerManager::ModelSave(int id,
338 const std::string& path,
339 ml_train_model_format_e format) {
342 if (models_.find(id) == models_.end()) {
343 LoggerE("Could not find model with id: %d", id);
344 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
347 auto& model = models_[id];
349 auto tmpString = path;
350 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
351 // remove 'file://' prefix from path before passing to native api
352 tmpString.erase(0, FILE_PATH_PREFIX.length());
355 LoggerI("Saving model to file: %s", tmpString.c_str());
357 ml_train_model_save(model->getNative(), tmpString.c_str(), format);
359 if (ret_val != ML_ERROR_NONE) {
360 LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
361 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
364 return PlatformResult();
367 PlatformResult TrainerManager::ModelDispose(int id) {
370 if (models_.find(id) == models_.end()) {
371 LoggerE("Could not find model with id: %d", id);
372 return PlatformResult(ErrorCode::NOT_FOUND_ERR, "Could not find model");
375 auto model = models_[id];
377 int ret_val = ml_train_model_destroy(model->getNative());
378 if (ret_val != ML_ERROR_NONE) {
379 LoggerE("Could not destroy model: %d (%s)", ret_val, ml_strerror(ret_val));
380 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
383 // When model is destroyed by ml_train_model_destroy() then all attached
384 // handles (layers, optimizer, dataset) are also destroyed. This means that
385 // after Model disposal all related objects in JS/C++ layer become invalid.
386 // Code below is removing all wrappers stored in TrainerManager based on
387 // identifiers taken from Model wrapper
388 if (model->optimizerIndex >= 0) {
389 LoggerD("Deleting attached optimizer: %d", model->optimizerIndex);
390 optimizers_.erase(model->optimizerIndex);
392 if (model->datasetIndex) {
393 LoggerD("Deleting attached dataset: %d", model->datasetIndex);
394 datasets_.erase(model->datasetIndex);
396 for (auto const& ls : model->layerIndices) {
397 LoggerD("Deleting attached layer: %d", ls);
402 return PlatformResult();
405 PlatformResult TrainerManager::CreateLayer(int& id,
406 ml_train_layer_type_e type) {
409 ml_train_layer_h n_layer = NULL;
411 int ret_val = ml_train_layer_create(&n_layer, type);
412 if (ret_val != ML_ERROR_NONE) {
413 LoggerE("Could not create layer: %s", ml_strerror(ret_val));
414 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
417 layers_[next_layer_id_] =
418 std::make_shared<NativeWrapper<ml_train_layer_h>>(n_layer);
419 id = next_layer_id_++;
420 return PlatformResult();
423 PlatformResult TrainerManager::LayerSetProperty(int id, const std::string& name,
424 const std::string& value) {
425 ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
427 if (layers_.find(id) == layers_.end()) {
428 LoggerE("Could not find layer with id: %d", id);
429 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
432 auto layer = layers_[id];
433 std::string opt = name + "=" + value;
436 ml_train_layer_set_property(layer->getNative(), opt.c_str(), NULL);
437 if (ret_val != ML_ERROR_NONE) {
438 LoggerE("Could not set layer property: %d (%s)", ret_val,
439 ml_strerror(ret_val));
440 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
442 return PlatformResult();
445 PlatformResult TrainerManager::LayerDispose(int id) {
448 if (layers_.find(id) == layers_.end()) {
449 LoggerE("Could not find layer with id: %d", id);
450 return PlatformResult(ErrorCode::NOT_FOUND_ERR, "Could not find layer");
453 auto layer = layers_[id];
454 // Layer added to Model cannot be destroyed individually.
455 // It will be destroyed when destroying the Model
456 // see comment in TrainerManager::ModelDispose()
457 if (layer->isAttached()) {
458 LoggerE("Trying to dispose layer attached to model");
459 return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR,
460 "Cannot dispose layer attached to model");
463 int ret_val = ml_train_layer_destroy(layer->getNative());
464 if (ret_val != ML_ERROR_NONE) {
465 LoggerE("Could not destroy layer: %d (%s)", ret_val, ml_strerror(ret_val));
466 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
471 return PlatformResult();
474 PlatformResult TrainerManager::CreateOptimizer(int& id,
475 ml_train_optimizer_type_e type) {
478 ml_train_optimizer_h n_optimizer = NULL;
480 int ret_val = ml_train_optimizer_create(&n_optimizer, type);
481 if (ret_val != ML_ERROR_NONE) {
482 LoggerE("Could not create optimizer: %d (%s)", ret_val,
483 ml_strerror(ret_val));
484 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
487 optimizers_[next_optimizer_id_] =
488 std::make_shared<NativeWrapper<ml_train_optimizer_h>>(n_optimizer);
489 id = next_optimizer_id_++;
490 return PlatformResult();
493 PlatformResult TrainerManager::OptimizerSetProperty(int id,
494 const std::string& name,
495 const std::string& value) {
496 ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
498 if (optimizers_.find(id) == optimizers_.end()) {
499 LoggerE("Could not find optimizer with id: %d", id);
500 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
503 auto optimizer = optimizers_[id];
504 std::string opt = name + "=" + value;
505 int ret_val = ml_train_optimizer_set_property(optimizer->getNative(),
507 if (ret_val != ML_ERROR_NONE) {
508 LoggerE("Could not set optimizer property: %d (%s)", ret_val,
509 ml_strerror(ret_val));
510 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
512 return PlatformResult();
515 PlatformResult TrainerManager::OptimizerDispose(int id) {
518 if (optimizers_.find(id) == optimizers_.end()) {
519 LoggerE("Could not find optimizer with id: %d", id);
520 return PlatformResult(ErrorCode::NOT_FOUND_ERR, "Could not find optimizer");
523 auto optimizer = optimizers_[id];
524 // Optimizer set to Model cannot be destroyed individually.
525 // It will be destroyed when destroying the Model
526 // see comment in TrainerManager::ModelDispose()
527 if (optimizer->isAttached()) {
528 LoggerE("Trying to dispose optimizer attached to model");
529 return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR,
530 "Cannot dispose optimizer attached to model");
533 int ret_val = ml_train_optimizer_destroy(optimizer->getNative());
534 if (ret_val != ML_ERROR_NONE) {
535 LoggerE("Could not destroy optimizer: %d (%s)", ret_val,
536 ml_strerror(ret_val));
537 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
540 optimizers_.erase(id);
542 return PlatformResult();
545 PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string train_file,
546 const std::string valid_file,
547 const std::string test_file) {
550 ml_train_dataset_h n_dataset = NULL;
552 int ret_val = ml_train_dataset_create(&n_dataset);
553 if (ret_val != ML_ERROR_NONE) {
554 LoggerE("Could not create dataset: %s", ml_strerror(ret_val));
555 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
558 if (!train_file.empty()) {
559 auto tmpString = train_file;
560 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
561 // remove 'file://' prefix from path before passing to native api
562 tmpString.erase(0, FILE_PATH_PREFIX.length());
565 ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TRAIN,
567 if (ret_val != ML_ERROR_NONE) {
568 LoggerE("Could not add train file %s to dataset: %s", tmpString.c_str(),
569 ml_strerror(ret_val));
570 ml_train_dataset_destroy(n_dataset);
571 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
575 if (!valid_file.empty()) {
576 auto tmpString = valid_file;
577 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
578 // remove 'file://' prefix from path before passing to native api
579 tmpString.erase(0, FILE_PATH_PREFIX.length());
581 ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_VALID,
583 if (ret_val != ML_ERROR_NONE) {
584 LoggerE("Could not add validation file %s to dataset: %s",
585 tmpString.c_str(), ml_strerror(ret_val));
586 ml_train_dataset_destroy(n_dataset);
587 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
591 if (!test_file.empty()) {
592 auto tmpString = test_file;
593 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
594 // remove 'file://' prefix from path before passing to native api
595 tmpString.erase(0, FILE_PATH_PREFIX.length());
597 ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TEST,
599 if (ret_val != ML_ERROR_NONE) {
600 LoggerE("Could not add test file %s to dataset: %s", tmpString.c_str(),
601 ml_strerror(ret_val));
602 ml_train_dataset_destroy(n_dataset);
603 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
607 datasets_[next_dataset_id_] =
608 std::make_shared<NativeWrapper<ml_train_dataset_h>>(n_dataset);
609 id = next_dataset_id_++;
610 return PlatformResult();
613 // MK-TODO Add creating Dataset with generator
615 PlatformResult TrainerManager::DatasetSetProperty(
617 const std::string& name,
618 const std::string& value,
619 ml_train_dataset_mode_e mode) {
620 ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
622 if (datasets_.find(id) == datasets_.end()) {
623 LoggerE("Could not find dataset with id: %d", id);
624 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
627 auto dataset = datasets_[id];
628 std::string opt = name + "=" + value;
630 int ret_val = ml_train_dataset_set_property_for_mode(dataset->getNative(),
631 mode, opt.c_str(), NULL);
632 if (ret_val != ML_ERROR_NONE) {
633 LoggerE("Could not set dataset property for mode %d: %d (%s)", mode,
634 ret_val, ml_strerror(ret_val));
635 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
638 return PlatformResult();
641 PlatformResult TrainerManager::DatasetDispose(int id) {
644 if (datasets_.find(id) == datasets_.end()) {
645 LoggerE("Could not find dataset with id: %d", id);
646 return PlatformResult(ErrorCode::NOT_FOUND_ERR, "Could not find dataset");
649 auto dataset = datasets_[id];
650 // Dataset set to Model cannot be destroyed individually.
651 // It will be destroyed when destroying the Model
652 // see comment in TrainerManager::ModelDispose()
653 if (dataset->isAttached()) {
654 LoggerE("Trying to dispose dataset attached to model");
655 return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR,
656 "Cannot dispose dataset attached to model");
659 int ret_val = ml_train_dataset_destroy(dataset->getNative());
660 if (ret_val != ML_ERROR_NONE) {
661 LoggerE("Could not destroy dataset: %d (%s)", ret_val,
662 ml_strerror(ret_val));
663 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
668 return PlatformResult();
672 } // namespace extension