fc62be477e52dbe671697e4ffd36b67fed88514f
[platform/core/api/mediavision.git] / mv_machine_learning / inference / src / ObjectDecoder.cpp
1 /**
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "mv_private.h"
18 #include "ObjectDecoder.h"
19
20 #include <unistd.h>
21 #include <fstream>
22 #include <string>
23
24 namespace mediavision
25 {
26 namespace inference
27 {
28 int ObjectDecoder::init()
29 {
30         if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_BYPASS) {
31                 if (!mTensorBuffer.exist(mMeta.GetLabelName()) || !mTensorBuffer.exist(mMeta.GetNumberName())) {
32                         LOGE("buffer buffers named of %s or %s are NULL", mMeta.GetLabelName().c_str(),
33                                  mMeta.GetNumberName().c_str());
34
35                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
36                 }
37
38                 std::vector<int> indexes = mMeta.GetNumberDimInfo().GetValidIndexAll();
39                 if (indexes.size() != 1) {
40                         LOGE("Invalid dim size. It should be 1");
41                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
42                 }
43
44                 // mNumberOfObjects is set again if INFERENCE_BOX_DECODING_TYPE_BYPASS.
45                 // Otherwise it is set already within ctor.
46                 mNumberOfOjects = mTensorBuffer.getValue<int>(mMeta.GetNumberName(), indexes[0]);
47         } else if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_SSD_ANCHOR) {
48                 if (mMeta.GetBoxDecodeInfo().IsAnchorBoxEmpty()) {
49                         LOGE("Anchor boxes are required but empty.");
50                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
51                 }
52         } else {
53                 LOGI("YOLO_ANCHOR does nothing");
54         }
55
56         return MEDIA_VISION_ERROR_NONE;
57 }
58
59 float ObjectDecoder::decodeScore(int idx)
60 {
61         float score = mTensorBuffer.getValue<float>(mMeta.GetScoreName(), idx);
62         if (mMeta.GetScoreType() == INFERENCE_SCORE_TYPE_SIGMOID) {
63                 score = PostProcess::sigmoid(score);
64         }
65
66         return score < mMeta.GetScoreThreshold() ? 0.0f : score;
67 }
68
69 Box ObjectDecoder::decodeBox(int idx, float score, int label, int offset)
70 {
71         // assume type is (cx,cy,w,h)
72         // left or cx
73         float cx = mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[0]);
74         // top or cy
75         float cy = mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[1]);
76         // right or width
77         float cWidth =
78                         mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[2]);
79         // bottom or height
80         float cHeight =
81                         mTensorBuffer.getValue<float>(mMeta.GetBoxName(), idx * mBoxOffset + offset + mMeta.GetBoxOrder()[3]);
82
83         LOGI("cx:%.2f, cy:%.2f, cW:%.2f, cH:%.2f", cx, cy, cWidth, cHeight);
84         // convert type to ORIGIN_CENTER if ORIGIN_LEFTTOP
85         if (mMeta.GetBoxType() == INFERENCE_BOX_TYPE_ORIGIN_LEFTTOP) {
86                 float tmpCx = cx;
87                 float tmpCy = cy;
88                 cx = (cx + cWidth) * 0.5f; // (left + right)/2
89                 cy = (cy + cHeight) * 0.5f; // (top + bottom)/2
90                 cWidth = cWidth - tmpCx; // right - left
91                 cHeight = cHeight - tmpCy; // bottom - top
92         }
93
94         // convert coordinate to RATIO if PIXEL
95         if (mMeta.GetScoreCoordinate() == INFERENCE_BOX_COORDINATE_TYPE_PIXEL) {
96                 cx /= mScaleW;
97                 cy /= mScaleH;
98                 cWidth /= mScaleW;
99                 cHeight /= mScaleH;
100         }
101
102         Box box = { .index = mMeta.GetLabelName().empty() ? label : mTensorBuffer.getValue<int>(mMeta.GetLabelName(), idx),
103                                 .score = score,
104                                 .location = cv::Rect2f(cx, cy, cWidth, cHeight) };
105
106         return box;
107 }
108
109 Box ObjectDecoder::decodeBoxWithAnchor(int idx, int anchorIdx, float score, cv::Rect2f &anchor)
110 {
111         // location coordinate of box, the output of decodeBox(), is relative between 0 ~ 1
112         Box box = decodeBox(anchorIdx, score, idx);
113
114         if (mMeta.GetBoxDecodeInfo().IsFixedAnchorSize()) {
115                 box.location.x += anchor.x;
116                 box.location.y += anchor.y;
117         } else {
118                 box.location.x = box.location.x / mMeta.GetBoxDecodeInfo().GetAnchorXscale() * anchor.width + anchor.x;
119                 box.location.y = box.location.y / mMeta.GetBoxDecodeInfo().GetAnchorYscale() * anchor.height + anchor.y;
120         }
121
122         if (mMeta.GetBoxDecodeInfo().IsExponentialBoxScale()) {
123                 box.location.width = anchor.width * std::exp(box.location.width / mMeta.GetBoxDecodeInfo().GetAnchorWscale());
124                 box.location.height =
125                                 anchor.height * std::exp(box.location.height / mMeta.GetBoxDecodeInfo().GetAnchorHscale());
126         } else {
127                 box.location.width = anchor.width * box.location.width / mMeta.GetBoxDecodeInfo().GetAnchorWscale();
128                 box.location.height = anchor.height * box.location.height / mMeta.GetBoxDecodeInfo().GetAnchorHscale();
129         }
130
131         return box;
132 }
133
134 int ObjectDecoder::decode()
135 {
136         LOGI("ENTER");
137
138         BoxesList boxList;
139         Boxes boxes;
140         int ret = MEDIA_VISION_ERROR_NONE;
141         int totalIdx = mNumberOfOjects;
142
143         for (int idx = 0; idx < totalIdx; ++idx) {
144                 if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_BYPASS) {
145                         float score = decodeScore(idx);
146                         if (score <= 0.0f)
147                                 continue;
148
149                         Box box = decodeBox(idx, score);
150                         mResultBoxes.push_back(box);
151                 } else if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_SSD_ANCHOR) {
152                         int anchorIdx = -1;
153
154                         boxes.clear();
155                         for (auto &anchorBox : mMeta.GetBoxDecodeInfo().GetAnchorBoxAll()) {
156                                 anchorIdx++;
157
158                                 float score = decodeScore(anchorIdx * mNumberOfOjects + idx);
159
160                                 if (score <= 0.0f)
161                                         continue;
162
163                                 Box box = decodeBoxWithAnchor(idx, anchorIdx, score, anchorBox);
164                                 boxes.push_back(box);
165                         }
166                         boxList.push_back(boxes);
167                 }
168         }
169
170         if (mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_SSD_ANCHOR ||
171                 mMeta.GetBoxDecodingType() == INFERENCE_BOX_DECODING_TYPE_YOLO_ANCHOR)
172                 boxList.push_back(boxes);
173
174         if (!boxList.empty()) {
175                 PostProcess postProc;
176                 ret = postProc.Nms(boxList, mMeta.GetBoxDecodeInfo().GetNmsMode(),
177                                                    mMeta.GetBoxDecodeInfo().GetNmsIouThreshold(), mResultBoxes);
178                 if (ret != MEDIA_VISION_ERROR_NONE) {
179                         LOGE("Fail to non-maximum suppression[%d]", ret);
180                         return ret;
181                 }
182         } else {
183                 LOGW("boxlist empty!");
184         }
185
186         LOGI("LEAVE");
187
188         return ret;
189 }
190
191 Boxes &ObjectDecoder::getObjectAll()
192 {
193         return mResultBoxes;
194 }
195 }
196 }