Faster-RCNN models support
[platform/upstream/opencv.git] / modules / dnn / src / layers / proposal_layer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2017, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 #include "../precomp.hpp"
8 #include "layers_common.hpp"
9
10 namespace cv { namespace dnn {
11
12 class ProposalLayerImpl : public ProposalLayer
13 {
14 public:
15     ProposalLayerImpl(const LayerParams& params)
16     {
17         setParamsFrom(params);
18
19         uint32_t featStride = params.get<uint32_t>("feat_stride", 16);
20         uint32_t baseSize = params.get<uint32_t>("base_size", 16);
21         // uint32_t minSize = params.get<uint32_t>("min_size", 16);
22         uint32_t keepTopBeforeNMS = params.get<uint32_t>("pre_nms_topn", 6000);
23         keepTopAfterNMS = params.get<uint32_t>("post_nms_topn", 300);
24         float nmsThreshold = params.get<float>("nms_thresh", 0.7);
25         DictValue ratios = params.get("ratio");
26         DictValue scales = params.get("scale");
27
28         {
29             LayerParams lp;
30             lp.set("step", featStride);
31             lp.set("flip", false);
32             lp.set("clip", false);
33             lp.set("normalized_bbox", false);
34
35             // Unused values.
36             float variance[] = {0.1f, 0.1f, 0.2f, 0.2f};
37             lp.set("variance", DictValue::arrayReal<float*>(&variance[0], 4));
38
39             // Compute widths and heights explicitly.
40             std::vector<float> widths, heights;
41             widths.reserve(ratios.size() * scales.size());
42             heights.reserve(ratios.size() * scales.size());
43             for (int i = 0; i < ratios.size(); ++i)
44             {
45                 float ratio = ratios.get<float>(i);
46                 for (int j = 0; j < scales.size(); ++j)
47                 {
48                     float scale = scales.get<float>(j);
49                     float width = std::floor(baseSize / sqrt(ratio) + 0.5f);
50                     float height = std::floor(width * ratio + 0.5f);
51                     widths.push_back(scale * width);
52                     heights.push_back(scale * height);
53                 }
54             }
55             lp.set("width", DictValue::arrayReal<float*>(&widths[0], widths.size()));
56             lp.set("height", DictValue::arrayReal<float*>(&heights[0], heights.size()));
57
58             priorBoxLayer = PriorBoxLayer::create(lp);
59         }
60         {
61             int order[] = {0, 2, 3, 1};
62             LayerParams lp;
63             lp.set("order", DictValue::arrayInt<int*>(&order[0], 4));
64
65             deltasPermute = PermuteLayer::create(lp);
66             scoresPermute = PermuteLayer::create(lp);
67         }
68         {
69             LayerParams lp;
70             lp.set("code_type", "CENTER_SIZE");
71             lp.set("num_classes", 1);
72             lp.set("share_location", true);
73             lp.set("background_label_id", 1);  // We won't pass background scores so set it out of range [0, num_classes)
74             lp.set("variance_encoded_in_target", true);
75             lp.set("keep_top_k", keepTopAfterNMS);
76             lp.set("top_k", keepTopBeforeNMS);
77             lp.set("nms_threshold", nmsThreshold);
78             lp.set("normalized_bbox", false);
79             lp.set("clip", true);
80
81             detectionOutputLayer = DetectionOutputLayer::create(lp);
82         }
83     }
84
85     bool getMemoryShapes(const std::vector<MatShape> &inputs,
86                          const int requiredOutputs,
87                          std::vector<MatShape> &outputs,
88                          std::vector<MatShape> &internals) const
89     {
90         // We need to allocate the following blobs:
91         // - output priors from PriorBoxLayer
92         // - permuted priors
93         // - permuted scores
94         CV_Assert(inputs.size() == 3);
95
96         const MatShape& scores = inputs[0];
97         const MatShape& bboxDeltas = inputs[1];
98
99         std::vector<MatShape> layerInputs, layerOutputs, layerInternals;
100
101         // Prior boxes layer.
102         layerInputs.assign(1, scores);
103         priorBoxLayer->getMemoryShapes(layerInputs, 1, layerOutputs, layerInternals);
104         CV_Assert(layerOutputs.size() == 1);
105         CV_Assert(layerInternals.empty());
106         internals.push_back(layerOutputs[0]);
107
108         // Scores permute layer.
109         CV_Assert(scores.size() == 4);
110         MatShape objectScores = scores;
111         CV_Assert((scores[1] & 1) == 0);  // Number of channels is even.
112         objectScores[1] /= 2;
113         layerInputs.assign(1, objectScores);
114         scoresPermute->getMemoryShapes(layerInputs, 1, layerOutputs, layerInternals);
115         CV_Assert(layerOutputs.size() == 1);
116         CV_Assert(layerInternals.empty());
117         internals.push_back(layerOutputs[0]);
118
119         // BBox predictions permute layer.
120         layerInputs.assign(1, bboxDeltas);
121         deltasPermute->getMemoryShapes(layerInputs, 1, layerOutputs, layerInternals);
122         CV_Assert(layerOutputs.size() == 1);
123         CV_Assert(layerInternals.empty());
124         internals.push_back(layerOutputs[0]);
125
126         outputs.resize(1, shape(keepTopAfterNMS, 5));
127         return false;
128     }
129
130     void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
131     {
132         std::vector<Mat*> layerInputs;
133         std::vector<Mat> layerOutputs;
134
135         // Scores permute layer.
136         Mat scores = getObjectScores(*inputs[0]);
137         layerInputs.assign(1, &scores);
138         layerOutputs.assign(1, Mat(shape(scores.size[0], scores.size[2],
139                                          scores.size[3], scores.size[1]), CV_32FC1));
140         scoresPermute->finalize(layerInputs, layerOutputs);
141
142         // BBox predictions permute layer.
143         Mat* bboxDeltas = inputs[1];
144         CV_Assert(bboxDeltas->dims == 4);
145         layerInputs.assign(1, bboxDeltas);
146         layerOutputs.assign(1, Mat(shape(bboxDeltas->size[0], bboxDeltas->size[2],
147                                          bboxDeltas->size[3], bboxDeltas->size[1]), CV_32FC1));
148         deltasPermute->finalize(layerInputs, layerOutputs);
149     }
150
151     void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr)
152     {
153         CV_TRACE_FUNCTION();
154         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
155
156         Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
157     }
158
159     void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
160     {
161         CV_TRACE_FUNCTION();
162         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
163
164         CV_Assert(inputs.size() == 3);
165         CV_Assert(internals.size() == 3);
166         const Mat& scores = *inputs[0];
167         const Mat& bboxDeltas = *inputs[1];
168         const Mat& imInfo = *inputs[2];
169         Mat& priorBoxes = internals[0];
170         Mat& permuttedScores = internals[1];
171         Mat& permuttedDeltas = internals[2];
172
173         CV_Assert(imInfo.total() >= 2);
174         // We've chosen the smallest data type because we need just a shape from it.
175         fakeImageBlob.create(shape(1, 1, imInfo.at<float>(0), imInfo.at<float>(1)), CV_8UC1);
176
177         // Generate prior boxes.
178         std::vector<Mat> layerInputs(2), layerOutputs(1, priorBoxes);
179         layerInputs[0] = scores;
180         layerInputs[1] = fakeImageBlob;
181         priorBoxLayer->forward(layerInputs, layerOutputs, internals);
182
183         // Permute scores.
184         layerInputs.assign(1, getObjectScores(scores));
185         layerOutputs.assign(1, permuttedScores);
186         scoresPermute->forward(layerInputs, layerOutputs, internals);
187
188         // Permute deltas.
189         layerInputs.assign(1, bboxDeltas);
190         layerOutputs.assign(1, permuttedDeltas);
191         deltasPermute->forward(layerInputs, layerOutputs, internals);
192
193         // Sort predictions by scores and apply NMS. DetectionOutputLayer allocates
194         // output internally because of different number of objects after NMS.
195         layerInputs.resize(4);
196         layerInputs[0] = permuttedDeltas;
197         layerInputs[1] = permuttedScores;
198         layerInputs[2] = priorBoxes;
199         layerInputs[3] = fakeImageBlob;
200
201         layerOutputs[0] = Mat();
202         detectionOutputLayer->forward(layerInputs, layerOutputs, internals);
203
204         // DetectionOutputLayer produces 1x1xNx7 output where N might be less or
205         // equal to keepTopAfterNMS. We fill the rest by zeros.
206         const int numDets = layerOutputs[0].total() / 7;
207         CV_Assert(numDets <= keepTopAfterNMS);
208
209         Mat src = layerOutputs[0].reshape(1, numDets).colRange(3, 7);
210         Mat dst = outputs[0].rowRange(0, numDets);
211         src.copyTo(dst.colRange(1, 5));
212         dst.col(0).setTo(0);  // First column are batch ids. Keep it zeros too.
213
214         if (numDets < keepTopAfterNMS)
215             outputs[0].rowRange(numDets, keepTopAfterNMS).setTo(0);
216     }
217
218 private:
219     // A first half of channels are background scores. We need only a second one.
220     static Mat getObjectScores(const Mat& m)
221     {
222         CV_Assert(m.dims == 4);
223         CV_Assert(m.size[0] == 1);
224         int channels = m.size[1];
225         CV_Assert((channels & 1) == 0);
226         return slice(m, Range::all(), Range(channels / 2, channels));
227     }
228
229     Ptr<PriorBoxLayer> priorBoxLayer;
230     Ptr<DetectionOutputLayer> detectionOutputLayer;
231
232     Ptr<PermuteLayer> deltasPermute;
233     Ptr<PermuteLayer> scoresPermute;
234     uint32_t keepTopAfterNMS;
235     Mat fakeImageBlob;
236 };
237
238
239 Ptr<ProposalLayer> ProposalLayer::create(const LayerParams& params)
240 {
241     return Ptr<ProposalLayer>(new ProposalLayerImpl(params));
242 }
243
244 }  // namespace dnn
245 }  // namespace cv