Add yolo5s model on SNPE
[platform/core/api/mediavision.git] / mv_machine_learning / inference / src / ObjectDecoder.cpp
1 /**
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "mv_private.h"
18 #include "ObjectDecoder.h"
19
20 #include <unistd.h>
21 #include <fstream>
22 #include <string>
23
24 namespace mediavision
25 {
26 namespace inference
27 {
28 int ObjectDecoder::init()
29 {
30         if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_BYPASS) {
31                 if (!mTensorBuffer.exist(mMeta.GetLabelName()) || !mTensorBuffer.exist(mMeta.GetNumberName())) {
32                         LOGE("buffer buffers named of %s or %s are NULL", mMeta.GetLabelName().c_str(),
33                                  mMeta.GetNumberName().c_str());
34
35                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
36                 }
37
38                 std::vector<int> indexes = mMeta.GetNumberDimInfo().GetValidIndexAll();
39                 if (indexes.size() != 1) {
40                         LOGE("Invalid dim size. It should be 1");
41                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
42                 }
43
44                 // mNumberOfObjects is set again if INFERENCE_BOX_DECODING_TYPE_BYPASS.
45                 // Otherwise it is set already within ctor.
46                 mNumberOfOjects = mTensorBuffer.getValue<int>(mMeta.GetNumberName(), indexes[0]);
47         } else if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_SSD_ANCHOR) {
48                 if (mMeta.GetBoxDecodeInfo().IsAnchorBoxEmpty()) {
49                         LOGE("Anchor boxes are required but empty.");
50                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
51                 }
52         } else {
53                 LOGI("YOLO_ANCHOR does nothing");
54         }
55
56         return MEDIA_VISION_ERROR_NONE;
57 }
58
59 float ObjectDecoder::decodeScore(int idx)
60 {
61         float score = mTensorBuffer.getValue<float>(mMeta.GetScoreName(), idx);
62         if (mMeta.GetScoreType() == INFERENCE_SCORE_TYPE_SIGMOID) {
63                 score = PostProcess::sigmoid(score);
64         }
65
66         return score < mMeta.GetScoreThreshold() ? 0.0f : score;
67 }
68
69 Box ObjectDecoder::decodeBox(int idx, float score, int label, int offset)
70 {
71         // assume type is (cx,cy,w,h)
72         // left or cx
73         float cx = mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[0]);
74         // top or cy
75         float cy = mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[1]);
76         // right or width
77         float cWidth =
78                         mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[2]);
79         // bottom or height
80         float cHeight =
81                         mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[3]);
82
83         LOGI("cx:%.2f, cy:%.2f, cW:%.2f, cH:%.2f", cx, cy, cWidth, cHeight);
84         // convert type to ORIGIN_CENTER if ORIGIN_LEFTTOP
85         if (mMeta.GetBoxType() == INFERENCE_BOX_TYPE_ORIGIN_LEFTTOP) {
86                 float tmpCx = cx;
87                 float tmpCy = cy;
88                 cx = (cx + cWidth) * 0.5f; // (left + right)/2
89                 cy = (cy + cHeight) * 0.5f; // (top + bottom)/2
90                 cWidth = cWidth - tmpCx; // right - left
91                 cHeight = cHeight - tmpCy; // bottom - top
92         }
93
94         // convert coordinate to RATIO if PIXEL
95         if (mMeta.GetScoreCoordinate() == INFERENCE_BOX_COORDINATE_TYPE_PIXEL) {
96                 cx /= mScaleW;
97                 cy /= mScaleH;
98                 cWidth /= mScaleW;
99                 cHeight /= mScaleH;
100         }
101
102         Box box = { .index = mMeta.GetLabelName().empty() ? label : mTensorBuffer.getValue<int>(mMeta.GetLabelName(), idx),
103                                 .score = score,
104                                 .location = cv::Rect2f(cx, cy, cWidth, cHeight) };
105
106         return box;
107 }
108
109 Box ObjectDecoder::decodeBoxWithAnchor(int idx, int anchorIdx, float score, cv::Rect2f &anchor)
110 {
111         // location coordinate of box, the output of decodeBox(), is relative between 0 ~ 1
112         Box box = decodeBox(anchorIdx, score, idx);
113
114         if (mMeta.GetBoxDecodeInfo().IsFixedAnchorSize()) {
115                 box.location.x += anchor.x;
116                 box.location.y += anchor.y;
117         } else {
118                 box.location.x = box.location.x / mMeta.GetBoxDecodeInfo().GetAnchorXscale() * anchor.width + anchor.x;
119                 box.location.y = box.location.y / mMeta.GetBoxDecodeInfo().GetAnchorYscale() * anchor.height + anchor.y;
120         }
121
122         if (mMeta.GetBoxDecodeInfo().IsExponentialBoxScale()) {
123                 box.location.width = anchor.width * std::exp(box.location.width / mMeta.GetBoxDecodeInfo().GetAnchorWscale());
124                 box.location.height =
125                                 anchor.height * std::exp(box.location.height / mMeta.GetBoxDecodeInfo().GetAnchorHscale());
126         } else {
127                 box.location.width = anchor.width * box.location.width / mMeta.GetBoxDecodeInfo().GetAnchorWscale();
128                 box.location.height = anchor.height * box.location.height / mMeta.GetBoxDecodeInfo().GetAnchorHscale();
129         }
130
131         return box;
132 }
133
134 int ObjectDecoder::decode()
135 {
136         LOGI("ENTER");
137
138         BoxesList boxList;
139         Boxes boxes;
140         int ret = MEDIA_VISION_ERROR_NONE;
141         int totalIdx = mNumberOfOjects;
142
143         for (int idx = 0; idx < totalIdx; ++idx) {
144                 if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_BYPASS) {
145                         float score = decodeScore(idx);
146                         if (score <= 0.0f)
147                                 continue;
148
149                         Box box = decodeBox(idx, score);
150                         mResultBoxes.push_back(box);
151                 } else if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_SSD_ANCHOR) {
152                         int anchorIdx = -1;
153
154                         boxes.clear();
155                         for (auto &anchorBox : mMeta.GetBoxDecodeInfo().GetAnchorBoxAll()) {
156                                 anchorIdx++;
157
158                                 float score = decodeScore(anchorIdx * mNumberOfOjects + idx);
159
160                                 if (score <= 0.0f)
161                                         continue;
162
163                                 Box box = decodeBoxWithAnchor(idx, anchorIdx, score, anchorBox);
164                                 boxes.push_back(box);
165                         }
166                         boxList.push_back(boxes);
167                 }
168         }
169         if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_YOLO_ANCHOR)
170                 decodeYOLO(boxList);
171         else if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_SSD_ANCHOR)
172                 boxList.push_back(boxes);
173
174         if (!boxList.empty()) {
175                 PostProcess postProc;
176                 ret = postProc.Nms(boxList, mMeta.GetBoxDecodeInfo().GetNmsMode(),
177                                                    mMeta.GetBoxDecodeInfo().GetNmsIouThreshold(), mResultBoxes);
178                 if (ret != MEDIA_VISION_ERROR_NONE) {
179                         LOGE("Fail to non-maximum suppression[%d]", ret);
180                         return ret;
181                 }
182         } else {
183                 LOGW("boxlist empty!");
184         }
185
186         LOGI("LEAVE");
187
188         return ret;
189 }
190
191 Boxes &ObjectDecoder::getObjectAll()
192 {
193         return mResultBoxes;
194 }
195
196 float ObjectDecoder::decodeYOLOScore(int idx, int nameIdx)
197 {
198         auto it = mMeta._tensor_info.begin();
199         std::advance(it, nameIdx);
200
201         float score = mTensorBuffer.getValue<float>(it->first, idx);
202         if (mMeta.GetScoreType() == INFERENCE_SCORE_TYPE_SIGMOID) {
203                 score = PostProcess::sigmoid(score);
204         }
205
206         return score;
207 }
208 Box ObjectDecoder::decodeYOLOBox(int idx, float score, int label, int offset, int nameIdx)
209 {
210         auto it = mMeta._tensor_info.begin();
211         std::advance(it, nameIdx);
212
213         // assume type is (cx,cy,w,h)
214         // left or cx
215         float cx = mTensorBuffer.getValue<float>(it->first, idx * mBoxOffset + offset + mMeta.GetBoxOrder()[0]);
216         // top or cy
217         float cy = mTensorBuffer.getValue<float>(it->first, idx * mBoxOffset + offset + mMeta.GetBoxOrder()[1]);
218         // right or width
219         float cWidth = mTensorBuffer.getValue<float>(it->first, idx * mBoxOffset + offset + mMeta.GetBoxOrder()[2]);
220         // bottom or height
221         float cHeight = mTensorBuffer.getValue<float>(it->first, idx * mBoxOffset + offset + mMeta.GetBoxOrder()[3]);
222
223         if (mMeta.GetScoreType() == INFERENCE_SCORE_TYPE_SIGMOID) {
224                 cx = PostProcess::sigmoid(cx);
225                 cy = PostProcess::sigmoid(cy);
226                 cWidth = PostProcess::sigmoid(cWidth);
227                 cHeight = PostProcess::sigmoid(cHeight);
228         }
229
230         LOGI("cx:%.2f, cy:%.2f, cW:%.2f, cH:%.2f", cx, cy, cWidth, cHeight);
231         // convert type to ORIGIN_CENTER if ORIGIN_LEFTTOP
232         if (mMeta.GetBoxType() == INFERENCE_BOX_TYPE_ORIGIN_LEFTTOP) {
233                 float tmpCx = cx;
234                 float tmpCy = cy;
235                 cx = (cx + cWidth) * 0.5f; // (left + right)/2
236                 cy = (cy + cHeight) * 0.5f; // (top + bottom)/2
237                 cWidth = cWidth - tmpCx; // right - left
238                 cHeight = cHeight - tmpCy; // bottom - top
239         }
240
241         // convert coordinate to RATIO if PIXEL
242         if (mMeta.GetScoreCoordinate() == INFERENCE_BOX_COORDINATE_TYPE_PIXEL) {
243                 cx /= mScaleW;
244                 cy /= mScaleH;
245                 cWidth /= mScaleW;
246                 cHeight /= mScaleH;
247         }
248
249         Box box = { .index = label, .score = score, .location = cv::Rect2f(cx, cy, cWidth, cHeight) };
250
251         return box;
252 }
253 void ObjectDecoder::decodeYOLO(BoxesList &boxesList)
254 {
255         box::DecodeInfo &decodeInfo = mMeta.GetBoxDecodeInfo();
256         box::AnchorParam &yoloAnchor = decodeInfo.anchorParam;
257
258         //offsetAnchors is 3 which is number of BOX
259         mNumberOfOjects = mBoxOffset / yoloAnchor.offsetAnchors - 5;
260         boxesList.resize(mNumberOfOjects);
261
262         for (auto strideIdx = 0; strideIdx < yoloAnchor.offsetAnchors; strideIdx++) {
263                 auto &stride = yoloAnchor.strides[strideIdx];
264                 //for each stride
265                 int startAnchorIdx = 0;
266                 int endAnchorIdx = (static_cast<int>(mScaleW) / stride * static_cast<int>(mScaleH) / stride);
267
268                 for (int anchorIdx = startAnchorIdx; anchorIdx < endAnchorIdx; anchorIdx++) {
269                         // for each grid cell
270                         for (int offset = 0; offset < yoloAnchor.offsetAnchors; ++offset) {
271                                 //for each BOX
272                                 //handle order is (H,W,A)
273                                 float boxScore =
274                                                 decodeYOLOScore(anchorIdx * mBoxOffset + (mNumberOfOjects + 5) * offset + 4, strideIdx);
275
276                                 auto anchorBox = decodeInfo.vAnchorBoxes[strideIdx][anchorIdx * yoloAnchor.offsetAnchors + offset];
277
278                                 for (int objIdx = 0; objIdx < mNumberOfOjects; ++objIdx) { //each box to every object
279                                         float objScore = decodeYOLOScore(
280                                                         anchorIdx * mBoxOffset + (mNumberOfOjects + 5) * offset + 5 + objIdx, strideIdx);
281
282                                         if (boxScore * objScore < mMeta.GetScoreThreshold())
283                                                 continue;
284                                         Box box = decodeYOLOBox(anchorIdx, objScore, objIdx, (mNumberOfOjects + 5) * offset, strideIdx);
285
286                                         if (!decodeInfo.vAnchorBoxes.empty()) {
287                                                 box.location.x = (box.location.x * 2 + anchorBox.x) * stride / mScaleW;
288                                                 box.location.y = (box.location.y * 2 + anchorBox.y) * stride / mScaleH;
289                                                 box.location.width =
290                                                                 (box.location.width * 2) * (box.location.width * 2) * anchorBox.width / mScaleW;
291
292                                                 box.location.height =
293                                                                 (box.location.height * 2) * (box.location.height * 2) * anchorBox.height / mScaleH;
294                                         }
295                                         boxesList[objIdx].push_back(box);
296                                 }
297                         }
298                 }
299         }
300 }
301 }
302 }