2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
22 #include "machine_learning_exception.h"
23 #include "mv_machine_learning_common.h"
24 #include "mv_object_detection_config.h"
25 #include "object_detection.h"
28 using namespace mediavision::inference;
29 using namespace MediaVision::Common;
30 using namespace mediavision::common;
31 using namespace mediavision::machine_learning::exception;
35 namespace machine_learning
37 ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type, shared_ptr<MachineLearningConfig> config)
38 : _task_type(task_type), _config(config)
40 _inference = make_unique<Inference>();
43 void ObjectDetection::preDestroy()
48 _async_manager->stop();
51 ObjectDetectionTaskType ObjectDetection::getTaskType()
56 void ObjectDetection::getEngineList()
58 for (auto idx = MV_INFERENCE_BACKEND_NONE + 1; idx < MV_INFERENCE_BACKEND_MAX; ++idx) {
59 auto backend = _inference->getSupportedInferenceBackend(idx);
60 // TODO. we need to describe what inference engines are supported by each Task API,
61 // and based on it, below inference engine types should be checked
62 // if a given type is supported by this Task API later. As of now, tflite only.
63 if (backend.second == true && backend.first.compare("tflite") == 0)
64 _valid_backends.push_back(backend.first);
68 void ObjectDetection::getDeviceList(const char *engine_type)
70 // TODO. add device types available for a given engine type later.
71 // In default, cpu and gpu only.
72 _valid_devices.push_back("cpu");
73 _valid_devices.push_back("gpu");
76 void ObjectDetection::setEngineInfo(std::string engine_type_name, std::string device_type_name)
78 if (engine_type_name.empty() || device_type_name.empty())
79 throw InvalidParameter("Invalid engine info.");
81 transform(engine_type_name.begin(), engine_type_name.end(), engine_type_name.begin(), ::toupper);
82 transform(device_type_name.begin(), device_type_name.end(), device_type_name.begin(), ::toupper);
84 int engine_type = GetBackendType(engine_type_name);
85 int device_type = GetDeviceType(device_type_name);
87 if (engine_type == MEDIA_VISION_ERROR_INVALID_PARAMETER || device_type == MEDIA_VISION_ERROR_INVALID_PARAMETER)
88 throw InvalidParameter("backend or target device type not found.");
90 _config->setBackendType(engine_type);
91 _config->setTargetDeviceType(device_type);
93 LOGI("Engine type : %s => %d, Device type : %s => %d", engine_type_name.c_str(), engine_type,
94 device_type_name.c_str(), device_type);
97 void ObjectDetection::getNumberOfEngines(unsigned int *number_of_engines)
99 if (!_valid_backends.empty()) {
100 *number_of_engines = _valid_backends.size();
105 *number_of_engines = _valid_backends.size();
108 void ObjectDetection::getEngineType(unsigned int engine_index, char **engine_type)
110 if (!_valid_backends.empty()) {
111 if (_valid_backends.size() <= engine_index)
112 throw InvalidParameter("Invalid engine index.");
114 *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
120 if (_valid_backends.size() <= engine_index)
121 throw InvalidParameter("Invalid engine index.");
123 *engine_type = const_cast<char *>(_valid_backends[engine_index].data());
126 void ObjectDetection::getNumberOfDevices(const char *engine_type, unsigned int *number_of_devices)
128 if (!_valid_devices.empty()) {
129 *number_of_devices = _valid_devices.size();
133 getDeviceList(engine_type);
134 *number_of_devices = _valid_devices.size();
137 void ObjectDetection::getDeviceType(const char *engine_type, const unsigned int device_index, char **device_type)
139 if (!_valid_devices.empty()) {
140 if (_valid_devices.size() <= device_index)
141 throw InvalidParameter("Invalid device index.");
143 *device_type = const_cast<char *>(_valid_devices[device_index].data());
147 getDeviceList(engine_type);
149 if (_valid_devices.size() <= device_index)
150 throw InvalidParameter("Invalid device index.");
152 *device_type = const_cast<char *>(_valid_devices[device_index].data());
155 void ObjectDetection::loadLabel()
157 if (_config->getLabelFilePath().empty())
163 readFile.open(_config->getLabelFilePath().c_str());
166 throw InvalidOperation("Fail to open " + _config->getLabelFilePath() + " file.");
170 while (getline(readFile, line))
171 _labels.push_back(line);
176 void ObjectDetection::configure()
178 _config->loadMetaFile(make_unique<ObjectDetectionParser>(static_cast<int>(_task_type)));
181 int ret = _inference->bind(_config->getBackendType(), _config->getTargetDeviceType());
182 if (ret != MEDIA_VISION_ERROR_NONE)
183 throw InvalidOperation("Fail to bind a backend engine.");
186 void ObjectDetection::prepare()
188 int ret = _inference->configureInputMetaInfo(_config->getInputMetaMap());
189 if (ret != MEDIA_VISION_ERROR_NONE)
190 throw InvalidOperation("Fail to configure input tensor info from meta file.");
192 ret = _inference->configureOutputMetaInfo(_config->getOutputMetaMap());
193 if (ret != MEDIA_VISION_ERROR_NONE)
194 throw InvalidOperation("Fail to configure output tensor info from meta file.");
196 _inference->configureModelFiles("", _config->getModelFilePath(), "");
198 // Request to load model files to a backend engine.
199 ret = _inference->load();
200 if (ret != MEDIA_VISION_ERROR_NONE)
201 throw InvalidOperation("Fail to load model files.");
204 shared_ptr<MetaInfo> ObjectDetection::getInputMetaInfo()
206 TensorBuffer &tensor_buffer = _inference->getInputTensorBuffer();
207 IETensorBuffer &tensor_info_map = tensor_buffer.getIETensorBuffer();
209 // TODO. consider using multiple tensors later.
210 if (tensor_info_map.size() != 1)
211 throw InvalidOperation("Input tensor count not invalid.");
213 auto tensor_buffer_iter = tensor_info_map.begin();
215 // Get the meta information corresponding to a given input tensor name.
216 return _config->getInputMetaMap()[tensor_buffer_iter->first];
220 void ObjectDetection::preprocess(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo, vector<T> &inputVector)
224 PreprocessConfig config = { false,
225 metaInfo->colorSpace,
227 metaInfo->getChannel(),
228 metaInfo->getWidth(),
229 metaInfo->getHeight() };
231 auto normalization = static_pointer_cast<DecodingNormal>(metaInfo->decodingTypeMap.at(DecodingType::NORMAL));
233 config.normalize = normalization->use;
234 config.mean = normalization->mean;
235 config.std = normalization->std;
239 static_pointer_cast<DecodingQuantization>(metaInfo->decodingTypeMap.at(DecodingType::QUANTIZATION));
241 config.quantize = quantization->use;
242 config.scale = quantization->scale;
243 config.zeropoint = quantization->zeropoint;
246 _preprocess.setConfig(config);
247 _preprocess.run<T>(mv_src, inputVector);
252 template<typename T> void ObjectDetection::inference(vector<vector<T> > &inputVectors)
256 int ret = _inference->run<T>(inputVectors);
257 if (ret != MEDIA_VISION_ERROR_NONE)
258 throw InvalidOperation("Fail to run inference");
263 template<typename T> void ObjectDetection::perform(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo)
265 vector<T> inputVector;
267 preprocess<T>(mv_src, metaInfo, inputVector);
269 vector<vector<T> > inputVectors = { inputVector };
271 inference<T>(inputVectors);
273 // TODO. Update operation status here.
276 void ObjectDetection::perform(mv_source_h &mv_src)
278 shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
279 if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8)
280 perform<unsigned char>(mv_src, metaInfo);
281 else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32)
282 perform<float>(mv_src, metaInfo);
284 throw InvalidOperation("Invalid model data type.");
287 template<typename T> void ObjectDetection::performAsync(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo)
289 if (!_async_manager) {
290 _async_manager = make_unique<AsyncManager<ObjectDetectionResult> >([this]() {
291 AsyncInputQueue<T> inputQueue = _async_manager->popFromInput<T>();
293 inference<T>(inputQueue.inputs);
295 ObjectDetectionResult &resultQueue = result();
297 resultQueue.frame_number = inputQueue.frame_number;
298 _async_manager->pushToOutput(resultQueue);
302 vector<T> inputVector;
304 preprocess<T>(input.inference_src, metaInfo, inputVector);
306 vector<vector<T> > inputVectors = { inputVector };
308 _async_manager->push(inputVectors);
311 void ObjectDetection::performAsync(ObjectDetectionInput &input)
313 shared_ptr<MetaInfo> metaInfo = getInputMetaInfo();
315 if (metaInfo->dataType == MV_INFERENCE_DATA_UINT8) {
316 performAsync<unsigned char>(input, metaInfo);
317 } else if (metaInfo->dataType == MV_INFERENCE_DATA_FLOAT32) {
318 performAsync<float>(input, metaInfo);
321 throw InvalidOperation("Invalid model data type.");
325 ObjectDetectionResult &ObjectDetection::getOutput()
327 if (_async_manager) {
328 if (!_async_manager->isWorking())
329 throw InvalidOperation("Object detection has been already destroyed so invalid operation.");
331 _current_result = _async_manager->pop();
333 // TODO. Check if inference request is completed or not here.
334 // If not then throw an exception.
335 _current_result = result();
338 return _current_result;
341 ObjectDetectionResult &ObjectDetection::getOutputCache()
343 return _current_result;
346 void ObjectDetection::getOutputNames(vector<string> &names)
348 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
349 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
351 for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
352 names.push_back(it->first);
355 void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
357 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
359 inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
361 throw InvalidOperation("Fail to get tensor buffer.");
363 auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
365 copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
368 template void ObjectDetection::preprocess<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
369 vector<float> &inputVector);
370 template void ObjectDetection::inference<float>(vector<vector<float> > &inputVectors);
371 template void ObjectDetection::perform<float>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
372 template void ObjectDetection::performAsync<float>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);
374 template void ObjectDetection::preprocess<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo,
375 vector<unsigned char> &inputVector);
376 template void ObjectDetection::inference<unsigned char>(vector<vector<unsigned char> > &inputVectors);
377 template void ObjectDetection::perform<unsigned char>(mv_source_h &mv_src, shared_ptr<MetaInfo> metaInfo);
378 template void ObjectDetection::performAsync<unsigned char>(ObjectDetectionInput &input, shared_ptr<MetaInfo> metaInfo);