Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_simplernms_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include "tests_common.hpp"
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct simplernms_test_params {
23     struct {
24         size_t n;
25         size_t c;
26         size_t h;
27         size_t w;
28     } in_cls;
29
30     struct {
31         size_t n;
32         size_t c;
33         size_t h;
34         size_t w;
35     } in_delta;
36
37     struct {
38         size_t n;
39         size_t c;
40     } in_info;
41
42     struct {
43         size_t n;
44         size_t c;
45     } out;
46
47     size_t minBoxSize;
48     size_t featStride;
49     size_t preNmsTopn;
50     size_t postNmsTopn;
51     float iouThreshold;
52
53     size_t num_prim_desc;
54
55     MKLDNNPlugin::impl_desc_type selectedType;
56
57     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
58 };
59
60 struct anchor { float start_x; float start_y; float end_x; float end_y; };
61
62 template <typename data_t>
63 struct simpler_nms_roi_t
64 {
65     data_t x0, y0, x1, y1;
66
67     constexpr static inline const data_t clamp_v(const data_t v, const data_t v_min, const data_t v_max)
68     {
69         return (std::max)(v_min, (std::min)(v, v_max));
70     }
71
72     data_t area() const { return std::max<data_t>(0, y1 - y0 + 1) * std::max<data_t>(0, x1 - x0 + 1); }
73
74     simpler_nms_roi_t intersect (simpler_nms_roi_t other) const
75     {
76         return
77                 {
78                         (std::max)(x0, other.x0),
79                         (std::max)(y0, other.y0),
80                         (std::min)(x1, other.x1),
81                         (std::min)(y1, other.y1)
82                 };
83     }
84     simpler_nms_roi_t clamp (simpler_nms_roi_t other) const
85     {
86         return
87                 {
88                         clamp_v(x0, other.x0, other.x1),
89                         clamp_v(y0, other.y0, other.y1),
90                         clamp_v(x1, other.x0, other.x1),
91                         clamp_v(y1, other.y0, other.y1)
92                 };
93     }
94 };
95
96 template <typename data_t>
97 struct simpler_nms_proposal_t { simpler_nms_roi_t<data_t> roi; data_t confidence; size_t ord; };
98 template <typename data_t>
99 struct simpler_nms_delta_t { data_t shift_x, shift_y, log_w, log_h; };
100
101 template <typename data_t>
102 inline simpler_nms_roi_t<data_t> simpler_nms_gen_bbox(
103         const anchor& box,
104         const simpler_nms_delta_t<data_t>& delta,
105         int anchor_shift_x,
106         int anchor_shift_y)
107 {
108     auto anchor_w = box.end_x - box.start_x + 1;
109     auto anchor_h = box.end_y - box.start_y + 1;
110     auto center_x = box.start_x + anchor_w * .5f;
111     auto center_y = box.start_y + anchor_h *.5f;
112
113     data_t pred_center_x = delta.shift_x * anchor_w + center_x + anchor_shift_x;
114     data_t pred_center_y = delta.shift_y * anchor_h + center_y + anchor_shift_y;
115     data_t half_pred_w = exp(delta.log_w) * anchor_w * .5f;
116     data_t half_pred_h = exp(delta.log_h) * anchor_h * .5f;
117
118     return { pred_center_x - half_pred_w,
119              pred_center_y - half_pred_h,
120              pred_center_x + half_pred_w,
121              pred_center_y + half_pred_h };
122 }
123 template <typename data_t>
124 inline void sort_and_keep_at_most_top_n(std::vector<simpler_nms_proposal_t<data_t>>& proposals, size_t top_n)
125 {
126     const auto cmp_fn = [](const simpler_nms_proposal_t<data_t>& a,
127                            const simpler_nms_proposal_t<data_t>& b)
128     {
129         return a.confidence > b.confidence || (a.confidence == b.confidence && a.ord > b.ord);
130     };
131
132     if (proposals.size() > top_n) {
133         std::partial_sort(proposals.begin(), proposals.begin() + top_n, proposals.end(), cmp_fn);
134         proposals.resize(top_n);
135     }
136     else {
137         std::sort(proposals.begin(), proposals.end(), cmp_fn);
138     }
139 }
140
141 template <typename data_t>
142 std::vector<simpler_nms_roi_t<data_t>> simpler_nms_perform_nms(const std::vector<simpler_nms_proposal_t<data_t>>& proposals,
143                                                        float iou_threshold, size_t top_n) {
144     //TODO(ruv): can I mark the 1st arg, proposals as const? ifndef DONT_PRECALC_AREA, i can
145     //TODO(ruv): is it better to do the precalc or not? since we need to fetch the floats from memory anyway for -
146     //           intersect calc, it's only a question of whether it's faster to do (f-f)*(f-f) or fetch another val
147 #define DONT_PRECALC_AREA
148
149 #ifndef DONT_PRECALC_AREA
150     std::vector<Dtype> areas;
151     areas.reserve(proposals.size());
152     std::transform(proposals.begin(), proposals.end(), areas.begin(), [](const simpler_nms_proposals_t>& v)
153                    {
154                        return v.roi.area();
155                    });
156 #endif
157
158     std::vector<simpler_nms_roi_t<data_t>> res;
159     res.reserve(top_n);
160 #ifdef DONT_PRECALC_AREA
161     for (const auto & prop : proposals) {
162         const auto bbox = prop.roi;
163         const data_t area = bbox.area();
164 #else
165         size_t proposal_count = proposals.size();
166         for (size_t proposalIndex = 0; proposalIndex < proposal_count; ++proposalIndex) {
167             const auto & bbox = proposals[proposalIndex].roi;
168 #endif
169
170         // For any realistic WL, this condition is true for all top_n values anyway
171         if (prop.confidence > 0) {
172             bool overlaps = std::any_of(res.begin(), res.end(), [&](const simpler_nms_roi_t<data_t>& res_bbox)
173             {
174                 data_t interArea = bbox.intersect(res_bbox).area();
175 #ifdef DONT_PRECALC_AREA
176                 data_t unionArea = res_bbox.area() + area - interArea;
177 #else
178                 data_t unionArea = res_bbox.area() + areas[proposalIndex] - interArea;
179 #endif
180                 return interArea > iou_threshold * unionArea;
181             });
182
183             if (! overlaps) {
184                 res.push_back(bbox);
185                 if (res.size() == top_n) break;
186             }
187         }
188     }
189
190     return res;
191 }
192
193 template <typename data_t>
194 void ref_simplernms(const InferenceEngine::TBlob<data_t> &src_cls, const InferenceEngine::TBlob<data_t> &src_delta, const InferenceEngine::TBlob<data_t> &src_info, InferenceEngine::TBlob<data_t> &dst_blob, simplernms_test_params prm) {
195     int anchors_num = 3 * 3;
196     data_t *anchors_ = new data_t[anchors_num * sizeof(anchor) / sizeof(float)];
197     const anchor* anchors = (anchor*)anchors_;
198
199     int H = src_cls.dims()[2];
200     int W = src_cls.dims()[3];
201
202     int SZ = H * W;
203
204     data_t* dst = dst_blob.data();
205
206     const data_t* cls_scores = src_cls.readOnly();
207     const data_t* delta_pred = src_delta.readOnly();
208     const data_t* im_info = src_info.readOnly();
209
210     int IW = im_info[0];
211     int IH = im_info[1];
212     int IS = im_info[2];
213
214     int scaled_min_bbox_size = prm.minBoxSize * IS;
215
216     std::vector<simpler_nms_proposal_t<data_t>> sorted_proposals_confidence;
217
218     for (auto y = 0; y < H; ++y)
219     {
220         int anchor_shift_y = y * prm.featStride;
221
222         for (auto x = 0; x < W; ++x) {
223             int anchor_shift_x = x * prm.featStride;
224             int location_index = y * W + x;
225
226             // we assume proposals are grouped by window location
227             for (int anchor_index = 0; anchor_index < anchors_num ; anchor_index++) {
228                 data_t dx0 = delta_pred[location_index + SZ * (anchor_index * 4 + 0)];
229                 data_t dy0 = delta_pred[location_index + SZ * (anchor_index * 4 + 1)];
230                 data_t dx1 = delta_pred[location_index + SZ * (anchor_index * 4 + 2)];
231                 data_t dy1 = delta_pred[location_index + SZ * (anchor_index * 4 + 3)];
232
233                 simpler_nms_delta_t<data_t> bbox_delta { dx0, dy0, dx1, dy1 };
234
235                 data_t proposal_confidence = cls_scores[location_index + SZ * (anchor_index + anchors_num * 1)];
236
237                 simpler_nms_roi_t<data_t> tmp_roi = simpler_nms_gen_bbox(anchors[anchor_index], bbox_delta, anchor_shift_x, anchor_shift_y);
238                 simpler_nms_roi_t<data_t> roi = tmp_roi.clamp({ 0, 0, data_t(IW - 1), data_t(IH - 1) });
239
240                 int bbox_w = roi.x1 - roi.x0 + 1;
241                 int bbox_h = roi.y1 - roi.y0 + 1;
242
243                 if (bbox_w >= scaled_min_bbox_size && bbox_h >= scaled_min_bbox_size) {
244                     simpler_nms_proposal_t<data_t> proposal { roi, proposal_confidence, sorted_proposals_confidence.size() };
245                     sorted_proposals_confidence.push_back(proposal);
246                 }
247             }
248         }
249     }
250
251     sort_and_keep_at_most_top_n(sorted_proposals_confidence, prm.preNmsTopn);
252     auto res = simpler_nms_perform_nms(sorted_proposals_confidence, prm.iouThreshold, prm.postNmsTopn);
253
254     size_t res_num_rois = res.size();
255
256     for (size_t i = 0; i < res_num_rois; ++i) {
257         dst[5 * i + 0] = 0;    // roi_batch_ind, always zero on test time
258         dst[5 * i + 1] = res[i].x0;
259         dst[5 * i + 2] = res[i].y0;
260         dst[5 * i + 3] = res[i].x1;
261         dst[5 * i + 4] = res[i].y1;
262     }
263
264     delete[] anchors_;
265 }
266
267 class MKLDNNGraphSimplerNMSTests: public TestsCommon,
268                                      public WithParamInterface<simplernms_test_params> {
269     std::string model_t = R"V0G0N(
270 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
271     <layers>
272         <layer name="in1" type="Input" precision="FP32" id="0">
273             <output>
274                 <port id="0">
275                     <dim>_INC_</dim>
276                     <dim>_ICC_</dim>
277                     <dim>_IHC_</dim>
278                     <dim>_IWC_</dim>
279                 </port>
280             </output>
281         </layer>
282         <layer name="in2" type="Input" precision="FP32" id="1">
283             <output>
284                 <port id="1">
285                     <dim>_IND_</dim>
286                     <dim>_ICD_</dim>
287                     <dim>_IHD_</dim>
288                     <dim>_IWD_</dim>
289                 </port>
290             </output>
291         </layer>
292         <layer name="in3" type="Input" precision="FP32" id="2">
293             <output>
294                 <port id="2">
295                     <dim>_INI_</dim>
296                     <dim>_ICI_</dim>
297                 </port>
298             </output>
299         </layer>
300         <layer name="proposal" id="3" type="SimplerNMS" precision="FP32">
301             <data cls_threshold="0.500000" max_num_proposals="300" iou_threshold="_IOU_THRESHOLD_"
302             min_bbox_size="_MIN_BOX_SIZE_" feat_stride="_FSRD_" pre_nms_topn="_PRENT_" post_nms_topn="_POSTNT_"
303             scale="8.000000,16.000000,32.000000"/>
304
305             <input>
306                 <port id="3">
307                     <dim>_INC_</dim>
308                     <dim>_ICC_</dim>
309                     <dim>_IHC_</dim>
310                     <dim>_IWC_</dim>
311                 </port>
312                 <port id="4">
313                     <dim>_IND_</dim>
314                     <dim>_ICD_</dim>
315                     <dim>_IHD_</dim>
316                     <dim>_IWD_</dim>
317                 </port>
318                 <port id="5">
319                     <dim>_INI_</dim>
320                     <dim>_ICI_</dim>
321                 </port>
322             </input>
323             <output>
324                 <port id="6">
325                     <dim>_ON_</dim>
326                     <dim>_OC_</dim>
327                 </port>
328             </output>
329         </layer>
330     </layers>
331     <edges>
332         <edge from-layer="0" from-port="0" to-layer="3" to-port="3"/>
333         <edge from-layer="1" from-port="1" to-layer="3" to-port="4"/>
334         <edge from-layer="2" from-port="2" to-layer="3" to-port="5"/>
335     </edges>
336 </Net>
337 )V0G0N";
338
339     std::string getModel(simplernms_test_params p) {
340         std::string model = model_t;
341
342         REPLACE_WITH_NUM(model, "_IWC_", p.in_cls.w);
343         REPLACE_WITH_NUM(model, "_IHC_", p.in_cls.h);
344         REPLACE_WITH_NUM(model, "_ICC_", p.in_cls.c);
345         REPLACE_WITH_NUM(model, "_INC_", p.in_cls.n);
346
347         REPLACE_WITH_NUM(model, "_IWD_", p.in_delta.w);
348         REPLACE_WITH_NUM(model, "_IHD_", p.in_delta.h);
349         REPLACE_WITH_NUM(model, "_ICD_", p.in_delta.c);
350         REPLACE_WITH_NUM(model, "_IND_", p.in_delta.n);
351
352         REPLACE_WITH_NUM(model, "_ICI_", p.in_info.c);
353         REPLACE_WITH_NUM(model, "_INI_", p.in_info.n);
354
355         REPLACE_WITH_NUM(model, "_OC_", p.out.c);
356         REPLACE_WITH_NUM(model, "_ON_", p.out.n);
357
358         REPLACE_WITH_NUM(model, "_MIN_BOX_SIZE_", p.minBoxSize);
359         REPLACE_WITH_NUM(model, "_FSRD_", p.featStride);
360         REPLACE_WITH_NUM(model, "_PRENT_", p.preNmsTopn);
361         REPLACE_WITH_NUM(model, "_POSTNT_", p.postNmsTopn);
362         REPLACE_WITH_NUM(model, "_IOU_THRESHOLD_", p.iouThreshold);
363
364         return model;
365     }
366
367 protected:
368     virtual void TearDown() {
369     }
370
371     virtual void SetUp() {
372         try {
373             TestsCommon::SetUp();
374             simplernms_test_params p = ::testing::WithParamInterface<simplernms_test_params>::GetParam();
375             std::string model = getModel(p);
376
377             InferenceEngine::CNNNetReader net_reader;
378             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
379
380             MKLDNNGraphTestClass graph;
381             graph.CreateGraph(net_reader.getNetwork());
382             auto& nodes = graph.getNodes();
383             for (int i = 0; i < nodes.size(); i++) {
384                 if (nodes[i]->getType() == MKLDNNPlugin::SimplerNMS) {
385                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
386                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
387                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
388                     }
389                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
390                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
391                 }
392             }
393             InferenceEngine::SizeVector dims_src_cls = {p.in_cls.n, p.in_cls.c, p.in_cls.h, p.in_cls.w};
394
395             InferenceEngine::Blob::Ptr src_cls = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src_cls);
396             src_cls->allocate();
397             fill_data(src_cls->buffer(), src_cls->size());
398
399             InferenceEngine::TBlob<float>* srcClsPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_cls.get());
400
401             if (srcClsPtr == nullptr)
402                 FAIL() << "Cannot cast blob to TBlob<float>.";
403
404             InferenceEngine::SizeVector dims_delta = {p.in_delta.n, p.in_delta.c, p.in_delta.h, p.in_delta.w};
405
406             InferenceEngine::Blob::Ptr src_delta = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_delta);
407             src_delta->allocate();
408             fill_data(src_delta->buffer(), src_delta->size());
409
410             InferenceEngine::TBlob<float>* srcDeltaPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_delta.get());
411
412             if (srcDeltaPtr == nullptr)
413                 FAIL() << "Cannot cast blob to TBlob<float>.";
414
415             InferenceEngine::SizeVector dims_info = {p.in_info.n, p.in_info.c};
416
417             InferenceEngine::Blob::Ptr src_info = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NC, dims_info);
418             src_info->allocate();
419             fill_data(src_info->buffer(), src_info->size());
420             float * data_info = src_info->buffer();
421             data_info[0] = 20;
422             data_info[1] = 20;
423             data_info[2] = 3;
424
425             InferenceEngine::TBlob<float>* srcInfoPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_info.get());
426
427             if (srcInfoPtr == nullptr)
428                 FAIL() << "Cannot cast blob to TBlob<float>.";
429
430             InferenceEngine::BlobMap srcs;
431             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src_cls));
432             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src_delta));
433             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src_info));
434
435             InferenceEngine::OutputsDataMap out;
436             out = net_reader.getNetwork().getOutputsInfo();
437             InferenceEngine::BlobMap outputBlobs;
438
439             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
440
441             InferenceEngine::TBlob<float>::Ptr output;
442             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
443             output->allocate();
444             outputBlobs[item.first] = output;
445
446             graph.Infer(srcs, outputBlobs);
447
448             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
449             dst_ref.allocate();
450
451             ref_simplernms(*srcClsPtr, *srcDeltaPtr, *srcInfoPtr, dst_ref, p);
452
453             compare(*output, dst_ref);
454         } catch (const InferenceEngine::details::InferenceEngineException &e) {
455             FAIL() << e.what();
456         }
457     }
458 };
459
460 TEST_P(MKLDNNGraphSimplerNMSTests, TestsSimplerNMS) {}
461
462
463 INSTANTIATE_TEST_CASE_P(
464         DISABLED_TestsSimplerNMS, MKLDNNGraphSimplerNMSTests,
465         ::testing::Values(
466                 simplernms_test_params{{1, 18, 39, 64}, {1, 36, 39, 64}, {1, 3}, {150, 5}, 16, 16, 6000, 150, 0.7f, 1,
467                                        MKLDNNPlugin::impl_desc_type::ref, {
468                                          [](MKLDNNPlugin::PrimitiveDescInfo impl) {
469                                              ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
470                                              ASSERT_EQ(3, impl.getConfig().inConfs.size());
471                                              ASSERT_EQ(1, impl.getConfig().outConfs.size());
472                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
473                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(1).desc.getLayout());
474                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(2).desc.getLayout());
475                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
476                                          }
477                                  }}));