mv_machine_learning: use carnel notation
[platform/core/api/mediavision.git] / mv_machine_learning / image_classification / src / image_classification.cpp
1 /**
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <string.h>
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <algorithm>
22
23 #include "machine_learning_exception.h"
24 #include "mv_image_classification_config.h"
25 #include "image_classification.h"
26
27 using namespace std;
28 using namespace mediavision::inference;
29 using namespace MediaVision::Common;
30 using namespace mediavision::machine_learning::exception;
31
32 namespace mediavision
33 {
34 namespace machine_learning
35 {
36 ImageClassification::ImageClassification() : _backendType(), _targetDeviceType()
37 {
38         _inference = make_unique<Inference>();
39         _parser = make_unique<ImageClassificationParser>();
40 }
41
42 void ImageClassification::configure()
43 {
44         int ret = _inference->bind(_backendType, _targetDeviceType);
45         if (ret != MEDIA_VISION_ERROR_NONE)
46                 throw InvalidOperation("Fail to bind a backend engine.");
47 }
48
49 static bool IsJsonFile(const string &fileName)
50 {
51         return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
52 }
53
54 void ImageClassification::loadLabel()
55 {
56         ifstream readFile;
57
58         _labels.clear();
59         readFile.open(_modelLabelFilePath.c_str());
60
61         if (readFile.fail())
62                 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
63
64         string line;
65
66         while (getline(readFile, line))
67                 _labels.push_back(line);
68
69         readFile.close();
70 }
71
72 void ImageClassification::setUserModel(string model_file, string meta_file, string label_file)
73 {
74         _modelFilePath = model_file;
75         _modelMetaFilePath = meta_file;
76         _modelLabelFilePath = label_file;
77 }
78
79 void ImageClassification::parseMetaFile()
80 {
81         _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + string(MV_IMAGE_CLASSIFICATION_CONFIG_FILE_NAME));
82
83         int ret = _config->getIntegerAttribute(string(MV_IMAGE_CLASSIFICATION_BACKEND_TYPE), &_backendType);
84         if (ret != MEDIA_VISION_ERROR_NONE)
85                 throw InvalidOperation("Fail to get backend engine type.");
86
87         ret = _config->getIntegerAttribute(string(MV_IMAGE_CLASSIFICATION_TARGET_DEVICE_TYPE), &_targetDeviceType);
88         if (ret != MEDIA_VISION_ERROR_NONE)
89                 throw InvalidOperation("Fail to get target device type.");
90
91         string modelDefaultPath;
92
93         ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_DEFAULT_PATH, &modelDefaultPath);
94         if (ret != MEDIA_VISION_ERROR_NONE)
95                 throw InvalidOperation("Fail to get model default path");
96
97         if (_modelFilePath.empty()) {
98                 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_FILE_NAME, &_modelFilePath);
99                 if (ret != MEDIA_VISION_ERROR_NONE)
100                         throw InvalidOperation("Fail to get model file path");
101         }
102
103         _modelFilePath = modelDefaultPath + _modelFilePath;
104         LOGI("model file path = %s", _modelFilePath.c_str());
105
106         if (_modelMetaFilePath.empty()) {
107                 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_META_FILE_NAME, &_modelMetaFilePath);
108                 if (ret != MEDIA_VISION_ERROR_NONE)
109                         throw InvalidOperation("Fail to get model meta file path");
110
111                 if (_modelMetaFilePath.empty())
112                         throw InvalidOperation("Model meta file doesn't exist.");
113
114                 if (!IsJsonFile(_modelMetaFilePath))
115                         throw InvalidOperation("Model meta file should be json.");
116         }
117
118         _modelMetaFilePath = modelDefaultPath + _modelMetaFilePath;
119         LOGI("meta file path = %s", _modelMetaFilePath.c_str());
120
121         if (_modelLabelFilePath.empty()) {
122                 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_LABEL_FILE_NAME, &_modelLabelFilePath);
123                 if (ret != MEDIA_VISION_ERROR_NONE)
124                         throw InvalidOperation("Fail to get model label file path");
125
126                 if (_modelLabelFilePath.empty())
127                         throw InvalidOperation("Model label file doesn't exist.");
128         }
129
130         _modelLabelFilePath = modelDefaultPath + _modelLabelFilePath;
131         LOGI("label file path = %s", _modelLabelFilePath.c_str());
132
133         loadLabel();
134         _parser->load(_modelMetaFilePath);
135 }
136
137 void ImageClassification::prepare()
138 {
139         int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
140         if (ret != MEDIA_VISION_ERROR_NONE)
141                 throw InvalidOperation("Fail to configure input tensor info from meta file.");
142
143         ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
144         if (ret != MEDIA_VISION_ERROR_NONE)
145                 throw InvalidOperation("Fail to configure output tensor info from meta file.");
146
147         _inference->configureModelFiles("", _modelFilePath, "");
148
149         // Request to load model files to a backend engine.
150         ret = _inference->load();
151         if (ret != MEDIA_VISION_ERROR_NONE)
152                 throw InvalidOperation("Fail to load model files.");
153 }
154
155 void ImageClassification::preprocess(mv_source_h &mv_src)
156 {
157         LOGI("ENTER");
158
159         TensorBuffer &tensor_buffer_obj = _inference->getInputTensorBuffer();
160         IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
161         vector<mv_source_h> mv_srcs = { mv_src };
162
163         _preprocess.run(mv_srcs, _parser->getInputMetaMap(), ie_tensor_buffer);
164
165         LOGI("LEAVE");
166 }
167
168 void ImageClassification::inference(mv_source_h source)
169 {
170         LOGI("ENTER");
171
172         vector<mv_source_h> sources;
173
174         sources.push_back(source);
175
176         int ret = _inference->run();
177         if (ret != MEDIA_VISION_ERROR_NONE)
178                 throw InvalidOperation("Fail to run inference");
179
180         LOGI("LEAVE");
181 }
182
183 void ImageClassification::getOutputNames(vector<string> &names)
184 {
185         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
186         IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
187
188         for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
189                 names.push_back(it->first);
190 }
191
192 void ImageClassification::getOutpuTensor(string &target_name, vector<float> &tensor)
193 {
194         LOGI("ENTER");
195
196         TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
197
198         inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
199         if (!tensor_buffer)
200                 throw InvalidOperation("Fail to get tensor buffer.");
201
202         auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
203
204         copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));
205
206         LOGI("LEAVE");
207 }
208
209 }
210 }