2 * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include "machine_learning_exception.h"
22 #include "image_classification_default.h"
23 #include "Postprocess.h"
26 using namespace mediavision::inference;
27 using namespace mediavision::machine_learning::exception;
31 namespace machine_learning
33 ImageClassificationDefault::ImageClassificationDefault() : _result()
36 ImageClassificationDefault::~ImageClassificationDefault()
39 image_classification_result_s &ImageClassificationDefault::result()
43 ImageClassification::getOutputNames(names);
45 vector<float> output_vec;
47 // In case of image classification model, only one output tensor is used.
48 ImageClassification::getOutpuTensor(names[0], output_vec);
50 auto metaInfo = _parser->getOutputMetaMap().at(names[0]);
51 auto decodingScore = static_pointer_cast<DecodingScore>(metaInfo->decodingTypeMap.at(DecodingType::SCORE));
53 if (decodingScore->type == ScoreType::SIGMOID) {
54 for (size_t idx = 0; idx < output_vec.size(); ++idx)
55 output_vec[idx] = PostProcess::sigmoid(output_vec[idx]);
58 _result.label = _labels[max_element(output_vec.begin(), output_vec.end()) - output_vec.begin()];
59 LOGI("Label = %s", _result.label.c_str());