2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
22 #include "machine_learning_exception.h"
23 #include "mv_object_detection_config.h"
24 #include "object_detection.h"
27 using namespace mediavision::inference;
28 using namespace MediaVision::Common;
29 using namespace mediavision::machine_learning::exception;
33 namespace machine_learning
35 ObjectDetection::ObjectDetection() : _backendType(), _targetDeviceType()
37 _inference = make_unique<Inference>();
38 _parser = make_unique<ObjectDetectionParser>();
41 static bool IsJsonFile(const string &fileName)
43 return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
46 void ObjectDetection::parseMetaFile()
48 _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + string(MV_OBJECT_DETECTION_META_FILE_NAME));
50 int ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_BACKEND_TYPE), &_backendType);
51 if (ret != MEDIA_VISION_ERROR_NONE)
52 throw InvalidOperation("Fail to get backend engine type.");
54 ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE), &_targetDeviceType);
55 if (ret != MEDIA_VISION_ERROR_NONE)
56 throw InvalidOperation("Fail to get target device type.");
58 string modelDefaultPath;
60 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &modelDefaultPath);
61 if (ret != MEDIA_VISION_ERROR_NONE)
62 throw InvalidOperation("Fail to get model default path");
64 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
65 if (ret != MEDIA_VISION_ERROR_NONE)
66 throw InvalidOperation("Fail to get model file path");
68 _modelFilePath = modelDefaultPath + _modelFilePath;
70 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
71 if (ret != MEDIA_VISION_ERROR_NONE)
72 throw InvalidOperation("Fail to get model meta file path");
74 if (_modelMetaFilePath.empty())
75 throw InvalidOperation("Model meta file doesn't exist.");
77 if (!IsJsonFile(_modelMetaFilePath))
78 throw InvalidOperation("Model meta file should be json");
80 _modelMetaFilePath = modelDefaultPath + _modelMetaFilePath;
82 _parser->load(_modelMetaFilePath);
85 void ObjectDetection::configure()
87 int ret = _inference->bind(_backendType, _targetDeviceType);
88 if (ret != MEDIA_VISION_ERROR_NONE)
89 throw InvalidOperation("Fail to bind a backend engine.");
92 void ObjectDetection::prepare()
94 int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
95 if (ret != MEDIA_VISION_ERROR_NONE)
96 throw InvalidOperation("Fail to configure input tensor info from meta file.");
98 ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
99 if (ret != MEDIA_VISION_ERROR_NONE)
100 throw InvalidOperation("Fail to configure output tensor info from meta file.");
102 _inference->configureModelFiles("", _modelFilePath, "");
104 // Request to load model files to a backend engine.
105 ret = _inference->load();
106 if (ret != MEDIA_VISION_ERROR_NONE)
107 throw InvalidOperation("Fail to load model files.");
109 void ObjectDetection::preprocess(mv_source_h &mv_src)
113 TensorBuffer &tensor_buffer_obj = _inference->getInputTensorBuffer();
114 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
115 vector<mv_source_h> mv_srcs = { mv_src };
117 _preprocess.run(mv_srcs, _parser->getInputMetaMap(), ie_tensor_buffer);
122 void ObjectDetection::inference(mv_source_h source)
126 vector<mv_source_h> sources;
128 sources.push_back(source);
130 int ret = _inference->run();
131 if (ret != MEDIA_VISION_ERROR_NONE)
132 throw InvalidOperation("Fail to run inference");
137 void ObjectDetection::getOutputNames(vector<string> &names)
139 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
140 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
142 for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
143 names.push_back(it->first);
146 void ObjectDetection::getOutputTensor(string &target_name, vector<float> &tensor)
150 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
152 inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
154 throw InvalidOperation("Fail to get tensor buffer.");
156 auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
158 copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));