2 * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include "machine_learning_exception.h"
22 #include "image_classification_default.h"
23 #include "Postprocess.h"
26 using namespace mediavision::inference;
27 using namespace mediavision::machine_learning::exception;
31 namespace machine_learning
34 ImageClassificationDefault<T>::ImageClassificationDefault(shared_ptr<MachineLearningConfig> config)
35 : ImageClassification<T>(config), _result()
38 template<typename T> ImageClassificationDefault<T>::~ImageClassificationDefault()
41 template<typename T> ImageClassificationResult &ImageClassificationDefault<T>::result()
45 ImageClassification<T>::getOutputNames(names);
47 vector<float> output_vec;
49 // In case of image classification model, only one output tensor is used.
50 ImageClassification<T>::getOutpuTensor(names[0], output_vec);
52 auto metaInfo = _config->getOutputMetaMap().at(names[0]);
53 auto decodingScore = static_pointer_cast<DecodingScore>(metaInfo->decodingTypeMap.at(DecodingType::SCORE));
55 if (decodingScore->type == ScoreType::SIGMOID) {
56 for (size_t idx = 0; idx < output_vec.size(); ++idx)
57 output_vec[idx] = PostProcess::sigmoid(output_vec[idx]);
60 _result.label = _labels[max_element(output_vec.begin(), output_vec.end()) - output_vec.begin()];
61 LOGI("Label = %s", _result.label.c_str());
66 template class ImageClassificationDefault<unsigned char>;
67 template class ImageClassificationDefault<float>;