mv_machine_learning: reallocate parser on changing meta file
[platform/core/api/mediavision.git] / mv_machine_learning / landmark_detection / src / pose_landmark_adapter.cpp
1 /**
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "machine_learning_exception.h"
18 #include "pose_landmark_adapter.h"
19 #include "mv_landmark_detection_config.h"
20
21 using namespace std;
22 using namespace MediaVision::Common;
23 using namespace mediavision::machine_learning;
24 using namespace mediavision::machine_learning::exception;
25
26 namespace mediavision
27 {
28 namespace machine_learning
29 {
30 template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter() : _source()
31 {
32         _config = make_shared<LandmarkDetectionConfig>();
33         _config->parseConfigFile(_config_file_name);
34
35         LandmarkDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
36         create(model_type);
37 }
38
39 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
40 {
41         _landmark_detection->preDestroy();
42 }
43
44 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
45 {
46         // If a concrete class object created already exists, reset the object
47         // so that other concrete class object can be created again according to a given task_type.
48         if (_landmark_detection) {
49                 // If default task type is same as a given one then skip.
50                 if (_landmark_detection->getTaskType() == task_type)
51                         return;
52         }
53
54         // if model name is changed by user then reallocate the parser and reload the meta file corresponding to the model name.
55         _config->loadMetaFile(task_type);
56
57         if (task_type == LandmarkDetectionTaskType::PLD_CPM)
58                 _landmark_detection = make_unique<PldCpm>(task_type, _config);
59 }
60
61 template<typename T, typename V>
62 LandmarkDetectionTaskType PoseLandmarkAdapter<T, V>::convertToTaskType(string model_name)
63 {
64         if (model_name.empty())
65                 throw InvalidParameter("model name is empty.");
66
67         transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
68
69         if (model_name == "PLD_CPM")
70                 return LandmarkDetectionTaskType::PLD_CPM;
71
72         throw InvalidParameter("Invalid pose landmark model name.");
73 }
74
75 template<typename T, typename V>
76 void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
77                                                                                          const char *model_name)
78 {
79         try {
80                 _config->setUserModel(model_file, meta_file, label_file);
81
82                 LandmarkDetectionTaskType model_type = convertToTaskType(model_name);
83                 create(model_type);
84         } catch (const BaseException &e) {
85                 LOGW("A given model name is invalid so default task type will be used.");
86         }
87
88         if (!model_file && !meta_file) {
89                 LOGW("Given model info is invalid so default model info will be used instead.");
90                 return;
91         }
92 }
93
94 template<typename T, typename V>
95 void PoseLandmarkAdapter<T, V>::setEngineInfo(const char *engine_type, const char *device_type)
96 {
97         _landmark_detection->setEngineInfo(string(engine_type), string(device_type));
98 }
99
100 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::configure()
101 {
102         _landmark_detection->configure();
103 }
104
105 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
106 {
107         _landmark_detection->getNumberOfEngines(number_of_engines);
108 }
109
110 template<typename T, typename V>
111 void PoseLandmarkAdapter<T, V>::getEngineType(unsigned int engine_index, char **engine_type)
112 {
113         _landmark_detection->getEngineType(engine_index, engine_type);
114 }
115
116 template<typename T, typename V>
117 void PoseLandmarkAdapter<T, V>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
118 {
119         _landmark_detection->getNumberOfDevices(engine_type, number_of_devices);
120 }
121
122 template<typename T, typename V>
123 void PoseLandmarkAdapter<T, V>::getDeviceType(const char *engine_type, unsigned int device_index, char **device_type)
124 {
125         _landmark_detection->getDeviceType(engine_type, device_index, device_type);
126 }
127
128 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::prepare()
129 {
130         _landmark_detection->prepare();
131 }
132
133 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::setInput(T &t)
134 {
135         _source = t;
136 }
137
138 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::perform()
139 {
140         _landmark_detection->perform(_source.inference_src);
141 }
142
143 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::performAsync(T &t)
144 {
145         _landmark_detection->performAsync(t);
146 }
147
148 template<typename T, typename V> V &PoseLandmarkAdapter<T, V>::getOutput()
149 {
150         return _landmark_detection->getOutput();
151 }
152
153 template<typename T, typename V> V &PoseLandmarkAdapter<T, V>::getOutputCache()
154 {
155         throw InvalidOperation("Not support yet.");
156 }
157
158 template class PoseLandmarkAdapter<LandmarkDetectionInput, LandmarkDetectionResult>;
159 }
160 }