6b7e074eaa04a41aa75ecfe838ff80ed79e52465
[platform/core/api/mediavision.git] / mv_machine_learning / object_detection / src / object_detection.cpp
1 /**
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <string.h>
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <algorithm>
22
23 #include "machine_learning_exception.h"
24 #include "mv_machine_learning_common.h"
25 #include "mv_object_detection_config.h"
26 #include "object_detection.h"
27
28 using namespace std;
29 using namespace std::chrono_literals;
30 using namespace mediavision::inference;
31 using namespace MediaVision::Common;
32 using namespace mediavision::common;
33 using namespace mediavision::machine_learning::exception;
34
35 namespace mediavision
36 {
37 namespace machine_learning
38 {
39 ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
40                 : _task_type(task_type), _backendType(), _targetDeviceType()
41 {
42         _inference = make_unique<Inference>();
43         _parser = make_unique<ObjectDetectionParser>();
44 }
45
46 void ObjectDetection::preDestroy()
47 {
48         if (!_async_manager)
49                 return;
50
51         _async_manager->stop();
52 }
53
54 ObjectDetectionTaskType ObjectDetection::getTaskType()
55 {
56         return _task_type;
57 }
58
59 void ObjectDetection::getEngineList()
60 {
61         for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
62                 auto backend = _inference->getSupportedInferenceBackend(idx);
63                 // TODO. we need to describe what inference engines are supported by each Task API,
64                 //       and based on it, below inference engine types should be checked
65                 //       if a given type is supported by this Task API later. As of now, tflite only.
66                 if (backend.second == true && backend.first.compare("tflite") == 0)
67                         _valid_backends.push_back(backend.first);
68         }
69 }
70
71 void ObjectDetection::getDeviceList(const char *engine_type)
72 {
73         // TODO. add device types available for a given engine type later.
74         //       In default, cpu and gpu only.
75         _valid_devices.push_back("cpu");
76         _valid_devices.push_back("gpu");
77 }
78
79 void ObjectDetection::setEngineInfo(std::string engine_type, std::string device_type)
80 {
81         if (engine_type.empty() || device_type.empty())
82                 throw InvalidParameter("Invalid engine info.");
83
84         transform(engine_type.begin(), engine_type.end(), engine_type.begin(), ::toupper);
85         transform(device_type.begin(), device_type.end(), device_type.begin(), ::toupper);
86
87         _backendType = GetBackendType(engine_type);
88         _targetDeviceType = GetDeviceType(device_type);
89
90         LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type.c_str(), GetBackendType(engine_type),
91                  device_type.c_str(), GetDeviceType(device_type));
92
93         if (_backendType == MEDIA_VISION_ERROR_INVALID_PARAMETER ||
94                 _targetDeviceType == MEDIA_VISION_ERROR_INVALID_PARAMETER)
95                 throw InvalidParameter("backend or target device type not found.");
96 }
97
98 void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
99 {
100         if (!_valid_backends.empty()) {
101                 *number_of_engines = _valid_backends.size();
102                 return;
103         }
104
105         getEngineList();
106         *number_of_engines = _valid_backends.size();
107 }
108
109 void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type)
110 {
111         if (!_valid_backends.empty()) {
112                 if (_valid_backends.size() <= engine_index)
113                         throw InvalidParameter("Invalid engine index.");
114
115                 *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
116                 return;
117         }
118
119         getEngineList();
120
121         if (_valid_backends.size() <= engine_index)
122                 throw InvalidParameter("Invalid engine index.");
123
124         *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
125 }
126
127 void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
128 {
129         if (!_valid_devices.empty()) {
130                 *number_of_devices = _valid_devices.size();
131                 return;
132         }
133
134         getDeviceList(engine_type);
135         *number_of_devices = _valid_devices.size();
136 }
137
138 void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
139 {
140         if (!_valid_devices.empty()) {
141                 if (_valid_devices.size() <= device_index)
142                         throw InvalidParameter("Invalid device index.");
143
144                 *device_type = const_cast<char *>(_valid_devices[device_index].data());
145                 return;
146         }
147
148         getDeviceList(engine_type);
149
150         if (_valid_devices.size() <= device_index)
151                 throw InvalidParameter("Invalid device index.");
152
153         *device_type = const_cast<char *>(_valid_devices[device_index].data());
154 }
155
156 void ObjectDetection::setUserModel(string model_file, string meta_file, string label_file)
157 {
158         _modelFilePath = model_file;
159         _modelMetaFilePath = meta_file;
160         _modelLabelFilePath = label_file;
161 }
162
163 static bool IsJsonFile(const string &fileName)
164 {
165         return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
166 }
167
168 void ObjectDetection::loadLabel()
169 {
170         ifstream readFile;
171
172         _labels.clear();
173         readFile.open(_modelLabelFilePath.c_str());
174
175         if (readFile.fail())
176                 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
177
178         string line;
179
180         while (getline(readFile, line))
181                 _labels.push_back(line);
182
183         readFile.close();
184 }
185
186 void ObjectDetection::parseMetaFile(string meta_file_name)
187 {
188         _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + meta_file_name);
189
190         int ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_BACKEND_TYPE), &_backendType);
191         if (ret != MEDIA_VISION_ERROR_NONE)
192                 throw InvalidOperation("Fail to get backend engine type.");
193
194         ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE), &_targetDeviceType);
195         if (ret != MEDIA_VISION_ERROR_NONE)
196                 throw InvalidOperation("Fail to get target device type.");
197
198         ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
199         if (ret != MEDIA_VISION_ERROR_NONE)
200                 throw InvalidOperation("Fail to get model default path");
201
202         if (_modelFilePath.empty()) {
203                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
204                 if (ret != MEDIA_VISION_ERROR_NONE)
205                         throw InvalidOperation("Fail to get model file path");
206         }
207
208         _modelFilePath = _modelDefaultPath + _modelFilePath;
209         LOGI("model file path = %s", _modelFilePath.c_str());
210
211         if (_modelMetaFilePath.empty()) {
212                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
213                 if (ret != MEDIA_VISION_ERROR_NONE)
214                         throw InvalidOperation("Fail to get model meta file path");
215
216                 if (_modelMetaFilePath.empty())
217                         throw InvalidOperation("Model meta file doesn't exist.");
218
219                 if (!IsJsonFile(_modelMetaFilePath))
220                         throw InvalidOperation("Model meta file should be json");
221         }
222
223         _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
224         LOGI("meta file path = %s", _modelMetaFilePath.c_str());
225
226         _parser->setTaskType(static_cast<int>(_task_type));
227         _parser->load(_modelMetaFilePath);
228
229         if (_modelLabelFilePath.empty()) {
230                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
231                 if (ret != MEDIA_VISION_ERROR_NONE)
232                         throw InvalidOperation("Fail to get label file path");
233
234                 if (_modelLabelFilePath.empty())
235                         throw InvalidOperation("Model label file doesn't exist.");
236         }
237
238         _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
239         LOGI("label file path = %s", _modelLabelFilePath.c_str());
240
241         loadLabel();
242 }
243
244 void ObjectDetection::configure(string configFile)
245 {
246         parseMetaFile(configFile);
247
248         int ret = _inference->bind(_backendType, _targetDeviceType);
249         if (ret != MEDIA_VISION_ERROR_NONE)
250                 throw InvalidOperation("Fail to bind a backend engine.");
251 }
252
253 void ObjectDetection::prepare()
254 {
255         int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
256         if (ret != MEDIA_VISION_ERROR_NONE)
257                 throw InvalidOperation("Fail to configure input tensor info from meta file.");
258
259         ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
260         if (ret != MEDIA_VISION_ERROR_NONE)
261                 throw InvalidOperation("Fail to configure output tensor info from meta file.");
262
263         _inference->configureModelFiles("", _modelFilePath, "");
264
265         // Request to load model files to a backend engine.
266         ret = _inference->load();
267         if (ret != MEDIA_VISION_ERROR_NONE)
268                 throw InvalidOperation("Fail to load model files.");
269 }
270
271 shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
272 {
273         TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
274         IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
275
276         // TODO. consider using multiple tensors later.
277         if (tensor_info_map.size() != 1)
278                 throw InvalidOperation("Input tensor count not invalid.");
279
280         auto tensor_buffer_iter = tensor_info_map.begin();
281
282         // Get the meta information corresponding to a given input tensor name.
283         return _parser->getInputMetaMap()[tensor_buffer_iter->first];
284 }
285
286 template<typename T>
287 void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
288 {
289         LOGI("ENTER");
290
291         PreprocessConfig config = { false,
292                                                                 metaInfo->colorSpace,
293                                                                 metaInfo->dataType,
294                                                                 metaInfo->getChannel(),
295                                                                 metaInfo->getWidth(),
296                                                                 metaInfo->getHeight() };
297
298         auto normalization = static_pointer_cast<DecodingNormal>(metaInfo->decodingTypeMap.at(DecodingType::NORMAL));
299         if (normalization) {
300                 config.normalize = normalization->use;
301                 config.mean = normalization->mean;
302                 config.std = normalization->std;
303         }
304
305         auto quantization =
306                         static_pointer_cast<DecodingQuantization>(metaInfo->decodingTypeMap.at(DecodingType::QUANTIZATION));
307         if (quantization) {
308                 config.quantize = quantization->use;
309                 config.scale = quantization->scale;
310                 config.zeropoint = quantization->zeropoint;
311         }
312
313         _preprocess.setConfig(config);
314         _preprocess.run<T>(mv_src, inputVector);
315
316         LOGI("LEAVE");
317 }
318
319 template<typename T> void ObjectDetection::inference(vector<vector<T> > &inputVectors)
320 {
321         LOGI("ENTER");
322
323         int ret = _inference->run<T>(inputVectors);
324         if (ret != MEDIA_VISION_ERROR_NONE)
325                 throw InvalidOperation("Fail to run inference");
326
327         LOGI("LEAVE");
328 }
329
330 template<typename T> void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
331 {
332         vector<T> inputVector;
333
334         preprocess<T>(mv_src, metaInfo, inputVector);
335
336         vector<vector<T> > inputVectors = { inputVector };
337
338         inference<T>(inputVectors);
339
340         // TODO. Update operation status here.
341 }
342
343 void ObjectDetection::perform(mv_source_h &mv_src)
344 {
345         shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
346         if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
347                 perform<unsigned char>(mv_src, metaInfo);
348         else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
349                 perform<float>(mv_src, metaInfo);
350         else
351                 throw InvalidOperation("Invalid model data type.");
352 }
353
354 template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
355 {
356         if (!_async_manager) {
357                 _async_manager = make_unique<AsyncManager<ObjectDetectionResult> >(
358                                 [this]()
359                                 {
360                                         AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
361
362                                         inference<T>(inputQueue.inputs);
363
364                                         ObjectDetectionResult &resultQueue = result();
365
366                                         resultQueue.frame_number = inputQueue.frame_number;
367                                         _async_manager->pushToOutput(resultQueue);
368                                 }
369                 );
370         }
371
372         vector<T> inputVector;
373
374         preprocess<T>(input.inference_src, metaInfo, inputVector);
375
376         vector<vector<T> > inputVectors = { inputVector };
377
378         _async_manager->push(inputVectors);
379 }
380
381 void ObjectDetection::performAsync(ObjectDetectionInput &input)
382 {
383         shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
384
385         if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
386                 performAsync<unsigned char>(input, metaInfo);
387         } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
388                 performAsync<float>(input, metaInfo);
389                 // TODO
390         } else {
391                 throw InvalidOperation("Invalid model data type.");
392         }
393 }
394
395 ObjectDetectionResult &ObjectDetection::getOutput()
396 {
397         if (_async_manager) {
398                 if (!_async_manager->isWorking())
399                         throw InvalidOperation("Object detection has been already destroyed so invalid operation.");
400
401                 _current_result = _async_manager->pop();
402         } else {
403                 // TODO. Check if inference request is completed or not here.
404                 //       If not then throw an exception.
405                 _current_result = result();
406         }
407
408         return _current_result;
409 }
410
411 ObjectDetectionResult &ObjectDetection::getOutputCache()
412 {
413         return _current_result;
414 }
415
416 void ObjectDetection::getOutputNames(vector<string> &names)
417 {
418         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
419         IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
420
421         for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
422                 names.push_back(it->first);
423 }
424
425 void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
426 {
427         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
428
429         inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
430         if (!tensor_buffer)
431                 throw InvalidOperation("Fail to get tensor buffer.");
432
433         auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
434
435         copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
436 }
437
438 template void ObjectDetection::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
439                                                                                                  vector<float> &inputVector);
440 template void ObjectDetection::inference<float>(vector<vector<float> > &inputVectors);
441 template void ObjectDetection::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
442 template void ObjectDetection::performAsync<float>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
443
444 template void ObjectDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
445                                                                                                                  vector<unsigned char> &inputVector);
446 template void ObjectDetection::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
447 template void ObjectDetection::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
448 template void ObjectDetection::performAsync<unsigned char>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
449
450 }
451 }