}
};
+Model.prototype.load = function () {
+ var args = validator_.validateArgs(arguments, [
+ {
+ name: 'path',
+ type: types_.STRING
+ },
+ {
+ name: 'format',
+ type: types_.ENUM,
+ values: Object.values(SaveFormat)
+ }
+ ]);
+
+ var callArgs = {
+ id: this._id,
+ savePath: args.path,
+ saveFormat: args.format
+ }
+
+ try {
+ callArgs.savePath = tizen.filesystem.toURI(args.path);
+ } catch (e) {
+ throw new WebAPIException(WebAPIException.InvalidValuesError, 'Invalid file path given');
+ }
+
+ if (!tizen.filesystem.pathExists(callArgs.savePath)) {
+ throw new WebAPIException(WebAPIException.NotFoundError, 'Path not found');
+ }
+
+ var result = native_.callSync('MLTrainerModelLoad', callArgs);
+
+ if (native_.isFailure(result)) {
+ throw native_.getErrorObjectAndValidate(
+ result,
+ ValidModelSaveExceptions,
+ AbortError
+ );
+ }
+};
+
Model.prototype.addLayer = function() {
var args = validator_.validateArgs(arguments, [
{
REGISTER_METHOD(MLTrainerModelSummarize);
REGISTER_METHOD(MLTrainerModelCheckMetrics);
REGISTER_METHOD(MLTrainerModelSave);
+ REGISTER_METHOD(MLTrainerModelLoad);
REGISTER_METHOD(MLTrainerModelSetDataset);
REGISTER_METHOD(MLTrainerModelSetOptimizer);
REGISTER_METHOD(MLTrainerModelDispose);
ReportSuccess(out);
}
+void MlInstance::MLTrainerModelLoad(const picojson::value& args,
+ picojson::object& out) {
+ ScopeLogger("args: %s", args.serialize().c_str());
+ CHECK_ARGS(args, kId, double, out);
+ CHECK_ARGS(args, kSavePath, std::string, out);
+ CHECK_ARGS(args, kSaveFormat, std::string, out);
+
+ auto id = static_cast<int>(args.get(kId).get<double>());
+ auto path = args.get(kSavePath).get<std::string>();
+
+ ml_train_model_format_e model_format = ML_TRAIN_MODEL_FORMAT_INI_WITH_BIN;
+ PlatformResult result = types::ModelSaveFormatEnum.getValue(
+ args.get(kSaveFormat).get<std::string>(), &model_format);
+ if (!result) {
+ LogAndReportError(result, &out);
+ return;
+ }
+
+ result = trainer_manager_.ModelLoad(id, path, model_format);
+ if (!result) {
+ ReportError(result, &out);
+ return;
+ }
+ ReportSuccess(out);
+}
+
void MlInstance::MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out) {
ScopeLogger("args: %s", args.serialize().c_str());
CHECK_ARGS(args, kId, double, out);
void MLTrainerModelCheckMetrics(const picojson::value& args,
picojson::object& out);
void MLTrainerModelSave(const picojson::value& args, picojson::object& out);
+ void MLTrainerModelLoad(const picojson::value& args, picojson::object& out);
void MLTrainerModelSetDataset(const picojson::value& args, picojson::object& out);
void MLTrainerModelSetOptimizer(const picojson::value& args, picojson::object& out);
void MLTrainerModelDispose(const picojson::value& args,
model->instanceLock.unlock();
if (ret_val != ML_ERROR_NONE) {
- LoggerE("Could not model to file: %d (%s)", ret_val, ml_strerror(ret_val));
+ LoggerE("Could not save model to file: %d (%s)", ret_val,
+ ml_strerror(ret_val));
+ return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
+ }
+
+ return PlatformResult();
+}
+
+PlatformResult TrainerManager::ModelLoad(int id, const std::string& path,
+ ml_train_model_format_e format) {
+ ScopeLogger();
+
+ if (models_.find(id) == models_.end()) {
+ LoggerE("Could not find model with id: %d", id);
+ return PlatformResult(ErrorCode::ABORT_ERR, "Could not find model");
+ }
+
+ auto& model = models_[id];
+ bool available = model->instanceLock.try_lock();
+ if (!available) {
+ LoggerE("Model locked - probaly training in progress");
+ return PlatformResult(ErrorCode::NO_MODIFICATION_ALLOWED_ERR,
+ "Model training in progress - cannot save");
+ }
+
+ auto tmpString = path;
+ if (tmpString.substr(0, FILE_PATH_PREFIX.length()) == FILE_PATH_PREFIX) {
+ // remove 'file://' prefix from path before passing to native api
+ tmpString.erase(0, FILE_PATH_PREFIX.length());
+ }
+
+ LoggerI("Loading model from file: %s", tmpString.c_str());
+ int ret_val =
+ ml_train_model_load(model->getNative(), tmpString.c_str(), format);
+
+ model->instanceLock.unlock();
+
+ if (ret_val != ML_ERROR_NONE) {
+ LoggerE("Could not load model from file: %d (%s)", ret_val,
+ ml_strerror(ret_val));
return PlatformResult(ErrorCode::ABORT_ERR, ml_strerror(ret_val));
}
std::string& summary);
PlatformResult CheckMetrics(int id, double train_loss, double valid_loss,
double valid_accuracy, bool* result);
- PlatformResult ModelSave(int id,
- const std::string& path,
+ PlatformResult ModelSave(int id, const std::string& path,
+ ml_train_model_format_e format);
+ PlatformResult ModelLoad(int id, const std::string& path,
ml_train_model_format_e format);
PlatformResult ModelDispose(int id);