mv_machine_learning: convert ObjectDetection class into template class
[platform/core/api/mediavision.git] / mv_machine_learning / object_detection / src / object_detection_adapter.cpp
1 /**
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "machine_learning_exception.h"
18 #include "object_detection_adapter.h"
19 #include "object_detection_external.h"
20 #include "mv_object_detection_config.h"
21
22 using namespace std;
23 using namespace MediaVision::Common;
24 using namespace mediavision::machine_learning;
25 using namespace mediavision::machine_learning::exception;
26
27 namespace mediavision
28 {
29 namespace machine_learning
30 {
31 template<typename T, typename V> ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source()
32 {
33         _config = make_shared<MachineLearningConfig>();
34         _config->parseConfigFile(_config_file_name);
35
36         ObjectDetectionTaskType model_type = convertToTaskType(_config->getDefaultModelName());
37         create(model_type);
38 }
39
40 template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionAdapter()
41 {
42         _object_detection->preDestroy();
43 }
44
45 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
46 {
47         _config->loadMetaFile(make_unique<ObjectDetectionParser>(static_cast<int>(task_type)));
48         mv_inference_data_type_e dataType = _config->getInputMetaMap().begin()->second->dataType;
49
50         switch (task_type) {
51         case ObjectDetectionTaskType::MOBILENET_V1_SSD:
52                 if (dataType == MV_INFERENCE_DATA_UINT8)
53                         _object_detection = make_unique<MobilenetV1Ssd<unsigned char> >(task_type, _config);
54                 else if (dataType == MV_INFERENCE_DATA_FLOAT32)
55                         _object_detection = make_unique<MobilenetV1Ssd<float> >(task_type, _config);
56                 else
57                         throw InvalidOperation("Invalid model data type.");
58                 break;
59         case ObjectDetectionTaskType::MOBILENET_V2_SSD:
60                 if (dataType == MV_INFERENCE_DATA_UINT8)
61                         _object_detection = make_unique<MobilenetV2Ssd<unsigned char> >(task_type, _config);
62                 else if (dataType == MV_INFERENCE_DATA_FLOAT32)
63                         _object_detection = make_unique<MobilenetV2Ssd<float> >(task_type, _config);
64                 else
65                         throw InvalidOperation("Invalid model data type.");
66                 break;
67         case ObjectDetectionTaskType::OD_PLUGIN:
68                 _object_detection = make_unique<ObjectDetectionExternal>(task_type);
69                 break;
70         case ObjectDetectionTaskType::FD_PLUGIN:
71                 _object_detection = make_unique<ObjectDetectionExternal>(task_type);
72                 break;
73         default:
74                 throw InvalidOperation("Invalid object detection task type.");
75         }
76         // TODO.
77 }
78
79 template<typename T, typename V>
80 ObjectDetectionTaskType ObjectDetectionAdapter<T, V>::convertToTaskType(string model_name)
81 {
82         if (model_name.empty())
83                 throw InvalidParameter("model name is empty.");
84
85         transform(model_name.begin(), model_name.end(), model_name.begin(), ::toupper);
86
87         if (model_name == "OD_PLUGIN")
88                 return ObjectDetectionTaskType::OD_PLUGIN;
89         else if (model_name == "FD_PLUGIN")
90                 return ObjectDetectionTaskType::FD_PLUGIN;
91         else if (model_name == "MOBILENET_V1_SSD")
92                 return ObjectDetectionTaskType::MOBILENET_V1_SSD;
93         else if (model_name == "MOBILENET_V2_SSD")
94                 return ObjectDetectionTaskType::MOBILENET_V2_SSD;
95         // TODO.
96
97         throw InvalidParameter("Invalid object detection model name.");
98 }
99
100 template<typename T, typename V>
101 void ObjectDetectionAdapter<T, V>::setModelInfo(const char *model_file, const char *meta_file, const char *label_file,
102                                                                                                 const char *model_name)
103 {
104         try {
105                 _config->setUserModel(model_file, meta_file, label_file);
106
107                 ObjectDetectionTaskType model_type = convertToTaskType(model_name);
108                 create(model_type);
109         } catch (const BaseException &e) {
110                 LOGW("A given model name is invalid so default task type will be used.");
111         }
112
113         if (!model_file && !meta_file) {
114                 LOGW("Given model info is invalid so default model info will be used instead.");
115                 return;
116         }
117 }
118
119 template<typename T, typename V>
120 void ObjectDetectionAdapter<T, V>::setEngineInfo(const char *engine_type, const char *device_type)
121 {
122         _object_detection->setEngineInfo(string(engine_type), string(device_type));
123 }
124
125 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::configure()
126 {
127         _object_detection->configure();
128 }
129
130 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)
131 {
132         _object_detection->getNumberOfEngines(number_of_engines);
133 }
134
135 template<typename T, typename V>
136 void ObjectDetectionAdapter<T, V>::getEngineType(unsigned int engine_index, char **engine_type)
137 {
138         _object_detection->getEngineType(engine_index, engine_type);
139 }
140
141 template<typename T, typename V>
142 void ObjectDetectionAdapter<T, V>::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
143 {
144         _object_detection->getNumberOfDevices(engine_type, number_of_devices);
145 }
146
147 template<typename T, typename V>
148 void ObjectDetectionAdapter<T, V>::getDeviceType(const char *engine_type, unsigned int device_index, char **device_type)
149 {
150         _object_detection->getDeviceType(engine_type, device_index, device_type);
151 }
152
153 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::prepare()
154 {
155         _object_detection->prepare();
156 }
157
158 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::setInput(T &t)
159 {
160         _source = t;
161 }
162
163 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::perform()
164 {
165         _object_detection->perform(_source.inference_src);
166 }
167
168 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutput()
169 {
170         return _object_detection->getOutput();
171 }
172
173 template<typename T, typename V> V &ObjectDetectionAdapter<T, V>::getOutputCache()
174 {
175         return _object_detection->getOutputCache();
176 }
177
178 template<typename T, typename V> void ObjectDetectionAdapter<T, V>::performAsync(T &t)
179 {
180         _object_detection->performAsync(t);
181 }
182
183 template class ObjectDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>;
184 }
185 }