[ML][Training] Model saving to file
[platform/core/api/webapi-plugins.git] / src / ml / ml_trainer_manager.cc
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  *    Licensed under the Apache License, Version 2.0 (the "License");
5  *    you may not use this file except in compliance with the License.
6  *    You may obtain a copy of the License at
7  *
8  *        http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *    Unless required by applicable law or agreed to in writing, software
11  *    distributed under the License is distributed on an "AS IS" BASIS,
12  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *    See the License for the specific language governing permissions and
14  *    limitations under the License.
15  */
16
17 #include "ml_trainer_manager.h"
18
19 #include "common/tools.h"
20
21 using common::ErrorCode;
22 using common::PlatformResult;
23
24 namespace extension {
25 namespace ml {
26
27 const std::string OPTION_SEPARATOR = " | ";
28 const std::string FILE_PATH_PREFIX = "file://";
29
30 TrainerManager::TrainerManager() {
31   ScopeLogger();
32 }
33
34 TrainerManager::~TrainerManager() {
35   ScopeLogger();
36 }
37
38 PlatformResult TrainerManager::CreateModel(int& id) {
39   ScopeLogger();
40
41   ml_train_model_h n_model = NULL;
42
43   int ret_val = ml_train_model_construct(&n_model);
44   if (ret_val != 0) {
45     LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
46     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
47   }
48
49   models_[next_model_id_] = n_model;
50   id = next_model_id_++;
51
52   return PlatformResult();
53 }
54
55 PlatformResult TrainerManager::CreateModel(int& id, const std::string config) {
56   ScopeLogger();
57
58   ml_train_model_h n_model = NULL;
59
60   int ret_val = ml_train_model_construct_with_conf(config.c_str(), &n_model);
61   if (ret_val != 0) {
62     LoggerE("Could not create model: %d (%s)", ret_val, ml_strerror(ret_val));
63     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
64   }
65
66   models_[next_model_id_] = n_model;
67   id = next_model_id_++;
68
69   return PlatformResult();
70 }
71
72 PlatformResult TrainerManager::ModelCompile(int id,
73                                             const picojson::object& options) {
74   ScopeLogger();
75
76   if (models_.find(id) == models_.end()) {
77     LoggerE("Could not find model with id: %d", id);
78     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
79   }
80
81   auto& model = models_[id];
82
83   std::stringstream ss;
84   for (const auto& opt : options) {
85     const auto& key = opt.first;
86     if (opt.second.is<std::string>()) {
87       const auto& value = opt.second.get<std::string>();
88       ss << key << "=" << value << OPTION_SEPARATOR;
89     } else if (opt.second.is<double>()) {
90       const auto& value = opt.second.get<double>();
91       ss << key << "=" << value << OPTION_SEPARATOR;
92     } else {
93       LoggerE("Unexpected param type for: %s", key.c_str());
94       return PlatformResult(ErrorCode::ABORT_ERR,
95                             "Unexpected param type for:" + key);
96     }
97   }
98
99   int ret_val = 0;
100   auto compileOpts = ss.str();
101   if (compileOpts.length() < OPTION_SEPARATOR.length()) {
102     ret_val = ml_train_model_compile(model, NULL);
103   } else {
104     // remove trailing ' | ' from options string
105     compileOpts =
106         compileOpts.substr(0, compileOpts.length() - OPTION_SEPARATOR.length());
107     LoggerI("Compiling model with options: %s", compileOpts.c_str());
108     ret_val = ml_train_model_compile(model, compileOpts.c_str(), NULL);
109   }
110
111   ss.clear();
112
113   if (ret_val != 0) {
114     LoggerE("Could not compile model: %d (%s)", ret_val, ml_strerror(ret_val));
115     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
116   }
117
118   return PlatformResult();
119 }
120
121 PlatformResult TrainerManager::ModelRun(int id,
122                                         const picojson::object& options) {
123   ScopeLogger();
124
125   if (models_.find(id) == models_.end()) {
126     LoggerE("Could not find model with id: %d", id);
127     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
128   }
129
130   auto& model = models_[id];
131
132   std::stringstream ss;
133   for (const auto& opt : options) {
134     const auto& key = opt.first;
135     if (opt.second.is<std::string>()) {
136       const auto& value = opt.second.get<std::string>();
137       ss << key << "=" << value << OPTION_SEPARATOR;
138     } else if (opt.second.is<double>()) {
139       const auto& value = opt.second.get<double>();
140       ss << key << "=" << value << OPTION_SEPARATOR;
141     } else {
142       LoggerE("Unexpected param type for: %s", key.c_str());
143       return PlatformResult(ErrorCode::ABORT_ERR,
144                             "Unexpected param type for:" + key);
145     }
146   }
147
148   int ret_val = 0;
149   auto runOpts = ss.str();
150
151   if (runOpts.length() < OPTION_SEPARATOR.length()) {
152     ret_val = ml_train_model_run(model, NULL);
153   } else {
154     // remove trailing ' | ' from options string
155     runOpts = runOpts.substr(0, runOpts.length() - OPTION_SEPARATOR.length());
156     LoggerI("Running model with options: %s", runOpts.c_str());
157     ret_val = ml_train_model_run(model, runOpts.c_str(), NULL);
158   }
159
160   if (ret_val != 0) {
161     LoggerE("Could not run (train) model: %d (%s)", ret_val,
162             ml_strerror(ret_val));
163     return PlatformResult(ErrorCode::UNKNOWN_ERR, ml_strerror(ret_val));
164   }
165
166   return PlatformResult();
167 }
168
169 PlatformResult TrainerManager::ModelAddLayer(int id, int layerId) {
170   ScopeLogger();
171
172   if (models_.find(id) == models_.end()) {
173     LoggerE("Could not find model with id: %d", id);
174     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
175   }
176
177   if (layers_.find(layerId) == layers_.end()) {
178     LoggerE("Could not find layer with id: %d", id);
179     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
180   }
181
182   auto& model = models_[id];
183   auto& layer = layers_[layerId];
184
185   int ret_val = ml_train_model_add_layer(model, layer);
186   if (ret_val != 0) {
187     LoggerE("Could not add layer to model: %d (%s)", ret_val,
188             ml_strerror(ret_val));
189     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
190   }
191
192   return PlatformResult();
193 }
194
195 PlatformResult TrainerManager::ModelSetOptimizer(int id, int optimizerId) {
196   ScopeLogger();
197
198   if (models_.find(id) == models_.end()) {
199     LoggerE("Could not find model with id: %d", id);
200     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
201   }
202
203   if (optimizers_.find(optimizerId) == optimizers_.end()) {
204     LoggerE("Could not find optimizer with id: %d", id);
205     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
206   }
207
208   auto& model = models_[id];
209   auto& optimizer = optimizers_[optimizerId];
210
211   int ret_val = ml_train_model_set_optimizer(model, optimizer);
212   if (ret_val != 0) {
213     LoggerE("Could not set optimizer for model: %d (%s)", ret_val,
214             ml_strerror(ret_val));
215     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
216   }
217
218   return PlatformResult();
219 }
220
221 PlatformResult TrainerManager::ModelSetDataset(int id, int datasetId) {
222   ScopeLogger();
223
224   if (models_.find(id) == models_.end()) {
225     LoggerE("Could not find model with id: %d", id);
226     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
227   }
228
229   if (datasets_.find(datasetId) == datasets_.end()) {
230     LoggerE("Could not find dataset with id: %d", id);
231     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
232   }
233
234   auto& model = models_[id];
235   auto& dataset = datasets_[datasetId];
236
237   int ret_val = ml_train_model_set_dataset(model, dataset);
238   if (ret_val != 0) {
239     LoggerE("Could not set dataset for model: %d (%s)", ret_val,
240             ml_strerror(ret_val));
241     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
242   }
243
244   return PlatformResult();
245 }
246
247 PlatformResult TrainerManager::ModelSummarize(int id,
248                                               ml_train_summary_type_e level,
249                                               std::string& summary) {
250   ScopeLogger();
251
252   if (models_.find(id) == models_.end()) {
253     LoggerE("Could not find model with id: %d", id);
254     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
255   }
256
257   auto& model = models_[id];
258   char* tmpSummary = NULL;
259
260   int ret_val = ml_train_model_get_summary(model, level, &tmpSummary);
261
262   if (ret_val != 0) {
263     LoggerE("Could not get summary for model: %d (%s)", ret_val,
264             ml_strerror(ret_val));
265     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
266   }
267
268   summary = tmpSummary;
269   free(tmpSummary);
270
271   return PlatformResult();
272 }
273
274 PlatformResult TrainerManager::ModelSave(int id,
275                                          const std::string& path,
276                                          ml_train_model_format_e format) {
277   ScopeLogger();
278
279   if (models_.find(id) == models_.end()) {
280     LoggerE("Could not find model with id: %d", id);
281     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
282   }
283
284   auto& model = models_[id];
285
286   auto tmpString = path;
287   if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
288     // remove 'file://' prefix from path before passing to native api
289     tmpString.erase(0, FILE_PATH_PREFIX.length());
290   }
291
292   LoggerI("Saving model to file: %s", tmpString.c_str());
293   int ret_val = ml_train_model_save(model, tmpString.c_str(), format);
294
295   if (ret_val != 0) {
296     LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
297     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
298   }
299
300   return PlatformResult();
301 }
302
303 PlatformResult TrainerManager::CreateLayer(int& id,
304                                            ml_train_layer_type_e type) {
305   ScopeLogger();
306
307   ml_train_layer_h n_layer = NULL;
308
309   int ret_val = ml_train_layer_create(&n_layer, type);
310   if (ret_val != 0) {
311     LoggerE("Could not create layer: %s", ml_strerror(ret_val));
312     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
313   }
314
315   layers_[next_layer_id_] = n_layer;
316   id = next_layer_id_++;
317   return PlatformResult();
318 }
319
320 PlatformResult TrainerManager::LayerSetProperty(int id, const std::string& name,
321                                                 const std::string& value) {
322   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
323
324   if (layers_.find(id) == layers_.end()) {
325     LoggerE("Could not find layer with id: %d", id);
326     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find layer");
327   }
328
329   auto layer = layers_[id];
330   std::string opt = name + "=" + value;
331
332   int ret_val = ml_train_layer_set_property(layer, opt.c_str(), NULL);
333   if (ret_val != 0) {
334     LoggerE("Could not set layer property: %d (%s)", ret_val,
335             ml_strerror(ret_val));
336     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
337   }
338   return PlatformResult();
339 }
340
341 PlatformResult TrainerManager::CreateOptimizer(int& id,
342                                                ml_train_optimizer_type_e type) {
343   ScopeLogger();
344
345   ml_train_optimizer_h n_optimizer = NULL;
346
347   int ret_val = ml_train_optimizer_create(&n_optimizer, type);
348   if (ret_val != 0) {
349     LoggerE("Could not create optimizer: %d (%s)", ret_val,
350             ml_strerror(ret_val));
351     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
352   }
353
354   optimizers_[next_optimizer_id_] = n_optimizer;
355   id = next_optimizer_id_++;
356   return PlatformResult();
357 }
358
359 PlatformResult TrainerManager::OptimizerSetProperty(int id,
360                                                     const std::string& name,
361                                                     const std::string& value) {
362   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
363
364   if (optimizers_.find(id) == optimizers_.end()) {
365     LoggerE("Could not find optimizer with id: %d", id);
366     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find optimizer");
367   }
368
369   auto optimizer = optimizers_[id];
370   std::string opt = name + "=" + value;
371   int ret_val = ml_train_optimizer_set_property(optimizer, opt.c_str(), NULL);
372   if (ret_val != 0) {
373     LoggerE("Could not set optimizer property: %d (%s)", ret_val,
374             ml_strerror(ret_val));
375     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
376   }
377   return PlatformResult();
378 }
379
380 PlatformResult TrainerManager::CreateFileDataset(int& id, const std::string train_file,
381                                                  const std::string valid_file,
382                                                  const std::string test_file) {
383   ScopeLogger();
384
385   ml_train_dataset_h n_dataset = NULL;
386
387   int ret_val = ml_train_dataset_create(&n_dataset);
388   if (ret_val != 0) {
389     LoggerE("Could not create dataset: %s", ml_strerror(ret_val));
390     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
391   }
392
393   if (!train_file.empty()) {
394     auto tmpString = train_file;
395     if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
396       // remove 'file://' prefix from path before passing to native api
397       tmpString.erase(0, FILE_PATH_PREFIX.length());
398     }
399
400     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TRAIN,
401                                         tmpString.c_str());
402     if (ret_val != 0) {
403       LoggerE("Could not add train file %s to dataset: %s", tmpString.c_str(),
404               ml_strerror(ret_val));
405       ml_train_dataset_destroy(n_dataset);
406       return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
407     }
408   }
409
410   if (!valid_file.empty()) {
411     auto tmpString = valid_file;
412     if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
413       // remove 'file://' prefix from path before passing to native api
414       tmpString.erase(0, FILE_PATH_PREFIX.length());
415     }
416     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_VALID,
417                                         tmpString.c_str());
418     if (ret_val != 0) {
419       LoggerE("Could not add validation file %s to dataset: %s",
420               tmpString.c_str(), ml_strerror(ret_val));
421       ml_train_dataset_destroy(n_dataset);
422       return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
423     }
424   }
425
426   if (!test_file.empty()) {
427     auto tmpString = test_file;
428     if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
429       // remove 'file://' prefix from path before passing to native api
430       tmpString.erase(0, FILE_PATH_PREFIX.length());
431     }
432     ret_val = ml_train_dataset_add_file(n_dataset, ML_TRAIN_DATASET_MODE_TEST,
433                                         tmpString.c_str());
434     if (ret_val != 0) {
435       LoggerE("Could not add test file %s to dataset: %s", tmpString.c_str(),
436               ml_strerror(ret_val));
437       ml_train_dataset_destroy(n_dataset);
438       return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
439     }
440   }
441
442   datasets_[next_dataset_id_] = n_dataset;
443   id = next_dataset_id_++;
444   return PlatformResult();
445 }
446
447 // MK-TODO Add creating Dataset with generator
448
449 PlatformResult TrainerManager::DatasetSetProperty(int id,
450                                                   const std::string& name,
451                                                   const std::string& value) {
452   ScopeLogger("id: %d, name: %s, value: %s", id, name.c_str(), value.c_str());
453
454   if (datasets_.find(id) == datasets_.end()) {
455     LoggerE("Could not find dataset with id: %d", id);
456     return PlatformResult(ErrorCode::ABORT_ERR, "Could not find dataset");
457   }
458
459   auto dataset = datasets_[id];
460   std::string opt = name + "=" + value;
461
462   // ml_train_dataset_set_property() is marked as deprecated
463   // temporary set same property for all modes (all data files) if possible
464   int ret_val = ml_train_dataset_set_property_for_mode(
465       dataset, ML_TRAIN_DATASET_MODE_TRAIN, opt.c_str(), NULL);
466   if (ret_val != 0) {
467     LoggerE("Could not set dataset property for train mode: %d (%s)", ret_val,
468             ml_strerror(ret_val));
469     return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
470   }
471
472   ret_val = ml_train_dataset_set_property_for_mode(
473       dataset, ML_TRAIN_DATASET_MODE_VALID, opt.c_str(), NULL);
474   if (ret_val != 0) {
475     LoggerE("Could not set dataset property for validation mode: %d (%s)",
476             ret_val, ml_strerror(ret_val));
477     // MK-TODO report error for each file when extracted to separate functions
478     // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
479   }
480
481   ret_val = ml_train_dataset_set_property_for_mode(
482       dataset, ML_TRAIN_DATASET_MODE_TEST, opt.c_str(), NULL);
483   if (ret_val != 0) {
484     LoggerE("Could not set dataset property for test mode: %d (%s)", ret_val,
485             ml_strerror(ret_val));
486     // MK-TODO report error for each file when extracted to separate functions
487     // return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
488   }
489
490   return PlatformResult();
491 }
492
493 }  // namespace ml
494 }  // namespace extension