mv_machine_learning: embed AsyncManager into each task group
[platform/core/api/mediavision.git] / mv_machine_learning / object_detection / src / object_detection.cpp
1 /**
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <string.h>
18 #include <fstream>
19 #include <map>
20 #include <memory>
21
22 #include "machine_learning_exception.h"
23 #include "mv_machine_learning_common.h"
24 #include "mv_object_detection_config.h"
25 #include "object_detection.h"
26
27 using namespace std;
28 using namespace mediavision::inference;
29 using namespace MediaVision::Common;
30 using namespace mediavision::common;
31 using namespace mediavision::machine_learning::exception;
32
33 namespace mediavision
34 {
35 namespace machine_learning
36 {
37 ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
38                 : _task_type(task_type), _backendType(), _targetDeviceType()
39 {
40         _inference = make_unique<Inference>();
41         _parser = make_unique<ObjectDetectionParser>();
42 }
43
44 void ObjectDetection::preDestroy()
45 {
46         if (!_async_manager)
47                 return;
48
49         _async_manager->stop();
50 }
51
52 ObjectDetectionTaskType ObjectDetection::getTaskType()
53 {
54         return _task_type;
55 }
56
57 void ObjectDetection::getEngineList()
58 {
59         for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
60                 auto backend = _inference->getSupportedInferenceBackend(idx);
61                 // TODO. we need to describe what inference engines are supported by each Task API,
62                 //       and based on it, below inference engine types should be checked
63                 //       if a given type is supported by this Task API later. As of now, tflite only.
64                 if (backend.second == true && backend.first.compare("tflite") == 0)
65                         _valid_backends.push_back(backend.first);
66         }
67 }
68
69 void ObjectDetection::getDeviceList(const char *engine_type)
70 {
71         // TODO. add device types available for a given engine type later.
72         //       In default, cpu and gpu only.
73         _valid_devices.push_back("cpu");
74         _valid_devices.push_back("gpu");
75 }
76
77 void ObjectDetection::setEngineInfo(std::string engine_type, std::string device_type)
78 {
79         if (engine_type.empty() || device_type.empty())
80                 throw InvalidParameter("Invalid engine info.");
81
82         transform(engine_type.begin(), engine_type.end(), engine_type.begin(), ::toupper);
83         transform(device_type.begin(), device_type.end(), device_type.begin(), ::toupper);
84
85         _backendType = GetBackendType(engine_type);
86         _targetDeviceType = GetDeviceType(device_type);
87
88         LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type.c_str(), GetBackendType(engine_type),
89                  device_type.c_str(), GetDeviceType(device_type));
90
91         if (_backendType == MEDIA_VISION_ERROR_INVALID_PARAMETER ||
92                 _targetDeviceType == MEDIA_VISION_ERROR_INVALID_PARAMETER)
93                 throw InvalidParameter("backend or target device type not found.");
94 }
95
96 void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
97 {
98         if (!_valid_backends.empty()) {
99                 *number_of_engines = _valid_backends.size();
100                 return;
101         }
102
103         getEngineList();
104         *number_of_engines = _valid_backends.size();
105 }
106
107 void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type)
108 {
109         if (!_valid_backends.empty()) {
110                 if (_valid_backends.size() <= engine_index)
111                         throw InvalidParameter("Invalid engine index.");
112
113                 *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
114                 return;
115         }
116
117         getEngineList();
118
119         if (_valid_backends.size() <= engine_index)
120                 throw InvalidParameter("Invalid engine index.");
121
122         *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
123 }
124
125 void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
126 {
127         if (!_valid_devices.empty()) {
128                 *number_of_devices = _valid_devices.size();
129                 return;
130         }
131
132         getDeviceList(engine_type);
133         *number_of_devices = _valid_devices.size();
134 }
135
136 void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
137 {
138         if (!_valid_devices.empty()) {
139                 if (_valid_devices.size() <= device_index)
140                         throw InvalidParameter("Invalid device index.");
141
142                 *device_type = const_cast<char *>(_valid_devices[device_index].data());
143                 return;
144         }
145
146         getDeviceList(engine_type);
147
148         if (_valid_devices.size() <= device_index)
149                 throw InvalidParameter("Invalid device index.");
150
151         *device_type = const_cast<char *>(_valid_devices[device_index].data());
152 }
153
154 void ObjectDetection::setUserModel(string model_file, string meta_file, string label_file)
155 {
156         _modelFilePath = model_file;
157         _modelMetaFilePath = meta_file;
158         _modelLabelFilePath = label_file;
159 }
160
161 static bool IsJsonFile(const string &fileName)
162 {
163         return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
164 }
165
166 void ObjectDetection::loadLabel()
167 {
168         ifstream readFile;
169
170         _labels.clear();
171         readFile.open(_modelLabelFilePath.c_str());
172
173         if (readFile.fail())
174                 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
175
176         string line;
177
178         while (getline(readFile, line))
179                 _labels.push_back(line);
180
181         readFile.close();
182 }
183
184 void ObjectDetection::parseMetaFile(string meta_file_name)
185 {
186         _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + meta_file_name);
187
188         int ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_BACKEND_TYPE), &_backendType);
189         if (ret != MEDIA_VISION_ERROR_NONE)
190                 throw InvalidOperation("Fail to get backend engine type.");
191
192         ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE), &_targetDeviceType);
193         if (ret != MEDIA_VISION_ERROR_NONE)
194                 throw InvalidOperation("Fail to get target device type.");
195
196         ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
197         if (ret != MEDIA_VISION_ERROR_NONE)
198                 throw InvalidOperation("Fail to get model default path");
199
200         if (_modelFilePath.empty()) {
201                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
202                 if (ret != MEDIA_VISION_ERROR_NONE)
203                         throw InvalidOperation("Fail to get model file path");
204         }
205
206         _modelFilePath = _modelDefaultPath + _modelFilePath;
207         LOGI("model file path = %s", _modelFilePath.c_str());
208
209         if (_modelMetaFilePath.empty()) {
210                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
211                 if (ret != MEDIA_VISION_ERROR_NONE)
212                         throw InvalidOperation("Fail to get model meta file path");
213
214                 if (_modelMetaFilePath.empty())
215                         throw InvalidOperation("Model meta file doesn't exist.");
216
217                 if (!IsJsonFile(_modelMetaFilePath))
218                         throw InvalidOperation("Model meta file should be json");
219         }
220
221         _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
222         LOGI("meta file path = %s", _modelMetaFilePath.c_str());
223
224         _parser->setTaskType(static_cast<int>(_task_type));
225         _parser->load(_modelMetaFilePath);
226
227         if (_modelLabelFilePath.empty()) {
228                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
229                 if (ret != MEDIA_VISION_ERROR_NONE)
230                         throw InvalidOperation("Fail to get label file path");
231
232                 if (_modelLabelFilePath.empty())
233                         throw InvalidOperation("Model label file doesn't exist.");
234         }
235
236         _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
237         LOGI("label file path = %s", _modelLabelFilePath.c_str());
238
239         loadLabel();
240 }
241
242 void ObjectDetection::configure(string configFile)
243 {
244         parseMetaFile(configFile);
245
246         int ret = _inference->bind(_backendType, _targetDeviceType);
247         if (ret != MEDIA_VISION_ERROR_NONE)
248                 throw InvalidOperation("Fail to bind a backend engine.");
249 }
250
251 void ObjectDetection::prepare()
252 {
253         int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
254         if (ret != MEDIA_VISION_ERROR_NONE)
255                 throw InvalidOperation("Fail to configure input tensor info from meta file.");
256
257         ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
258         if (ret != MEDIA_VISION_ERROR_NONE)
259                 throw InvalidOperation("Fail to configure output tensor info from meta file.");
260
261         _inference->configureModelFiles("", _modelFilePath, "");
262
263         // Request to load model files to a backend engine.
264         ret = _inference->load();
265         if (ret != MEDIA_VISION_ERROR_NONE)
266                 throw InvalidOperation("Fail to load model files.");
267 }
268
269 shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
270 {
271         TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
272         IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
273
274         // TODO. consider using multiple tensors later.
275         if (tensor_info_map.size() != 1)
276                 throw InvalidOperation("Input tensor count not invalid.");
277
278         auto tensor_buffer_iter = tensor_info_map.begin();
279
280         // Get the meta information corresponding to a given input tensor name.
281         return _parser->getInputMetaMap()[tensor_buffer_iter->first];
282 }
283
284 template<typename T>
285 void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
286 {
287         LOGI("ENTER");
288
289         PreprocessConfig config = { false,
290                                                                 metaInfo->colorSpace,
291                                                                 metaInfo->dataType,
292                                                                 metaInfo->getChannel(),
293                                                                 metaInfo->getWidth(),
294                                                                 metaInfo->getHeight() };
295
296         auto normalization = static_pointer_cast<DecodingNormal>(metaInfo->decodingTypeMap.at(DecodingType::NORMAL));
297         if (normalization) {
298                 config.normalize = normalization->use;
299                 config.mean = normalization->mean;
300                 config.std = normalization->std;
301         }
302
303         auto quantization =
304                         static_pointer_cast<DecodingQuantization>(metaInfo->decodingTypeMap.at(DecodingType::QUANTIZATION));
305         if (quantization) {
306                 config.quantize = quantization->use;
307                 config.scale = quantization->scale;
308                 config.zeropoint = quantization->zeropoint;
309         }
310
311         _preprocess.setConfig(config);
312         _preprocess.run<T>(mv_src, inputVector);
313
314         LOGI("LEAVE");
315 }
316
317 template<typename T> void ObjectDetection::inference(vector<vector<T> > &inputVectors)
318 {
319         LOGI("ENTER");
320
321         int ret = _inference->run<T>(inputVectors);
322         if (ret != MEDIA_VISION_ERROR_NONE)
323                 throw InvalidOperation("Fail to run inference");
324
325         LOGI("LEAVE");
326 }
327
328 template<typename T> void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
329 {
330         vector<T> inputVector;
331
332         preprocess<T>(mv_src, metaInfo, inputVector);
333
334         vector<vector<T> > inputVectors = { inputVector };
335
336         inference<T>(inputVectors);
337
338         // TODO. Update operation status here.
339 }
340
341 void ObjectDetection::perform(mv_source_h &mv_src)
342 {
343         shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
344         if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
345                 perform<unsigned char>(mv_src, metaInfo);
346         else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
347                 perform<float>(mv_src, metaInfo);
348         else
349                 throw InvalidOperation("Invalid model data type.");
350 }
351
352 template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
353 {
354         if (!_async_manager) {
355                 _async_manager = make_unique<AsyncManager<ObjectDetectionResult> >([this]() {
356                         AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
357
358                         inference<T>(inputQueue.inputs);
359
360                         ObjectDetectionResult &resultQueue = result();
361
362                         resultQueue.frame_number = inputQueue.frame_number;
363                         _async_manager->pushToOutput(resultQueue);
364                 });
365         }
366
367         vector<T> inputVector;
368
369         preprocess<T>(input.inference_src, metaInfo, inputVector);
370
371         vector<vector<T> > inputVectors = { inputVector };
372
373         _async_manager->push(inputVectors);
374 }
375
376 void ObjectDetection::performAsync(ObjectDetectionInput &input)
377 {
378         shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
379
380         if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
381                 performAsync<unsigned char>(input, metaInfo);
382         } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
383                 performAsync<float>(input, metaInfo);
384                 // TODO
385         } else {
386                 throw InvalidOperation("Invalid model data type.");
387         }
388 }
389
390 ObjectDetectionResult &ObjectDetection::getOutput()
391 {
392         if (_async_manager) {
393                 if (!_async_manager->isWorking())
394                         throw InvalidOperation("Object detection has been already destroyed so invalid operation.");
395
396                 _current_result = _async_manager->pop();
397         } else {
398                 // TODO. Check if inference request is completed or not here.
399                 //       If not then throw an exception.
400                 _current_result = result();
401         }
402
403         return _current_result;
404 }
405
406 ObjectDetectionResult &ObjectDetection::getOutputCache()
407 {
408         return _current_result;
409 }
410
411 void ObjectDetection::getOutputNames(vector<string> &names)
412 {
413         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
414         IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
415
416         for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
417                 names.push_back(it->first);
418 }
419
420 void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
421 {
422         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
423
424         inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
425         if (!tensor_buffer)
426                 throw InvalidOperation("Fail to get tensor buffer.");
427
428         auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
429
430         copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
431 }
432
433 template void ObjectDetection::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
434                                                                                                  vector<float> &inputVector);
435 template void ObjectDetection::inference<float>(vector<vector<float> > &inputVectors);
436 template void ObjectDetection::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
437 template void ObjectDetection::performAsync<float>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
438
439 template void ObjectDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
440                                                                                                                  vector<unsigned char> &inputVector);
441 template void ObjectDetection::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
442 template void ObjectDetection::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
443 template void ObjectDetection::performAsync<unsigned char>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
444
445 }
446 }