2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
27 #include <mv_private.h>
29 #include "machine_learning_exception.h"
30 #include "training_model.h"
33 using namespace TrainingEngineInterface::Common;
34 using namespace mediavision::machine_learning::exception;
36 TrainingModel::TrainingModel(const training_backend_type_e backend_type, const training_target_type_e target_type,
37 vector<size_t> input_tensor_shape, const string internal_model_file)
39 _internal_model_file = internal_model_file;
40 _training = make_unique<TrainingEngineInterface::Common::TrainingEngineCommon>();
42 training_engine_config config = { "", backend_type, target_type };
43 int ret = _training->BindBackend(&config);
44 if (ret != TRAINING_ENGINE_ERROR_NONE)
45 throw InvalidOperation("Fail to bind backend engine.");
47 training_engine_capacity capacity = { TRAINING_TENSOR_SHAPE_MIN };
48 ret = _training->GetBackendCapacity(capacity);
49 if (ret != TRAINING_ENGINE_ERROR_NONE)
50 throw InvalidOperation("Fail to get backend capacity.");
53 TrainingModel::~TrainingModel()
56 _training->UnbindBackend();
59 void TrainingModel::ApplyDataSet(unique_ptr<DataSetManager> &data_set)
61 auto &values = data_set->GetData();
62 auto &labels = data_set->GetLabel();
64 LOGD("Generating feature vectors for training");
66 _data_set = _training->CreateDataset();
68 throw InvalidOperation("Fail to create a dataset.");
72 for (size_t idx = 0; idx < values.size(); ++idx) {
73 ret = _training->AddDataToDataset(_data_set.get(), values[idx], labels[idx], TRAINING_DATASET_TYPE_TRAIN);
74 if (ret != TRAINING_ENGINE_ERROR_NONE)
75 throw InvalidOperation("Fail to add data to dataset.", ret);
78 ret = _training->SetDataset(_model.get(), _data_set.get());
79 if (ret != TRAINING_ENGINE_ERROR_NONE)
80 throw InvalidOperation("Fail to set dataset to model.", ret);
83 void TrainingModel::ClearDataSet(unique_ptr<DataSetManager> &data_set)
86 _training->DestroyDataset(_data_set.get());
89 void TrainingModel::Compile()
91 auto optimizer = _training->CreateOptimizer(TRAINING_OPTIMIZER_TYPE_SGD);
93 throw InvalidOperation("Fail to create a optimizer.");
95 int ret = _training->SetOptimizerProperty(optimizer.get(), GetTrainingEngineInfo().optimizer_property);
96 if (ret != TRAINING_ENGINE_ERROR_NONE)
97 throw InvalidOperation("Fail to set optimizer property.", ret);
99 ret = _training->AddOptimizer(_model.get(), optimizer.get());
100 if (ret != TRAINING_ENGINE_ERROR_NONE)
101 throw InvalidOperation("Fail to add optimizer to model.", ret);
103 ret = _training->CompileModel(_model.get(), GetTrainingEngineInfo().compile_property);
104 if (ret != TRAINING_ENGINE_ERROR_NONE)
105 throw InvalidOperation("Fail to compile model.", ret);
108 void TrainingModel::Train()
110 training_engine_model_property model_property;
111 int ret = _training->TrainModel(_model.get(), model_property);
112 if (ret != TRAINING_ENGINE_ERROR_NONE)
113 throw InvalidOperation("Fail to train model.", ret);
116 SaveModel(_internal_model_file);
119 void TrainingModel::getWeights(float **weights, size_t *size, std::string name)
124 void TrainingModel::RemoveModel()
126 RemoveModel(_internal_model_file);