2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "machine_learning_exception.h"
18 #include "object_detection_adapter.h"
19 #include "object_detection_external.h"
20 #include "mv_object_detection_config.h"
23 using namespace MediaVision::Common;
24 using namespace mediavision::machine_learning;
25 using namespace mediavision::machine_learning::exception;
29 namespace machine_learning
31 template<typename T, typename V> ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source()
33 _config = make_shared<MachineLearningConfig>();
34 _config->parseConfigFile(_config_file_name);
36 ObjectDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
40 template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionAdapter()
42 _object_detection->preDestroy();
45 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
47 // If a concrete class object created already exists, reset the object
48 // so that other concrete class object can be created again according to a given task_type.
49 if (_object_detection) {
50 // If default task type is same as a given one then skip.
51 if (_object_detection->getTaskType() == task_type)
55 if (task_type == ObjectDetectionTaskType::MOBILENET_V1_SSD)
56 _object_detection = make_unique<MobilenetV1Ssd>(task_type, _config);
57 else if (task_type == ObjectDetectionTaskType::MOBILENET_V2_SSD)
58 _object_detection = make_unique<MobilenetV2Ssd>(task_type, _config);
59 else if (task_type == ObjectDetectionTaskType::OD_PLUGIN || task_type == ObjectDetectionTaskType::FD_PLUGIN)
60 _object_detection = make_unique<ObjectDetectionExternal>(task_type);
64 template<typename T, typename V>
65 ObjectDetectionTaskType ObjectDetectionAdapter<T, V>::convertToTaskType(string model_name)
67 if (model_name.empty())
68 throw InvalidParameter("model name is empty.");
70 transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
72 if (model_name == "OD_PLUGIN")
73 return ObjectDetectionTaskType::OD_PLUGIN;
74 else if (model_name == "FD_PLUGIN")
75 return ObjectDetectionTaskType::FD_PLUGIN;
76 else if (model_name == "MOBILENET_V1_SSD")
77 return ObjectDetectionTaskType::MOBILENET_V1_SSD;
78 else if (model_name == "MOBILENET_V2_SSD")
79 return ObjectDetectionTaskType::MOBILENET_V2_SSD;
82 throw InvalidParameter("Invalid object detection model name.");
85 template<typename T, typename V>
86 void ObjectDetectionAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
87 const char *model_name)
90 _config->setUserModel(model_file, meta_file, label_file);
92 ObjectDetectionTaskType model_type = convertToTaskType(model_name);
94 } catch (const BaseException &e) {
95 LOGW("A given model name is invalid so default task type will be used.");
98 if (!model_file && !meta_file) {
99 LOGW("Given model info is invalid so default model info will be used instead.");
104 template<typename T, typename V>
105 void ObjectDetectionAdapter<T, V>::setEngineInfo(const char *engine_type, const char *device_type)
107 _object_detection->setEngineInfo(string(engine_type), string(device_type));
110 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::configure()
112 _object_detection->configure();
115 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
117 _object_detection->getNumberOfEngines(number_of_engines);
120 template<typename T, typename V>
121 void ObjectDetectionAdapter<T, V>::getEngineType(unsigned int engine_index, char **engine_type)
123 _object_detection->getEngineType(engine_index, engine_type);
126 template<typename T, typename V>
127 void ObjectDetectionAdapter<T, V>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
129 _object_detection->getNumberOfDevices(engine_type, number_of_devices);
132 template<typename T, typename V>
133 void ObjectDetectionAdapter<T, V>::getDeviceType(const char *engine_type, unsigned int device_index, char **device_type)
135 _object_detection->getDeviceType(engine_type, device_index, device_type);
138 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::prepare()
140 _object_detection->prepare();
143 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::setInput(T &t)
148 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::perform()
150 _object_detection->perform(_source.inference_src);
153 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutput()
155 return _object_detection->getOutput();
158 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutputCache()
160 return _object_detection->getOutputCache();
163 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::performAsync(T &t)
165 _object_detection->performAsync(t);
168 template class ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>;