mv_machine_learning: introduce engine configuration task API
[platform/core/api/mediavision.git] / mv_machine_learning / object_detection / src / object_detection.cpp
1 /**
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <string.h>
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <algorithm>
22
23 #include "machine_learning_exception.h"
24 #include "mv_object_detection_config.h"
25 #include "object_detection.h"
26
27 using namespace std;
28 using namespace mediavision::inference;
29 using namespace MediaVision::Common;
30 using namespace mediavision::machine_learning::exception;
31
32 namespace mediavision
33 {
34 namespace machine_learning
35 {
36 ObjectDetection::ObjectDetection(ObjectDetectionTaskType task_type)
37                 : _task_type(task_type), _backendType(), _targetDeviceType()
38 {
39         _inference = make_unique<Inference>();
40         _parser = make_unique<ObjectDetectionParser>();
41 }
42
43 void ObjectDetection::setUserModel(string model_file, string meta_file, string label_file)
44 {
45         _modelFilePath = model_file;
46         _modelMetaFilePath = meta_file;
47         _modelLabelFilePath = label_file;
48 }
49
50 static bool IsJsonFile(const string &fileName)
51 {
52         return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
53 }
54
55 void ObjectDetection::loadLabel()
56 {
57         ifstream readFile;
58
59         _labels.clear();
60         readFile.open(_modelLabelFilePath.c_str());
61
62         if (readFile.fail())
63                 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
64
65         string line;
66
67         while (getline(readFile, line))
68                 _labels.push_back(line);
69
70         readFile.close();
71 }
72
73 void ObjectDetection::parseMetaFile()
74 {
75         _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + string(MV_OBJECT_DETECTION_META_FILE_NAME));
76
77         int ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_BACKEND_TYPE), &_backendType);
78         if (ret != MEDIA_VISION_ERROR_NONE)
79                 throw InvalidOperation("Fail to get backend engine type.");
80
81         ret = _config->getIntegerAttribute(string(MV_OBJECT_DETECTION_TARGET_DEVICE_TYPE), &_targetDeviceType);
82         if (ret != MEDIA_VISION_ERROR_NONE)
83                 throw InvalidOperation("Fail to get target device type.");
84
85         ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_DEFAULT_PATH, &_modelDefaultPath);
86         if (ret != MEDIA_VISION_ERROR_NONE)
87                 throw InvalidOperation("Fail to get model default path");
88
89         if (_modelFilePath.empty()) {
90                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_FILE_PATH, &_modelFilePath);
91                 if (ret != MEDIA_VISION_ERROR_NONE)
92                         throw InvalidOperation("Fail to get model file path");
93         }
94
95         _modelFilePath = _modelDefaultPath + _modelFilePath;
96         LOGI("model file path = %s", _modelFilePath.c_str());
97
98         if (_modelMetaFilePath.empty()) {
99                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_MODEL_META_FILE_PATH, &_modelMetaFilePath);
100                 if (ret != MEDIA_VISION_ERROR_NONE)
101                         throw InvalidOperation("Fail to get model meta file path");
102
103                 if (_modelMetaFilePath.empty())
104                         throw InvalidOperation("Model meta file doesn't exist.");
105
106                 if (!IsJsonFile(_modelMetaFilePath))
107                         throw InvalidOperation("Model meta file should be json");
108         }
109
110         _modelMetaFilePath = _modelDefaultPath + _modelMetaFilePath;
111         LOGI("meta file path = %s", _modelMetaFilePath.c_str());
112
113         _parser->setTaskType(static_cast<int>(_task_type));
114         _parser->load(_modelMetaFilePath);
115
116         if (_modelLabelFilePath.empty()) {
117                 ret = _config->getStringAttribute(MV_OBJECT_DETECTION_LABEL_FILE_NAME, &_modelLabelFilePath);
118                 if (ret != MEDIA_VISION_ERROR_NONE)
119                         throw InvalidOperation("Fail to get label file path");
120
121                 if (_modelLabelFilePath.empty())
122                         throw InvalidOperation("Model label file doesn't exist.");
123         }
124
125         _modelLabelFilePath = _modelDefaultPath + _modelLabelFilePath;
126         LOGI("label file path = %s", _modelLabelFilePath.c_str());
127
128         loadLabel();
129 }
130
131 void ObjectDetection::configure()
132 {
133         int ret = _inference->bind(_backendType, _targetDeviceType);
134         if (ret != MEDIA_VISION_ERROR_NONE)
135                 throw InvalidOperation("Fail to bind a backend engine.");
136 }
137
138 void ObjectDetection::prepare()
139 {
140         int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
141         if (ret != MEDIA_VISION_ERROR_NONE)
142                 throw InvalidOperation("Fail to configure input tensor info from meta file.");
143
144         ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
145         if (ret != MEDIA_VISION_ERROR_NONE)
146                 throw InvalidOperation("Fail to configure output tensor info from meta file.");
147
148         _inference->configureModelFiles("", _modelFilePath, "");
149
150         // Request to load model files to a backend engine.
151         ret = _inference->load();
152         if (ret != MEDIA_VISION_ERROR_NONE)
153                 throw InvalidOperation("Fail to load model files.");
154 }
155
156 void ObjectDetection::preprocess(mv_source_h &mv_src)
157 {
158         LOGI("ENTER");
159
160         TensorBuffer &tensor_buffer_obj = _inference->getInputTensorBuffer();
161         IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
162         vector<mv_source_h> mv_srcs = { mv_src };
163
164         _preprocess.run(mv_srcs, _parser->getInputMetaMap(), ie_tensor_buffer);
165
166         LOGI("LEAVE");
167 }
168
169 void ObjectDetection::inference(mv_source_h source)
170 {
171         LOGI("ENTER");
172
173         vector<mv_source_h> sources;
174
175         sources.push_back(source);
176
177         int ret = _inference->run();
178         if (ret != MEDIA_VISION_ERROR_NONE)
179                 throw InvalidOperation("Fail to run inference");
180
181         LOGI("LEAVE");
182 }
183
184 void ObjectDetection::getOutputNames(vector<string> &names)
185 {
186         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
187         IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
188
189         for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
190                 names.push_back(it->first);
191 }
192
193 void ObjectDetection::getOutputTensor(string target_name, vector<float> &tensor)
194 {
195         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
196
197         inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
198         if (!tensor_buffer)
199                 throw InvalidOperation("Fail to get tensor buffer.");
200
201         auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
202
203         copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
204 }
205
206 }
207 }