2 * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
23 #include "machine_learning_exception.h"
24 #include "mv_object_detection_config.h"
25 #include "object_detection.h"
28 using namespace mediavision::inference;
29 using namespace MediaVision::Common;
30 using namespace mediavision::machine_learning::exception;
34 namespace machine_learning
36 ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
37 : _task_type(task_type), _backendType(), _targetDeviceType()
39 _inference = make_unique<Inference>();
40 _parser = make_unique<ObjectDetectionParser>();
43 void ObjectDetection::setUserModel(string model_file, string meta_file, string label_file)
45 _modelFilePath = model_file;
46 _modelMetaFilePath = meta_file;
47 _modelLabelFilePath = label_file;
50 static bool IsJsonFile(const string &fileName)
52 return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
55 void ObjectDetection::loadLabel()
60 readFile.open(_modelLabelFilePath.c_str());
63 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
67 while (getline(readFile, line))
68 _labels.push_back(line);
73 void ObjectDetection::parseMetaFile()
75 _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + string(MV_OBJECT_DETECTION_META_FILE_NAME));
77 int ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_BACKEND_TYPE), &_backendType);
78 if (ret != MEDIA_VISION_ERROR_NONE)
79 throw InvalidOperation("Fail to get backend engine type.");
81 ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE), &_targetDeviceType);
82 if (ret != MEDIA_VISION_ERROR_NONE)
83 throw InvalidOperation("Fail to get target device type.");
85 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
86 if (ret != MEDIA_VISION_ERROR_NONE)
87 throw InvalidOperation("Fail to get model default path");
89 if (_modelFilePath.empty()) {
90 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
91 if (ret != MEDIA_VISION_ERROR_NONE)
92 throw InvalidOperation("Fail to get model file path");
95 _modelFilePath = _modelDefaultPath + _modelFilePath;
96 LOGI("model file path = %s", _modelFilePath.c_str());
98 if (_modelMetaFilePath.empty()) {
99 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
100 if (ret != MEDIA_VISION_ERROR_NONE)
101 throw InvalidOperation("Fail to get model meta file path");
103 if (_modelMetaFilePath.empty())
104 throw InvalidOperation("Model meta file doesn't exist.");
106 if (!IsJsonFile(_modelMetaFilePath))
107 throw InvalidOperation("Model meta file should be json");
110 _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
111 LOGI("meta file path = %s", _modelMetaFilePath.c_str());
113 _parser->setTaskType(static_cast<int>(_task_type));
114 _parser->load(_modelMetaFilePath);
116 if (_modelLabelFilePath.empty()) {
117 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
118 if (ret != MEDIA_VISION_ERROR_NONE)
119 throw InvalidOperation("Fail to get label file path");
121 if (_modelLabelFilePath.empty())
122 throw InvalidOperation("Model label file doesn't exist.");
125 _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
126 LOGI("label file path = %s", _modelLabelFilePath.c_str());
131 void ObjectDetection::configure()
133 int ret = _inference->bind(_backendType, _targetDeviceType);
134 if (ret != MEDIA_VISION_ERROR_NONE)
135 throw InvalidOperation("Fail to bind a backend engine.");
138 void ObjectDetection::prepare()
140 int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
141 if (ret != MEDIA_VISION_ERROR_NONE)
142 throw InvalidOperation("Fail to configure input tensor info from meta file.");
144 ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
145 if (ret != MEDIA_VISION_ERROR_NONE)
146 throw InvalidOperation("Fail to configure output tensor info from meta file.");
148 _inference->configureModelFiles("", _modelFilePath, "");
150 // Request to load model files to a backend engine.
151 ret = _inference->load();
152 if (ret != MEDIA_VISION_ERROR_NONE)
153 throw InvalidOperation("Fail to load model files.");
156 void ObjectDetection::preprocess(mv_source_h &mv_src)
160 TensorBuffer &tensor_buffer_obj = _inference->getInputTensorBuffer();
161 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
162 vector<mv_source_h> mv_srcs = { mv_src };
164 _preprocess.run(mv_srcs, _parser->getInputMetaMap(), ie_tensor_buffer);
169 void ObjectDetection::inference(mv_source_h source)
173 vector<mv_source_h> sources;
175 sources.push_back(source);
177 int ret = _inference->run();
178 if (ret != MEDIA_VISION_ERROR_NONE)
179 throw InvalidOperation("Fail to run inference");
184 void ObjectDetection::getOutputNames(vector<string> &names)
186 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
187 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
189 for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
190 names.push_back(it->first);
193 void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
195 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
197 inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
199 throw InvalidOperation("Fail to get tensor buffer.");
201 auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
203 copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));