2 * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ml_trainer_manager.h"
19 #include "common/tools.h"
21 using common::ErrorCode;
22 using common::PlatformResult;
27 const std::string OPTION_SEPARATOR = " | ";
28 const std::string FILE_PATH_PREFIX = "file://";
30 TrainerManager::TrainerManager() {
34 TrainerManager::~TrainerManager() {
38 PlatformResult TrainerManager::CreateModel(int& id) {
41 ml_train_model_h n_model = NULL;
43 int ret_val = ml_train_model_construct(&n_model);
45 LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
46 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
49 models_[next_model_id_] = n_model;
50 id = next_model_id_++;
52 return PlatformResult();
55 PlatformResult TrainerManager::CreateModel(int& id, const std::string config) {
58 ml_train_model_h n_model = NULL;
60 int ret_val = ml_train_model_construct_with_conf(config.c_str(), &n_model);
62 LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
63 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
66 models_[next_model_id_] = n_model;
67 id = next_model_id_++;
69 return PlatformResult();
72 PlatformResult TrainerManager::ModelCompile(int id,
73 const picojson::object& options) {
76 if (models_.find(id) == models_.end()) {
77 LoggerE("Could not find model with id: %d", id);
78 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
81 auto& model = models_[id];
84 for (const auto& opt : options) {
85 const auto& key = opt.first;
86 if (opt.second.is<std::string>()) {
87 const auto& value = opt.second.get<std::string>();
88 ss << key << "=" << value << OPTION_SEPARATOR;
89 } else if (opt.second.is<double>()) {
90 const auto& value = opt.second.get<double>();
91 ss << key << "=" << value << OPTION_SEPARATOR;
93 LoggerE("Unexpected param type for: %s", key.c_str());
94 return PlatformResult(ErrorCode::ABORT_ERR,
95 "Unexpected param type for:" + key);
100 auto compileOpts = ss.str();
101 if (compileOpts.length() < OPTION_SEPARATOR.length()) {
102 ret_val = ml_train_model_compile(model, NULL);
104 // remove trailing ' | ' from options string
106 compileOpts.substr(0, compileOpts.length() - OPTION_SEPARATOR.length());
107 LoggerI("Compiling model with options: %s", compileOpts.c_str());
108 ret_val = ml_train_model_compile(model, compileOpts.c_str(), NULL);
114 LoggerE("Could not compile model: %d (%s)", ret_val, ml_strerror(ret_val));
115 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
118 return PlatformResult();
121 PlatformResult TrainerManager::ModelRun(int id,
122 const picojson::object& options) {
125 if (models_.find(id) == models_.end()) {
126 LoggerE("Could not find model with id: %d", id);
127 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
130 auto& model = models_[id];
132 std::stringstream ss;
133 for (const auto& opt : options) {
134 const auto& key = opt.first;
135 if (opt.second.is<std::string>()) {
136 const auto& value = opt.second.get<std::string>();
137 ss << key << "=" << value << OPTION_SEPARATOR;
138 } else if (opt.second.is<double>()) {
139 const auto& value = opt.second.get<double>();
140 ss << key << "=" << value << OPTION_SEPARATOR;
142 LoggerE("Unexpected param type for: %s", key.c_str());
143 return PlatformResult(ErrorCode::ABORT_ERR,
144 "Unexpected param type for:" + key);
149 auto runOpts = ss.str();
151 if (runOpts.length() < OPTION_SEPARATOR.length()) {
152 ret_val = ml_train_model_run(model, NULL);
154 // remove trailing ' | ' from options string
155 runOpts = runOpts.substr(0, runOpts.length() - OPTION_SEPARATOR.length());
156 LoggerI("Running model with options: %s", runOpts.c_str());
157 ret_val = ml_train_model_run(model, runOpts.c_str(), NULL);
161 LoggerE("Could not run (train) model: %d (%s)", ret_val,
162 ml_strerror(ret_val));
163 return PlatformResult(ErrorCode::UNKNOWN_ERR, ml_strerror(ret_val));
166 return PlatformResult();
169 PlatformResult TrainerManager::ModelAddLayer(int id, int layerId) {
172 if (models_.find(id) == models_.end()) {
173 LoggerE("Could not find model with id: %d", id);
174 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
177 if (layers_.find(layerId) == layers_.end()) {
178 LoggerE("Could not find layer with id: %d", id);
179 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
182 auto& model = models_[id];
183 auto& layer = layers_[layerId];
185 int ret_val = ml_train_model_add_layer(model, layer);
187 LoggerE("Could not add layer to model: %d (%s)", ret_val,
188 ml_strerror(ret_val));
189 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
192 return PlatformResult();
195 PlatformResult TrainerManager::ModelSetOptimizer(int id, int optimizerId) {
198 if (models_.find(id) == models_.end()) {
199 LoggerE("Could not find model with id: %d", id);
200 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
203 if (optimizers_.find(optimizerId) == optimizers_.end()) {
204 LoggerE("Could not find optimizer with id: %d", id);
205 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
208 auto& model = models_[id];
209 auto& optimizer = optimizers_[optimizerId];
211 int ret_val = ml_train_model_set_optimizer(model, optimizer);
213 LoggerE("Could not set optimizer for model: %d (%s)", ret_val,
214 ml_strerror(ret_val));
215 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
218 return PlatformResult();
221 PlatformResult TrainerManager::ModelSetDataset(int id, int datasetId) {
224 if (models_.find(id) == models_.end()) {
225 LoggerE("Could not find model with id: %d", id);
226 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
229 if (datasets_.find(datasetId) == datasets_.end()) {
230 LoggerE("Could not find dataset with id: %d", id);
231 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
234 auto& model = models_[id];
235 auto& dataset = datasets_[datasetId];
237 int ret_val = ml_train_model_set_dataset(model, dataset);
239 LoggerE("Could not set dataset for model: %d (%s)", ret_val,
240 ml_strerror(ret_val));
241 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
244 return PlatformResult();
247 PlatformResult TrainerManager::ModelSummarize(int id,
248 ml_train_summary_type_e level,
249 std::string& summary) {
252 if (models_.find(id) == models_.end()) {
253 LoggerE("Could not find model with id: %d", id);
254 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
257 auto& model = models_[id];
258 char* tmpSummary = NULL;
260 int ret_val = ml_train_model_get_summary(model, level, &tmpSummary);
263 LoggerE("Could not get summary for model: %d (%s)", ret_val,
264 ml_strerror(ret_val));
265 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
268 summary = tmpSummary;
271 return PlatformResult();
274 PlatformResult TrainerManager::ModelSave(int id,
275 const std::string& path,
276 ml_train_model_format_e format) {
279 if (models_.find(id) == models_.end()) {
280 LoggerE("Could not find model with id: %d", id);
281 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
284 auto& model = models_[id];
286 auto tmpString = path;
287 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
288 // remove 'file://' prefix from path before passing to native api
289 tmpString.erase(0, FILE_PATH_PREFIX.length());
292 LoggerI("Saving model to file: %s", tmpString.c_str());
293 int ret_val = ml_train_model_save(model, tmpString.c_str(), format);
296 LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
297 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
300 return PlatformResult();
303 PlatformResult TrainerManager::CreateLayer(int& id,
304 ml_train_layer_type_e type) {
307 ml_train_layer_h n_layer = NULL;
309 int ret_val = ml_train_layer_create(&n_layer, type);
311 LoggerE("Could not create layer: %s", ml_strerror(ret_val));
312 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
315 layers_[next_layer_id_] = n_layer;
316 id = next_layer_id_++;
317 return PlatformResult();
320 PlatformResult TrainerManager::LayerSetProperty(int id, const std::string& name,
321 const std::string& value) {
322 ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
324 if (layers_.find(id) == layers_.end()) {
325 LoggerE("Could not find layer with id: %d", id);
326 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
329 auto layer = layers_[id];
330 std::string opt = name + "=" + value;
332 int ret_val = ml_train_layer_set_property(layer, opt.c_str(), NULL);
334 LoggerE("Could not set layer property: %d (%s)", ret_val,
335 ml_strerror(ret_val));
336 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
338 return PlatformResult();
341 PlatformResult TrainerManager::CreateOptimizer(int& id,
342 ml_train_optimizer_type_e type) {
345 ml_train_optimizer_h n_optimizer = NULL;
347 int ret_val = ml_train_optimizer_create(&n_optimizer, type);
349 LoggerE("Could not create optimizer: %d (%s)", ret_val,
350 ml_strerror(ret_val));
351 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
354 optimizers_[next_optimizer_id_] = n_optimizer;
355 id = next_optimizer_id_++;
356 return PlatformResult();
359 PlatformResult TrainerManager::OptimizerSetProperty(int id,
360 const std::string& name,
361 const std::string& value) {
362 ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
364 if (optimizers_.find(id) == optimizers_.end()) {
365 LoggerE("Could not find optimizer with id: %d", id);
366 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
369 auto optimizer = optimizers_[id];
370 std::string opt = name + "=" + value;
371 int ret_val = ml_train_optimizer_set_property(optimizer, opt.c_str(), NULL);
373 LoggerE("Could not set optimizer property: %d (%s)", ret_val,
374 ml_strerror(ret_val));
375 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
377 return PlatformResult();
380 PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string train_file,
381 const std::string valid_file,
382 const std::string test_file) {
385 ml_train_dataset_h n_dataset = NULL;
387 int ret_val = ml_train_dataset_create(&n_dataset);
389 LoggerE("Could not create dataset: %s", ml_strerror(ret_val));
390 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
393 if (!train_file.empty()) {
394 auto tmpString = train_file;
395 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
396 // remove 'file://' prefix from path before passing to native api
397 tmpString.erase(0, FILE_PATH_PREFIX.length());
400 ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TRAIN,
403 LoggerE("Could not add train file %s to dataset: %s", tmpString.c_str(),
404 ml_strerror(ret_val));
405 ml_train_dataset_destroy(n_dataset);
406 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
410 if (!valid_file.empty()) {
411 auto tmpString = valid_file;
412 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
413 // remove 'file://' prefix from path before passing to native api
414 tmpString.erase(0, FILE_PATH_PREFIX.length());
416 ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_VALID,
419 LoggerE("Could not add validation file %s to dataset: %s",
420 tmpString.c_str(), ml_strerror(ret_val));
421 ml_train_dataset_destroy(n_dataset);
422 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
426 if (!test_file.empty()) {
427 auto tmpString = test_file;
428 if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
429 // remove 'file://' prefix from path before passing to native api
430 tmpString.erase(0, FILE_PATH_PREFIX.length());
432 ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TEST,
435 LoggerE("Could not add test file %s to dataset: %s", tmpString.c_str(),
436 ml_strerror(ret_val));
437 ml_train_dataset_destroy(n_dataset);
438 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
442 datasets_[next_dataset_id_] = n_dataset;
443 id = next_dataset_id_++;
444 return PlatformResult();
447 // MK-TODO Add creating Dataset with generator
449 PlatformResult TrainerManager::DatasetSetProperty(int id,
450 const std::string& name,
451 const std::string& value) {
452 ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
454 if (datasets_.find(id) == datasets_.end()) {
455 LoggerE("Could not find dataset with id: %d", id);
456 return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
459 auto dataset = datasets_[id];
460 std::string opt = name + "=" + value;
462 // ml_train_dataset_set_property() is marked as deprecated
463 // temporary set same property for all modes (all data files) if possible
464 int ret_val = ml_train_dataset_set_property_for_mode(
465 dataset, ML_TRAIN_DATASET_MODE_TRAIN, opt.c_str(), NULL);
467 LoggerE("Could not set dataset property for train mode: %d (%s)", ret_val,
468 ml_strerror(ret_val));
469 return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
472 ret_val = ml_train_dataset_set_property_for_mode(
473 dataset, ML_TRAIN_DATASET_MODE_VALID, opt.c_str(), NULL);
475 LoggerE("Could not set dataset property for validation mode: %d (%s)",
476 ret_val, ml_strerror(ret_val));
477 // MK-TODO report error for each file when extracted to separate functions
478 // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
481 ret_val = ml_train_dataset_set_property_for_mode(
482 dataset, ML_TRAIN_DATASET_MODE_TEST, opt.c_str(), NULL);
484 LoggerE("Could not set dataset property for test mode: %d (%s)", ret_val,
485 ml_strerror(ret_val));
486 // MK-TODO report error for each file when extracted to separate functions
487 // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
490 return PlatformResult();
494 } // namespace extension