2 * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
23 #include "machine_learning_exception.h"
24 #include "mv_image_classification_config.h"
25 #include "image_classification.h"
28 using namespace mediavision::inference;
29 using namespace MediaVision::Common;
30 using namespace mediavision::machine_learning::exception;
34 namespace machine_learning
36 ImageClassification::ImageClassification() : _backendType(), _targetDeviceType()
38 _inference = make_unique<Inference>();
39 _parser = make_unique<ImageClassificationParser>();
42 void ImageClassification::configure()
44 int ret = _inference->bind(_backendType, _targetDeviceType);
45 if (ret != MEDIA_VISION_ERROR_NONE)
46 throw InvalidOperation("Fail to bind a backend engine.");
49 static bool IsJsonFile(const string &fileName)
51 return (!fileName.substr(fileName.find_last_of(".") + 1).compare("json"));
54 void ImageClassification::loadLabel()
59 readFile.open(_modelLabelFilePath.c_str());
62 throw InvalidOperation("Fail to open " + _modelLabelFilePath + " file.");
66 while (getline(readFile, line))
67 _labels.push_back(line);
72 void ImageClassification::setUserModel(string model_file, string meta_file, string label_file)
74 _modelFilePath = model_file;
75 _modelMetaFilePath = meta_file;
76 _modelLabelFilePath = label_file;
79 void ImageClassification::parseMetaFile()
81 _config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + string(MV_IMAGE_CLASSIFICATION_CONFIG_FILE_NAME));
83 int ret = _config->getIntegerAttribute(string(MV_IMAGE_CLASSIFICATION_BACKEND_TYPE), &_backendType);
84 if (ret != MEDIA_VISION_ERROR_NONE)
85 throw InvalidOperation("Fail to get backend engine type.");
87 ret = _config->getIntegerAttribute(string(MV_IMAGE_CLASSIFICATION_TARGET_DEVICE_TYPE), &_targetDeviceType);
88 if (ret != MEDIA_VISION_ERROR_NONE)
89 throw InvalidOperation("Fail to get target device type.");
91 string modelDefaultPath;
93 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_DEFAULT_PATH, &modelDefaultPath);
94 if (ret != MEDIA_VISION_ERROR_NONE)
95 throw InvalidOperation("Fail to get model default path");
97 if (_modelFilePath.empty()) {
98 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_FILE_NAME, &_modelFilePath);
99 if (ret != MEDIA_VISION_ERROR_NONE)
100 throw InvalidOperation("Fail to get model file path");
103 _modelFilePath = modelDefaultPath + _modelFilePath;
104 LOGI("model file path = %s", _modelFilePath.c_str());
106 if (_modelMetaFilePath.empty()) {
107 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_META_FILE_NAME, &_modelMetaFilePath);
108 if (ret != MEDIA_VISION_ERROR_NONE)
109 throw InvalidOperation("Fail to get model meta file path");
111 if (_modelMetaFilePath.empty())
112 throw InvalidOperation("Model meta file doesn't exist.");
114 if (!IsJsonFile(_modelMetaFilePath))
115 throw InvalidOperation("Model meta file should be json.");
118 _modelMetaFilePath = modelDefaultPath + _modelMetaFilePath;
119 LOGI("meta file path = %s", _modelMetaFilePath.c_str());
121 if (_modelLabelFilePath.empty()) {
122 ret = _config->getStringAttribute(MV_IMAGE_CLASSIFICATION_MODEL_LABEL_FILE_NAME, &_modelLabelFilePath);
123 if (ret != MEDIA_VISION_ERROR_NONE)
124 throw InvalidOperation("Fail to get model label file path");
126 if (_modelLabelFilePath.empty())
127 throw InvalidOperation("Model label file doesn't exist.");
130 _modelLabelFilePath = modelDefaultPath + _modelLabelFilePath;
131 LOGI("label file path = %s", _modelLabelFilePath.c_str());
134 _parser->load(_modelMetaFilePath);
137 void ImageClassification::prepare()
139 int ret = _inference->configureInputMetaInfo(_parser->getInputMetaMap());
140 if (ret != MEDIA_VISION_ERROR_NONE)
141 throw InvalidOperation("Fail to configure input tensor info from meta file.");
143 ret = _inference->configureOutputMetaInfo(_parser->getOutputMetaMap());
144 if (ret != MEDIA_VISION_ERROR_NONE)
145 throw InvalidOperation("Fail to configure output tensor info from meta file.");
147 _inference->configureModelFiles("", _modelFilePath, "");
149 // Request to load model files to a backend engine.
150 ret = _inference->load();
151 if (ret != MEDIA_VISION_ERROR_NONE)
152 throw InvalidOperation("Fail to load model files.");
155 void ImageClassification::preprocess(mv_source_h &mv_src)
159 TensorBuffer &tensor_buffer_obj = _inference->getInputTensorBuffer();
160 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
161 vector<mv_source_h> mv_srcs = { mv_src };
163 _preprocess.run(mv_srcs, _parser->getInputMetaMap(), ie_tensor_buffer);
168 void ImageClassification::inference(mv_source_h source)
172 vector<mv_source_h> sources;
174 sources.push_back(source);
176 int ret = _inference->run();
177 if (ret != MEDIA_VISION_ERROR_NONE)
178 throw InvalidOperation("Fail to run inference");
183 void ImageClassification::getOutputNames(vector<string> &names)
185 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
186 IETensorBuffer &ie_tensor_buffer = tensor_buffer_obj.getIETensorBuffer();
188 for (IETensorBuffer::iterator it = ie_tensor_buffer.begin(); it != ie_tensor_buffer.end(); it++)
189 names.push_back(it->first);
192 void ImageClassification::getOutpuTensor(string &target_name, vector<float> &tensor)
196 TensorBuffer &tensor_buffer_obj = _inference->getOutputTensorBuffer();
198 inference_engine_tensor_buffer *tensor_buffer = tensor_buffer_obj.getTensorBuffer(target_name);
200 throw InvalidOperation("Fail to get tensor buffer.");
202 auto raw_buffer = static_cast<float *>(tensor_buffer->buffer);
204 copy(&raw_buffer[0], &raw_buffer[tensor_buffer->size / sizeof(float)], back_inserter(tensor));