2 * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "machine_learning_exception.h"
18 #include "pose_landmark_adapter.h"
19 #include "mv_landmark_detection_config.h"
22 using namespace MediaVision::Common;
23 using namespace mediavision::machine_learning;
24 using namespace mediavision::machine_learning::exception;
28 namespace machine_learning
30 template<typename T, typename V> PoseLandmarkAdapter<T, V>::PoseLandmarkAdapter() : _source()
32 auto config = make_unique<EngineConfig>(MV_CONFIG_PATH + _config_file_name);
34 string defaultModelName;
36 int ret = config->getStringAttribute(MV_LANDMARK_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
37 if (ret != MEDIA_VISION_ERROR_NONE)
38 throw InvalidOperation("Fail to get default model name.");
40 create(convertToTaskType(defaultModelName));
43 template<typename T, typename V> PoseLandmarkAdapter<T, V>::~PoseLandmarkAdapter()
45 _landmark_detection->preDestroy();
48 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::create(LandmarkDetectionTaskType task_type)
50 // If a concrete class object created already exists, reset the object
51 // so that other concrete class object can be created again according to a given task_type.
52 if (_landmark_detection) {
53 // If default task type is same as a given one then skip.
54 if (_landmark_detection->getTaskType() == task_type)
58 if (task_type == LandmarkDetectionTaskType::PLD_CPM)
59 _landmark_detection = make_unique<PldCpm>(task_type);
62 template<typename T, typename V>
63 LandmarkDetectionTaskType PoseLandmarkAdapter<T, V>::convertToTaskType(string model_name)
65 if (model_name.empty())
66 throw InvalidParameter("model name is empty.");
68 transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
70 LandmarkDetectionTaskType task_type = LandmarkDetectionTaskType::LANDMARK_DETECTION_TASK_NONE;
72 if (model_name == "PLD_CPM")
73 return LandmarkDetectionTaskType::PLD_CPM;
75 throw InvalidParameter("Invalid pose landmark model name.");
78 template<typename T, typename V>
79 void PoseLandmarkAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
80 const char *model_name)
83 create(convertToTaskType(model_name));
84 } catch (const BaseException &e) {
85 LOGW("A given model name is invalid so default task type will be used.");
89 _model_file = model_file;
91 _meta_file = meta_file;
93 _label_file = label_file;
95 if (_model_file.empty() && _meta_file.empty()) {
96 LOGW("Given model info is invalid so default model info will be used instead.");
100 _landmark_detection->setUserModel(_model_file, _meta_file, _label_file);
103 template<typename T, typename V>
104 void PoseLandmarkAdapter<T, V>::setEngineInfo(const char *engine_type, const char *device_type)
106 _landmark_detection->setEngineInfo(string(engine_type), string(device_type));
109 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::configure()
111 _landmark_detection->configure(_config_file_name);
114 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
116 _landmark_detection->getNumberOfEngines(number_of_engines);
119 template<typename T, typename V>
120 void PoseLandmarkAdapter<T, V>::getEngineType(unsigned int engine_index, char **engine_type)
122 _landmark_detection->getEngineType(engine_index, engine_type);
125 template<typename T, typename V>
126 void PoseLandmarkAdapter<T, V>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
128 _landmark_detection->getNumberOfDevices(engine_type, number_of_devices);
131 template<typename T, typename V>
132 void PoseLandmarkAdapter<T, V>::getDeviceType(const char *engine_type, unsigned int device_index, char **device_type)
134 _landmark_detection->getDeviceType(engine_type, device_index, device_type);
137 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::prepare()
139 _landmark_detection->prepare();
142 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::setInput(T &t)
147 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::perform()
149 _landmark_detection->perform(_source.inference_src);
152 template<typename T, typename V> void PoseLandmarkAdapter<T, V>::performAsync(T &t)
154 _landmark_detection->performAsync(t);
157 template<typename T, typename V> V &PoseLandmarkAdapter<T, V>::getOutput()
159 return _landmark_detection->getOutput();
162 template<typename T, typename V> V &PoseLandmarkAdapter<T, V>::getOutputCache()
164 throw InvalidOperation("Not support yet.");
167 template class PoseLandmarkAdapter<LandmarkDetectionInput, LandmarkDetectionResult>;