2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __TRAINING_MODEL_H__
18 #define __TRAINING_MODEL_H__
23 #include <mv_inference_type.h>
25 #include "training_engine_error.h"
26 #include "training_engine_common_impl.h"
27 #include "inference_engine_common_impl.h"
28 #include "data_set_manager.h"
29 #include "feature_vector_manager.h"
31 struct TrainingEngineBackendInfo {
32 std::vector<std::string> input_layer_names;
33 std::vector<std::string> output_layer_names;
34 std::vector<inference_engine_tensor_info> input_tensor_info;
35 std::vector<inference_engine_tensor_info> output_tensor_info;
36 training_engine_optimizer_property optimizer_property;
37 training_engine_compile_property compile_property;
43 virtual void saveModel(const std::string file_path) = 0;
44 virtual void removeModel(const std::string file_path) = 0;
47 std::unique_ptr<TrainingEngineInterface::Common::TrainingEngineCommon> _training;
48 std::unique_ptr<training_engine_model> _model;
49 std::unique_ptr<training_engine_dataset> _data_set;
50 std::string _internal_model_file;
53 TrainingModel(const training_backend_type_e backend_type, const training_target_type_e target_type,
54 const std::vector<size_t> input_tensor_shape, const std::string internal_model_file);
55 virtual ~TrainingModel();
57 void applyDataSet(std::unique_ptr<DataSetManager> &data_set);
58 void clearDataSet(std::unique_ptr<DataSetManager> &data_set);
62 void getWeights(float **weights, size_t *size, std::string name);
64 virtual void configureModel(int num_of_class) = 0;
65 virtual TrainingEngineBackendInfo &getTrainingEngineInfo() = 0;