ee49b9ee8707fe90fdfba2c7402fd3a543d34f05
[platform/core/api/mediavision.git] / mv_machine_learning / training / include / training_model.h
1 /**
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __TRAINING_MODEL_H__
18 #define __TRAINING_MODEL_H__
19
20 #include <iostream>
21 #include <map>
22
23 #include <mv_inference_type.h>
24
25 #include "training_engine_error.h"
26 #include "training_engine_common_impl.h"
27 #include "inference_engine_common_impl.h"
28 #include "data_set_manager.h"
29 #include "feature_vector_manager.h"
30
31 typedef struct {
32         std::vector<std::string> input_layer_names;
33         std::vector<std::string> output_layer_names;
34         std::vector<inference_engine_tensor_info> input_tensor_info;
35         std::vector<inference_engine_tensor_info> output_tensor_info;
36         training_engine_optimizer_property optimizer_property;
37         training_engine_compile_property compile_property;
38 } TrainingEngineBackendInfo;
39
40 class TrainingModel
41 {
42 private:
43         virtual void SaveModel(const std::string file_path) = 0;
44         virtual void RemoveModel(const std::string file_path) = 0;
45
46 protected:
47         std::unique_ptr<TrainingEngineInterface::Common::TrainingEngineCommon> _training;
48         std::unique_ptr<training_engine_model> _model;
49         std::unique_ptr<training_engine_dataset> _data_set;
50         std::string _internal_model_file;
51
52 public:
53         TrainingModel(const training_backend_type_e backend_type, const training_target_type_e target_type,
54                                   const std::vector<size_t> input_tensor_shape, const std::string internal_model_file);
55         virtual ~TrainingModel();
56
57         void ApplyDataSet(std::unique_ptr<DataSetManager> &data_set);
58         void ClearDataSet(std::unique_ptr<DataSetManager> &data_set);
59         void Compile();
60         void Train();
61         void RemoveModel();
62         void getWeights(float **weights, size_t *size, std::string name);
63
64         virtual void ConfigureModel(int num_of_class) = 0;
65         virtual TrainingEngineBackendInfo &GetTrainingEngineInfo() = 0;
66 };
67
68 #endif