2 * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
22 #include "machine_learning_exception.h"
23 #include "mv_object_detection_config.h"
24 #include "mobilenet_v2_ssd.h"
25 #include "Postprocess.h"
28 using namespace mediavision::inference;
29 using namespace mediavision::machine_learning::exception;
33 namespace machine_learning
35 MobilenetV2Ssd::MobilenetV2Ssd(ObjectDetectionTaskType task_type, std::shared_ptr<MachineLearningConfig> config)
36 : ObjectDetection(task_type, config), _result()
39 MobilenetV2Ssd::~MobilenetV2Ssd()
42 static bool compareScore(Box box0, Box box1)
44 return box0.score > box1.score;
47 static float calcIntersectionOverUnion(Box box0, Box box1)
49 float area0 = box0.location.width * box0.location.height;
50 float area1 = box1.location.width * box1.location.height;
52 if (area0 <= 0.0f || area1 <= 0.0f)
55 float sx0 = box0.location.x - box0.location.width * 0.5f;
56 float sy0 = box0.location.y - box0.location.height * 0.5f;
57 float ex0 = box0.location.x + box0.location.width * 0.5f;
58 float ey0 = box0.location.y + box0.location.height * 0.5f;
59 float sx1 = box1.location.x - box1.location.width * 0.5f;
60 float sy1 = box1.location.y - box1.location.height * 0.5f;
61 float ex1 = box1.location.x + box1.location.width * 0.5f;
62 float ey1 = box1.location.y + box1.location.height * 0.5f;
64 float xmin0 = min(sx0, ex0);
65 float ymin0 = min(sy0, ey0);
66 float xmax0 = max(sx0, ex0);
67 float ymax0 = max(sy0, ey0);
68 float xmin1 = min(sx1, ex1);
69 float ymin1 = min(sy1, ey1);
70 float xmax1 = max(sx1, ex1);
71 float ymax1 = max(sy1, ey1);
73 float intersectXmin = max(xmin0, xmin1);
74 float intersectYmin = max(ymin0, ymin1);
75 float intersectXmax = min(xmax0, xmax1);
76 float intersectYmax = min(ymax0, ymax1);
78 float intersectArea = max((intersectYmax - intersectYmin), 0.0f) * max((intersectXmax - intersectXmin), 0.0f);
80 return intersectArea / (area0 + area1 - intersectArea);
83 void MobilenetV2Ssd::ApplyNms(vector<vector<Box> > &box_lists, BoxNmsMode mode, float threshold,
84 vector<Box> &box_vector)
88 if (mode != BoxNmsMode::STANDARD) {
93 LOGI("threshold: %.3f", threshold);
95 bool isIgnore = false;
96 vector<Box> candidate_box_vec;
98 for (auto &box_list : box_lists) {
99 if (box_list.size() <= 0)
102 sort(box_list.begin(), box_list.end(), compareScore);
103 candidate_box_vec.clear();
105 for (auto &decoded_box : box_list) {
108 for (auto candidate_box = candidate_box_vec.rbegin(); candidate_box != candidate_box_vec.rend();
110 // compare decoded_box with previous one
111 float iouValue = calcIntersectionOverUnion(decoded_box, (*candidate_box));
113 LOGI("iouValue: %.3f", iouValue);
115 if (iouValue >= threshold) {
122 candidate_box_vec.push_back(decoded_box);
125 if (candidate_box_vec.size() > 0)
126 box_vector.insert(box_vector.begin(), candidate_box_vec.begin(), candidate_box_vec.end());
132 Box MobilenetV2Ssd::decodeBox(const DecodingBox *decodingBox, vector<float> &bb_tensor, int idx, float score, int label,
135 // assume type is (cx,cy,w,h)
137 float cx = bb_tensor[idx * box_offset + decodingBox->order[0]];
139 float cy = bb_tensor[idx * box_offset + decodingBox->order[1]];
141 float cWidth = bb_tensor[idx * box_offset + decodingBox->order[2]];
143 float cHeight = bb_tensor[idx * box_offset + decodingBox->order[3]];
145 LOGI("cx:%.2f, cy:%.2f, cW:%.2f, cH:%.2f", cx, cy, cWidth, cHeight);
147 Box box = { .index = label, .score = score, .location = cv::Rect2f(cx, cy, cWidth, cHeight) };
152 Box MobilenetV2Ssd::decodeBoxWithAnchor(const BoxAnchorParam *boxAnchorParam, Box &box, cv::Rect2f &anchor)
154 if (boxAnchorParam->isFixedAnchorSize) {
155 box.location.x += anchor.x;
156 box.location.y += anchor.y;
158 box.location.x = box.location.x / boxAnchorParam->xScale * anchor.width + anchor.x;
159 box.location.y = box.location.y / boxAnchorParam->yScale * anchor.height + anchor.y;
162 if (boxAnchorParam->isExponentialBoxScale) {
163 box.location.width = anchor.width * exp(box.location.width / boxAnchorParam->wScale);
164 box.location.height = anchor.height * exp(box.location.height / boxAnchorParam->hScale);
166 box.location.width = anchor.width * box.location.width / boxAnchorParam->wScale;
167 box.location.height = anchor.height * box.location.height / boxAnchorParam->hScale;
173 ObjectDetectionResult &MobilenetV2Ssd::result()
175 // Clear _result object because result() function can be called every time user wants
176 // so make sure to clear existing result data before getting the data again.
177 memset(reinterpret_cast<void *>(&_result), 0, sizeof(_result));
179 vector<string> names;
181 ObjectDetection::getOutputNames(names);
183 vector<float> score_tensor;
185 // raw_outputs/class_predictions
186 ObjectDetection::getOutputTensor(names[1], score_tensor);
188 auto scoreMetaInfo = _config->getOutputMetaMap().at(names[1]);
189 auto decodingScore = static_pointer_cast<DecodingScore>(scoreMetaInfo->decodingTypeMap[DecodingType::SCORE]);
191 auto boxMetaInfo = _config->getOutputMetaMap().at(names[0]);
192 auto decodingBox = static_pointer_cast<DecodingBox>(boxMetaInfo->decodingTypeMap[DecodingType::BOX]);
193 auto anchorParam = static_pointer_cast<BoxAnchorParam>(decodingBox->decodingInfoMap[BoxDecodingType::SSD_ANCHOR]);
194 unsigned int number_of_objects = scoreMetaInfo->dims[2]; // Shape is 1 x 2034 x 91
196 vector<float> bb_tensor;
198 // raw_outputs/box_encodings
199 ObjectDetection::getOutputTensor(names[0], bb_tensor);
202 vector<vector<Box> > box_list_vec;
203 int box_offset = boxMetaInfo->dims[2]; // Shape is 1 x 2034 x 4
205 for (unsigned int object_idx = 0; object_idx < number_of_objects; ++object_idx) {
210 for (auto &anchor : anchorParam->anchorBox) {
213 float score = score_tensor[anchor_idx * number_of_objects + object_idx];
214 if (score <= decodingScore->threshold)
217 Box box = decodeBox(decodingBox.get(), bb_tensor, anchor_idx, score, object_idx, box_offset);
218 box_vec.push_back(decodeBoxWithAnchor(anchorParam.get(), box, anchor));
221 box_list_vec.push_back(box_vec);
224 vector<Box> result_box_vec;
226 if (!box_list_vec.empty()) {
227 auto anchorNmsParam = static_pointer_cast<BoxNmsParam>(decodingBox->decodingInfoMap[BoxDecodingType::NMS]);
228 ApplyNms(box_list_vec, anchorNmsParam->mode, anchorNmsParam->iouThreshold, result_box_vec);
231 for (auto &box : result_box_vec) {
232 _result.number_of_objects++;
233 _result.names.push_back(_labels[box.index]);
234 _result.indices.push_back(_result.number_of_objects - 1);
235 _result.confidences.push_back(box.score);
237 auto src_width = static_cast<double>(_preprocess.getImageWidth()[0]);
238 auto src_height = static_cast<double>(_preprocess.getImageHeight()[0]);
239 auto half_width = box.location.x - box.location.width * 0.5f;
240 auto half_height = box.location.y - box.location.height * 0.5f;
242 _result.left.push_back(static_cast<int>(half_width * src_width));
243 _result.top.push_back(static_cast<int>(half_height * src_height));
244 _result.right.push_back(static_cast<int>(half_width * src_width) +
245 static_cast<int>(box.location.width * src_width));
246 _result.bottom.push_back(static_cast<int>(half_height * src_height) +
247 static_cast<int>(box.location.height * src_height));
249 LOGI("idx = %d, name = %s, score = %f, %dx%d, %dx%d", box.index,
250 _result.names[_result.number_of_objects - 1].c_str(), _result.confidences[_result.number_of_objects - 1],
251 _result.left[_result.number_of_objects - 1], _result.top[_result.number_of_objects - 1],
252 _result.right[_result.number_of_objects - 1], _result.bottom[_result.number_of_objects - 1]);
254 if (decodingScore->topNumber == _result.number_of_objects)