2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "machine_learning_exception.h"
18 #include "object_detection_adapter.h"
19 #include "object_detection_external.h"
20 #include "mv_object_detection_config.h"
23 using namespace MediaVision::Common;
24 using namespace mediavision::machine_learning;
25 using namespace mediavision::machine_learning::exception;
29 namespace machine_learning
31 template<typename T, typename V> ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source()
33 _config = make_shared<MachineLearningConfig>();
34 _config->parseConfigFile(_config_file_name);
36 ObjectDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
40 template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionAdapter()
42 _object_detection->preDestroy();
45 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
47 _config->loadMetaFile(make_unique<ObjectDetectionParser>(static_cast<int>(task_type)));
48 mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
51 case ObjectDetectionTaskType::MOBILENET_V1_SSD:
52 if (dataType == MV_INFERENCE_DATA_UINT8)
53 _object_detection = make_unique<MobilenetV1Ssd<unsigned char> >(task_type, _config);
54 else if (dataType == MV_INFERENCE_DATA_FLOAT32)
55 _object_detection = make_unique<MobilenetV1Ssd<float> >(task_type, _config);
57 throw InvalidOperation("Invalid model data type.");
59 case ObjectDetectionTaskType::MOBILENET_V2_SSD:
60 if (dataType == MV_INFERENCE_DATA_UINT8)
61 _object_detection = make_unique<MobilenetV2Ssd<unsigned char> >(task_type, _config);
62 else if (dataType == MV_INFERENCE_DATA_FLOAT32)
63 _object_detection = make_unique<MobilenetV2Ssd<float> >(task_type, _config);
65 throw InvalidOperation("Invalid model data type.");
67 case ObjectDetectionTaskType::OD_PLUGIN:
68 _object_detection = make_unique<ObjectDetectionExternal>(task_type);
70 case ObjectDetectionTaskType::FD_PLUGIN:
71 _object_detection = make_unique<ObjectDetectionExternal>(task_type);
74 throw InvalidOperation("Invalid object detection task type.");
79 template<typename T, typename V>
80 ObjectDetectionTaskType ObjectDetectionAdapter<T, V>::convertToTaskType(string model_name)
82 if (model_name.empty())
83 throw InvalidParameter("model name is empty.");
85 transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
87 if (model_name == "OD_PLUGIN")
88 return ObjectDetectionTaskType::OD_PLUGIN;
89 else if (model_name == "FD_PLUGIN")
90 return ObjectDetectionTaskType::FD_PLUGIN;
91 else if (model_name == "MOBILENET_V1_SSD")
92 return ObjectDetectionTaskType::MOBILENET_V1_SSD;
93 else if (model_name == "MOBILENET_V2_SSD")
94 return ObjectDetectionTaskType::MOBILENET_V2_SSD;
97 throw InvalidParameter("Invalid object detection model name.");
100 template<typename T, typename V>
101 void ObjectDetectionAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
102 const char *model_name)
105 _config->setUserModel(model_file, meta_file, label_file);
107 ObjectDetectionTaskType model_type = convertToTaskType(model_name);
109 } catch (const BaseException &e) {
110 LOGW("A given model name is invalid so default task type will be used.");
113 if (!model_file && !meta_file) {
114 LOGW("Given model info is invalid so default model info will be used instead.");
119 template<typename T, typename V>
120 void ObjectDetectionAdapter<T, V>::setEngineInfo(const char *engine_type, const char *device_type)
122 _object_detection->setEngineInfo(string(engine_type), string(device_type));
125 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::configure()
127 _object_detection->configure();
130 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
132 _object_detection->getNumberOfEngines(number_of_engines);
135 template<typename T, typename V>
136 void ObjectDetectionAdapter<T, V>::getEngineType(unsigned int engine_index, char **engine_type)
138 _object_detection->getEngineType(engine_index, engine_type);
141 template<typename T, typename V>
142 void ObjectDetectionAdapter<T, V>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
144 _object_detection->getNumberOfDevices(engine_type, number_of_devices);
147 template<typename T, typename V>
148 void ObjectDetectionAdapter<T, V>::getDeviceType(const char *engine_type, unsigned int device_index, char **device_type)
150 _object_detection->getDeviceType(engine_type, device_index, device_type);
153 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::prepare()
155 _object_detection->prepare();
158 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::setInput(T &t)
163 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::perform()
165 _object_detection->perform(_source.inference_src);
168 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutput()
170 return _object_detection->getOutput();
173 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutputCache()
175 return _object_detection->getOutputCache();
178 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::performAsync(T &t)
180 _object_detection->performAsync(t);
183 template class ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>;