2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
23 #include "machine_learning_exception.h"
24 #include "mv_machine_learning_common.h"
25 #include "mv_object_detection_config.h"
26 #include "object_detection.h"
29 using namespace std::chrono_literals;
30 using namespace mediavision::inference;
31 using namespace MediaVision::Common;
32 using namespace mediavision::common;
33 using namespace mediavision::machine_learning::exception;
37 namespace machine_learning
39 ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
40 : _task_type(task_type), _backendType(), _targetDeviceType()
42 _inference = make_unique<Inference>();
43 _parser = make_unique<ObjectDetectionParser>();
46 void ObjectDetection::preDestroy()
51 _async_manager->stop();
54 ObjectDetectionTaskType ObjectDetection::getTaskType()
59 void ObjectDetection::getEngineList()
61 for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
62 auto backend = _inference->getSupportedInferenceBackend(idx);
63 // TODO. we need to describe what inference engines are supported by each Task API,
64 // and based on it, below inference engine types should be checked
65 // if a given type is supported by this Task API later. As of now, tflite only.
66 if (backend.second == true && backend.first.compare("tflite") == 0)
67 _valid_backends.push_back(backend.first);
71 void ObjectDetection::getDeviceList(const char *engine_type)
73 // TODO. add device types available for a given engine type later.
74 // In default, cpu and gpu only.
75 _valid_devices.push_back("cpu");
76 _valid_devices.push_back("gpu");
79 void ObjectDetection::setEngineInfo(std::string engine_type, std::string device_type)
81 if (engine_type.empty() || device_type.empty())
82 throw InvalidParameter("Invalid engine info.");
84 transform(engine_type.begin(), engine_type.end(), engine_type.begin(), ::toupper);
85 transform(device_type.begin(), device_type.end(), device_type.begin(), ::toupper);
87 _backendType = GetBackendType(engine_type);
88 _targetDeviceType = GetDeviceType(device_type);
90 LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type.c_str(), GetBackendType(engine_type),
91 device_type.c_str(), GetDeviceType(device_type));
93 if (_backendType == MEDIA_VISION_ERROR_INVALID_PARAMETER ||
94 _targetDeviceType == MEDIA_VISION_ERROR_INVALID_PARAMETER)
95 throw InvalidParameter("backend or target device type not found.");
98 void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
100 if (!_valid_backends.empty()) {
101 *number_of_engines = _valid_backends.size();
106 *number_of_engines = _valid_backends.size();
109 void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type)
111 if (!_valid_backends.empty()) {
112 if (_valid_backends.size() <= engine_index)
113 throw InvalidParameter("Invalid engine index.");
115 *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
121 if (_valid_backends.size() <= engine_index)
122 throw InvalidParameter("Invalid engine index.");
124 *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
127 void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
129 if (!_valid_devices.empty()) {
130 *number_of_devices = _valid_devices.size();
134 getDeviceList(engine_type);
135 *number_of_devices = _valid_devices.size();
138 void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
140 if (!_valid_devices.empty()) {
141 if (_valid_devices.size() <= device_index)
142 throw InvalidParameter("Invalid device index.");
144 *device_type = const_cast<char *>(_valid_devices[device_index].data());
148 getDeviceList(engine_type);
150 if (_valid_devices.size() <= device_index)
151 throw InvalidParameter("Invalid device index.");
153 *device_type = const_cast<char *>(_valid_devices[device_index].data());
156 void ObjectDetection::setUserModel(string model_file, string meta_file, string label_file)
158 _modelFilePath = model_file;
159 _modelMetaFilePath = meta_file;
160 _modelLabelFilePath = label_file;
163 static bool IsJsonFile(const string &fileName)
165 return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
168 void ObjectDetection::loadLabel()
173 readFile.open(_modelLabelFilePath.c_str());
176 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
180 while (getline(readFile, line))
181 _labels.push_back(line);
186 void ObjectDetection::parseMetaFile(string meta_file_name)
188 _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + meta_file_name);
190 int ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_BACKEND_TYPE), &_backendType);
191 if (ret != MEDIA_VISION_ERROR_NONE)
192 throw InvalidOperation("Fail to get backend engine type.");
194 ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE), &_targetDeviceType);
195 if (ret != MEDIA_VISION_ERROR_NONE)
196 throw InvalidOperation("Fail to get target device type.");
198 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
199 if (ret != MEDIA_VISION_ERROR_NONE)
200 throw InvalidOperation("Fail to get model default path");
202 if (_modelFilePath.empty()) {
203 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
204 if (ret != MEDIA_VISION_ERROR_NONE)
205 throw InvalidOperation("Fail to get model file path");
208 _modelFilePath = _modelDefaultPath + _modelFilePath;
209 LOGI("model file path = %s", _modelFilePath.c_str());
211 if (_modelMetaFilePath.empty()) {
212 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
213 if (ret != MEDIA_VISION_ERROR_NONE)
214 throw InvalidOperation("Fail to get model meta file path");
216 if (_modelMetaFilePath.empty())
217 throw InvalidOperation("Model meta file doesn't exist.");
219 if (!IsJsonFile(_modelMetaFilePath))
220 throw InvalidOperation("Model meta file should be json");
223 _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
224 LOGI("meta file path = %s", _modelMetaFilePath.c_str());
226 _parser->setTaskType(static_cast<int>(_task_type));
227 _parser->load(_modelMetaFilePath);
229 if (_modelLabelFilePath.empty()) {
230 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
231 if (ret != MEDIA_VISION_ERROR_NONE)
232 throw InvalidOperation("Fail to get label file path");
234 if (_modelLabelFilePath.empty())
235 throw InvalidOperation("Model label file doesn't exist.");
238 _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
239 LOGI("label file path = %s", _modelLabelFilePath.c_str());
244 void ObjectDetection::configure(string configFile)
246 parseMetaFile(configFile);
248 int ret = _inference->bind(_backendType, _targetDeviceType);
249 if (ret != MEDIA_VISION_ERROR_NONE)
250 throw InvalidOperation("Fail to bind a backend engine.");
253 void ObjectDetection::prepare()
255 int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
256 if (ret != MEDIA_VISION_ERROR_NONE)
257 throw InvalidOperation("Fail to configure input tensor info from meta file.");
259 ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
260 if (ret != MEDIA_VISION_ERROR_NONE)
261 throw InvalidOperation("Fail to configure output tensor info from meta file.");
263 _inference->configureModelFiles("", _modelFilePath, "");
265 // Request to load model files to a backend engine.
266 ret = _inference->load();
267 if (ret != MEDIA_VISION_ERROR_NONE)
268 throw InvalidOperation("Fail to load model files.");
271 shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
273 TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
274 IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
276 // TODO. consider using multiple tensors later.
277 if (tensor_info_map.size() != 1)
278 throw InvalidOperation("Input tensor count not invalid.");
280 auto tensor_buffer_iter = tensor_info_map.begin();
282 // Get the meta information corresponding to a given input tensor name.
283 return _parser->getInputMetaMap()[tensor_buffer_iter->first];
287 void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
291 PreprocessConfig config = { false,
292 metaInfo->colorSpace,
294 metaInfo->getChannel(),
295 metaInfo->getWidth(),
296 metaInfo->getHeight() };
298 auto normalization = static_pointer_cast<DecodingNormal>(metaInfo->decodingTypeMap.at(DecodingType::NORMAL));
300 config.normalize = normalization->use;
301 config.mean = normalization->mean;
302 config.std = normalization->std;
306 static_pointer_cast<DecodingQuantization>(metaInfo->decodingTypeMap.at(DecodingType::QUANTIZATION));
308 config.quantize = quantization->use;
309 config.scale = quantization->scale;
310 config.zeropoint = quantization->zeropoint;
313 _preprocess.setConfig(config);
314 _preprocess.run<T>(mv_src, inputVector);
319 template<typename T> void ObjectDetection::inference(vector<vector<T> > &inputVectors)
323 int ret = _inference->run<T>(inputVectors);
324 if (ret != MEDIA_VISION_ERROR_NONE)
325 throw InvalidOperation("Fail to run inference");
330 template<typename T> void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
332 vector<T> inputVector;
334 preprocess<T>(mv_src, metaInfo, inputVector);
336 vector<vector<T> > inputVectors = { inputVector };
338 inference<T>(inputVectors);
340 // TODO. Update operation status here.
343 void ObjectDetection::perform(mv_source_h &mv_src)
345 shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
346 if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
347 perform<unsigned char>(mv_src, metaInfo);
348 else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
349 perform<float>(mv_src, metaInfo);
351 throw InvalidOperation("Invalid model data type.");
354 template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
356 if (!_async_manager) {
357 _async_manager = make_unique<AsyncManager<ObjectDetectionResult> >(
360 AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
362 inference<T>(inputQueue.inputs);
364 ObjectDetectionResult &resultQueue = result();
366 resultQueue.frame_number = inputQueue.frame_number;
367 _async_manager->pushToOutput(resultQueue);
372 vector<T> inputVector;
374 preprocess<T>(input.inference_src, metaInfo, inputVector);
376 vector<vector<T> > inputVectors = { inputVector };
378 _async_manager->push(inputVectors);
381 void ObjectDetection::performAsync(ObjectDetectionInput &input)
383 shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
385 if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
386 performAsync<unsigned char>(input, metaInfo);
387 } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
388 performAsync<float>(input, metaInfo);
391 throw InvalidOperation("Invalid model data type.");
395 ObjectDetectionResult &ObjectDetection::getOutput()
397 if (_async_manager) {
398 if (!_async_manager->isWorking())
399 throw InvalidOperation("Object detection has been already destroyed so invalid operation.");
401 _current_result = _async_manager->pop();
403 // TODO. Check if inference request is completed or not here.
404 // If not then throw an exception.
405 _current_result = result();
408 return _current_result;
411 ObjectDetectionResult &ObjectDetection::getOutputCache()
413 return _current_result;
416 void ObjectDetection::getOutputNames(vector<string> &names)
418 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
419 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
421 for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
422 names.push_back(it->first);
425 void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
427 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
429 inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
431 throw InvalidOperation("Fail to get tensor buffer.");
433 auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
435 copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
438 template void ObjectDetection::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
439 vector<float> &inputVector);
440 template void ObjectDetection::inference<float>(vector<vector<float> > &inputVectors);
441 template void ObjectDetection::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
442 template void ObjectDetection::performAsync<float>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
444 template void ObjectDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
445 vector<unsigned char> &inputVector);
446 template void ObjectDetection::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
447 template void ObjectDetection::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
448 template void ObjectDetection::performAsync<unsigned char>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);