mv_machine_learning: add async API for face detection task group
[platform/core/api/mediavision.git] / mv_machine_learning / object_detection / src / mv_face_detection.cpp
1 /**
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "mv_private.h"
18 #include "itask.h"
19 #include "mv_face_detection_internal.h"
20 #include "face_detection_adapter.h"
21 #include "machine_learning_exception.h"
22 #include "object_detection_type.h"
23 #include "context.h"
24
25 #include <new>
26 #include <unistd.h>
27 #include <string>
28 #include <algorithm>
29 #include <mutex>
30 #include <iostream>
31
32 using namespace std;
33 using namespace mediavision::inference;
34 using namespace mediavision::common;
35 using namespace mediavision::machine_learning;
36 using namespace MediaVision::Common;
37 using namespace mediavision::machine_learning::exception;
38 using FaceDetectionTask = ITask<ObjectDetectionInput, ObjectDetectionResult>;
39
40 static mutex g_face_detection_mutex;
41
42 int mv_face_detection_create(mv_face_detection_h *handle)
43 {
44         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
45         MEDIA_VISION_NULL_ARG_CHECK(handle);
46
47         MEDIA_VISION_FUNCTION_ENTER();
48
49         Context *context = nullptr;
50         FaceDetectionTask *task = nullptr;
51
52         try {
53                 context = new Context();
54                 task = new FaceDetectionAdapter<ObjectDetectionInput, ObjectDetectionResult>();
55                 context->__tasks.insert(make_pair("face_detection", task));
56                 *handle = static_cast<mv_face_detection_h>(context);
57         } catch (const BaseException &e) {
58                 delete task;
59                 delete context;
60                 return e.getError();
61         }
62
63         MEDIA_VISION_FUNCTION_LEAVE();
64
65         return MEDIA_VISION_ERROR_NONE;
66 }
67
68 int mv_face_detection_destroy(mv_face_detection_h handle)
69 {
70         // TODO. find proper solution later.
71         // For thread safety, lock is needed here but if async API is used then dead lock occurs
72         // because mv_face_detection_destroy_open function acquires a lock and,
73         // while waiting for the thread loop to finish, the same lock is also acquired
74         // within functions - mv_face_detection_get_result_open and mv_face_detection_get_label_open
75         // - called to obtain results from the thread loop.
76
77         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
78         MEDIA_VISION_INSTANCE_CHECK(handle);
79
80         MEDIA_VISION_FUNCTION_ENTER();
81
82         auto context = static_cast<Context *>(handle);
83
84         for (auto &m : context->__tasks)
85                 delete static_cast<FaceDetectionTask *>(m.second);
86
87         delete context;
88
89         MEDIA_VISION_FUNCTION_LEAVE();
90
91         return MEDIA_VISION_ERROR_NONE;
92 }
93
94 int mv_face_detection_set_model(mv_face_detection_h handle, const char *model_name, const char *model_file,
95                                                                 const char *meta_file, const char *label_file)
96 {
97         lock_guard<mutex> lock(g_face_detection_mutex);
98
99         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
100
101         MEDIA_VISION_INSTANCE_CHECK(handle);
102         MEDIA_VISION_INSTANCE_CHECK(model_name);
103         MEDIA_VISION_NULL_ARG_CHECK(model_file);
104         MEDIA_VISION_NULL_ARG_CHECK(meta_file);
105         MEDIA_VISION_NULL_ARG_CHECK(label_file);
106
107         MEDIA_VISION_FUNCTION_ENTER();
108
109         try {
110                 auto context = static_cast<Context *>(handle);
111                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
112
113                 task->setModelInfo(model_file, meta_file, label_file, model_name);
114         } catch (const BaseException &e) {
115                 LOGE("%s", e.what());
116                 return e.getError();
117         }
118
119         MEDIA_VISION_FUNCTION_LEAVE();
120
121         return MEDIA_VISION_ERROR_NONE;
122 }
123
124 int mv_face_detection_set_engine(mv_face_detection_h handle, const char *backend_type, const char *device_type)
125 {
126         lock_guard<mutex> lock(g_face_detection_mutex);
127
128         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
129
130         MEDIA_VISION_INSTANCE_CHECK(handle);
131         MEDIA_VISION_NULL_ARG_CHECK(backend_type);
132         MEDIA_VISION_NULL_ARG_CHECK(device_type);
133
134         MEDIA_VISION_FUNCTION_ENTER();
135
136         try {
137                 auto context = static_cast<Context *>(handle);
138                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
139
140                 task->setEngineInfo(backend_type, device_type);
141         } catch (const BaseException &e) {
142                 LOGE("%s", e.what());
143                 return e.getError();
144         }
145
146         MEDIA_VISION_FUNCTION_LEAVE();
147
148         return MEDIA_VISION_ERROR_NONE;
149 }
150
151 int mv_face_detection_get_engine_count(mv_face_detection_h handle, unsigned int *engine_count)
152 {
153         lock_guard<mutex> lock(g_face_detection_mutex);
154
155         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
156
157         MEDIA_VISION_INSTANCE_CHECK(handle);
158         MEDIA_VISION_NULL_ARG_CHECK(engine_count);
159
160         MEDIA_VISION_FUNCTION_ENTER();
161
162         try {
163                 auto context = static_cast<Context *>(handle);
164                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
165
166                 task->getNumberOfEngines(engine_count);
167                 // TODO.
168         } catch (const BaseException &e) {
169                 LOGE("%s", e.what());
170                 return e.getError();
171         }
172
173         MEDIA_VISION_FUNCTION_LEAVE();
174
175         return MEDIA_VISION_ERROR_NONE;
176 }
177
178 int mv_face_detection_get_engine_type(mv_face_detection_h handle, const unsigned int engine_index, char **engine_type)
179 {
180         lock_guard<mutex> lock(g_face_detection_mutex);
181
182         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
183
184         MEDIA_VISION_INSTANCE_CHECK(handle);
185         MEDIA_VISION_NULL_ARG_CHECK(engine_type);
186
187         MEDIA_VISION_FUNCTION_ENTER();
188
189         try {
190                 auto context = static_cast<Context *>(handle);
191                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
192
193                 task->getEngineType(engine_index, engine_type);
194                 // TODO.
195         } catch (const BaseException &e) {
196                 LOGE("%s", e.what());
197                 return e.getError();
198         }
199
200         MEDIA_VISION_FUNCTION_LEAVE();
201
202         return MEDIA_VISION_ERROR_NONE;
203 }
204
205 int mv_face_detection_get_device_count(mv_face_detection_h handle, const char *engine_type, unsigned int *device_count)
206 {
207         lock_guard<mutex> lock(g_face_detection_mutex);
208
209         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
210
211         MEDIA_VISION_INSTANCE_CHECK(handle);
212         MEDIA_VISION_NULL_ARG_CHECK(device_count);
213
214         MEDIA_VISION_FUNCTION_ENTER();
215
216         try {
217                 auto context = static_cast<Context *>(handle);
218                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
219
220                 task->getNumberOfDevices(engine_type, device_count);
221                 // TODO.
222         } catch (const BaseException &e) {
223                 LOGE("%s", e.what());
224                 return e.getError();
225         }
226
227         MEDIA_VISION_FUNCTION_LEAVE();
228
229         return MEDIA_VISION_ERROR_NONE;
230 }
231
232 int mv_face_detection_get_device_type(mv_face_detection_h handle, const char *engine_type,
233                                                                           const unsigned int device_index, char **device_type)
234 {
235         lock_guard<mutex> lock(g_face_detection_mutex);
236
237         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_face_check_system_info_feature_supported());
238
239         MEDIA_VISION_INSTANCE_CHECK(handle);
240         MEDIA_VISION_NULL_ARG_CHECK(engine_type);
241         MEDIA_VISION_NULL_ARG_CHECK(device_type);
242
243         MEDIA_VISION_FUNCTION_ENTER();
244
245         try {
246                 auto context = static_cast<Context *>(handle);
247                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
248
249                 task->getDeviceType(engine_type, device_index, device_type);
250                 // TODO.
251         } catch (const BaseException &e) {
252                 LOGE("%s", e.what());
253                 return e.getError();
254         }
255
256         MEDIA_VISION_FUNCTION_LEAVE();
257
258         return MEDIA_VISION_ERROR_NONE;
259 }
260
261 int mv_face_detection_configure(mv_face_detection_h handle)
262 {
263         lock_guard<mutex> lock(g_face_detection_mutex);
264
265         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
266         MEDIA_VISION_INSTANCE_CHECK(handle);
267
268         MEDIA_VISION_FUNCTION_ENTER();
269
270         try {
271                 auto context = static_cast<Context *>(handle);
272                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
273
274                 task->configure();
275         } catch (const BaseException &e) {
276                 LOGE("%s", e.what());
277                 return e.getError();
278         }
279
280         MEDIA_VISION_FUNCTION_LEAVE();
281
282         return MEDIA_VISION_ERROR_NONE;
283 }
284
285 int mv_face_detection_prepare(mv_face_detection_h handle)
286 {
287         lock_guard<mutex> lock(g_face_detection_mutex);
288
289         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_check_system_info_feature_supported());
290         MEDIA_VISION_INSTANCE_CHECK(handle);
291
292         MEDIA_VISION_FUNCTION_ENTER();
293
294         try {
295                 auto context = static_cast<Context *>(handle);
296                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
297
298                 task->prepare();
299         } catch (const BaseException &e) {
300                 LOGE("%s", e.what());
301                 return e.getError();
302         }
303
304         MEDIA_VISION_FUNCTION_LEAVE();
305
306         return MEDIA_VISION_ERROR_NONE;
307 }
308
309 int mv_face_detection_inference(mv_face_detection_h handle, mv_source_h source)
310 {
311         lock_guard<mutex> lock(g_face_detection_mutex);
312
313         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
314         MEDIA_VISION_INSTANCE_CHECK(source);
315         MEDIA_VISION_INSTANCE_CHECK(handle);
316
317         MEDIA_VISION_FUNCTION_ENTER();
318
319         try {
320                 auto context = static_cast<Context *>(handle);
321                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
322
323                 ObjectDetectionInput input = { .inference_src = source };
324
325                 task->setInput(input);
326                 task->perform();
327         } catch (const BaseException &e) {
328                 LOGE("%s", e.what());
329                 return e.getError();
330         }
331
332         MEDIA_VISION_FUNCTION_LEAVE();
333
334         return MEDIA_VISION_ERROR_NONE;
335 }
336
337 int mv_face_detection_inference_async(mv_face_detection_h handle, mv_source_h source, mv_completion_cb completion_cb,
338                                                                           void *user_data)
339 {
340         LOGD("ENTER");
341
342         lock_guard<mutex> lock(g_face_detection_mutex);
343
344         if (!handle) {
345                 LOGE("Handle is NULL.");
346                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
347         }
348
349         try {
350                 auto context = static_cast<Context *>(handle);
351                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
352
353                 ObjectDetectionInput input = { handle, source, completion_cb, user_data };
354
355                 task->performAsync(input);
356         } catch (const BaseException &e) {
357                 LOGE("%s", e.what());
358                 return e.getError();
359         }
360
361         LOGD("LEAVE");
362
363         return MEDIA_VISION_ERROR_NONE;
364 }
365
366 int mv_face_detection_get_result(mv_face_detection_h handle, unsigned int *number_of_objects,
367                                                                  const unsigned int **indices, const float **confidences, const int **left,
368                                                                  const int **top, const int **right, const int **bottom)
369 {
370         lock_guard<mutex> lock(g_face_detection_mutex);
371
372         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
373         MEDIA_VISION_INSTANCE_CHECK(handle);
374         MEDIA_VISION_INSTANCE_CHECK(number_of_objects);
375         MEDIA_VISION_INSTANCE_CHECK(indices);
376         MEDIA_VISION_INSTANCE_CHECK(confidences);
377         MEDIA_VISION_INSTANCE_CHECK(left);
378         MEDIA_VISION_INSTANCE_CHECK(top);
379         MEDIA_VISION_INSTANCE_CHECK(right);
380         MEDIA_VISION_INSTANCE_CHECK(bottom);
381
382         MEDIA_VISION_FUNCTION_ENTER();
383
384         try {
385                 auto context = static_cast<Context *>(handle);
386                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
387
388                 ObjectDetectionResult &result = task->getOutput();
389                 *number_of_objects = result.number_of_objects;
390                 *indices = result.indices.data();
391                 *confidences = result.confidences.data();
392                 *left = result.left.data();
393                 *top = result.top.data();
394                 *right = result.right.data();
395                 *bottom = result.bottom.data();
396         } catch (const BaseException &e) {
397                 LOGE("%s", e.what());
398                 return e.getError();
399         }
400
401         MEDIA_VISION_FUNCTION_LEAVE();
402
403         return MEDIA_VISION_ERROR_NONE;
404 }
405
406 int mv_face_detection_get_label(mv_face_detection_h handle, const unsigned int index, const char **out_label)
407 {
408         lock_guard<mutex> lock(g_face_detection_mutex);
409
410         MEDIA_VISION_SUPPORT_CHECK(_mv_inference_image_check_system_info_feature_supported());
411         MEDIA_VISION_INSTANCE_CHECK(handle);
412         MEDIA_VISION_INSTANCE_CHECK(out_label);
413
414         MEDIA_VISION_FUNCTION_ENTER();
415
416         try {
417                 auto context = static_cast<Context *>(handle);
418                 auto task = static_cast<FaceDetectionTask *>(context->__tasks.at("face_detection"));
419
420                 ObjectDetectionResult &result = task->getOutput();
421
422                 if (result.number_of_objects <= index)
423                         throw InvalidParameter("Invalid index range.");
424
425                 *out_label = result.names[index].c_str();
426         } catch (const BaseException &e) {
427                 LOGE("%s", e.what());
428                 return e.getError();
429         }
430
431         MEDIA_VISION_FUNCTION_LEAVE();
432
433         return MEDIA_VISION_ERROR_NONE;
434 }