mv_machine_learning: use DEFAULT_MODEL_NAME for landmark detection task group
[platform/core/api/mediavision.git] / mv_machine_learning / landmark_detection / src / pose_landmark_adapter.cpp
1 /**
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "machine_learning_exception.h"
18 #include "pose_landmark_adapter.h"
19 #include "mv_landmark_detection_config.h"
20
21 using namespace std;
22 using namespace MediaVision::Common;
23 using namespace mediavision::machine_learning;
24 using namespace mediavision::machine_learning::exception;
25
26 namespace mediavision
27 {
28 namespace machine_learning
29 {
30 template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter() : _source()
31 {
32         auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + _config_file_name);
33
34         string defaultModelName;
35
36         int ret = config->getStringAttribute(MV_LANDMARK_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
37         if (ret != MEDIA_VISION_ERROR_NONE)
38                 throw InvalidOperation("Fail to get default model name.");
39
40         create(convertToTaskType(defaultModelName));
41 }
42
43 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
44 {
45         _landmark_detection->preDestroy();
46 }
47
48 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
49 {
50         // If a concrete class object created already exists, reset the object
51         // so that other concrete class object can be created again according to a given task_type.
52         if (_landmark_detection) {
53                 // If default task type is same as a given one then skip.
54                 if (_landmark_detection->getTaskType() == task_type)
55                         return;
56         }
57
58         if (task_type == LandmarkDetectionTaskType::PLD_CPM)
59                 _landmark_detection = make_unique<PldCpm>(task_type);
60 }
61
62 template<typename T, typename V>
63 LandmarkDetectionTaskType PoseLandmarkAdapter<T, V>::convertToTaskType(string model_name)
64 {
65         if (model_name.empty())
66                 throw InvalidParameter("model name is empty.");
67
68         transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
69
70         LandmarkDetectionTaskType task_type = LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE;
71
72         if (model_name == "PLD_CPM")
73                 return LandmarkDetectionTaskType::PLD_CPM;
74
75         throw InvalidParameter("Invalid pose landmark model name.");
76 }
77
78 template<typename T, typename V>
79 void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
80                                                                                          const char *model_name)
81 {
82         try {
83                 create(convertToTaskType(model_name));
84         } catch (const BaseException &e) {
85                 LOGW("A given model name is invalid so default task type will be used.");
86         }
87
88         if (model_file)
89                 _model_file = model_file;
90         if (meta_file)
91                 _meta_file = meta_file;
92         if (label_file)
93                 _label_file = label_file;
94
95         if (_model_file.empty() && _meta_file.empty()) {
96                 LOGW("Given model info is invalid so default model info will be used instead.");
97                 return;
98         }
99
100         _landmark_detection->setUserModel(_model_file, _meta_file, _label_file);
101 }
102
103 template<typename T, typename V>
104 void PoseLandmarkAdapter<T, V>::setEngineInfo(const char *engine_type, const char *device_type)
105 {
106         _landmark_detection->setEngineInfo(string(engine_type), string(device_type));
107 }
108
109 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::configure()
110 {
111         _landmark_detection->configure(_config_file_name);
112 }
113
114 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
115 {
116         _landmark_detection->getNumberOfEngines(number_of_engines);
117 }
118
119 template<typename T, typename V>
120 void PoseLandmarkAdapter<T, V>::getEngineType(unsigned int engine_index, char **engine_type)
121 {
122         _landmark_detection->getEngineType(engine_index, engine_type);
123 }
124
125 template<typename T, typename V>
126 void PoseLandmarkAdapter<T, V>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
127 {
128         _landmark_detection->getNumberOfDevices(engine_type, number_of_devices);
129 }
130
131 template<typename T, typename V>
132 void PoseLandmarkAdapter<T, V>::getDeviceType(const char *engine_type, unsigned int device_index, char **device_type)
133 {
134         _landmark_detection->getDeviceType(engine_type, device_index, device_type);
135 }
136
137 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::prepare()
138 {
139         _landmark_detection->prepare();
140 }
141
142 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::setInput(T &t)
143 {
144         _source = t;
145 }
146
147 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::perform()
148 {
149         _landmark_detection->perform(_source.inference_src);
150 }
151
152 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::performAsync(T &t)
153 {
154         _landmark_detection->performAsync(t);
155 }
156
157 template<typename T, typename V> V &PoseLandmarkAdapter<T, V>::getOutput()
158 {
159         return _landmark_detection->getOutput();
160 }
161
162 template<typename T, typename V> V &PoseLandmarkAdapter<T, V>::getOutputCache()
163 {
164         throw InvalidOperation("Not support yet.");
165 }
166
167 template class PoseLandmarkAdapter<LandmarkDetectionInput, LandmarkDetectionResult>;
168 }
169 }