Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_simplernms_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
14
15 using namespace ::testing;
16 using namespace std;
17 using namespace mkldnn;
18
19
20 struct simplernms_test_params {
21     struct {
22         size_t n;
23         size_t c;
24         size_t h;
25         size_t w;
26     } in_cls;
27
28     struct {
29         size_t n;
30         size_t c;
31         size_t h;
32         size_t w;
33     } in_delta;
34
35     struct {
36         size_t n;
37         size_t c;
38     } in_info;
39
40     struct {
41         size_t n;
42         size_t c;
43     } out;
44
45     size_t minBoxSize;
46     size_t featStride;
47     size_t preNmsTopn;
48     size_t postNmsTopn;
49     float iouThreshold;
50
51     size_t num_prim_desc;
52
53     MKLDNNPlugin::impl_desc_type selectedType;
54
55     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
56 };
57
58 struct anchor { float start_x; float start_y; float end_x; float end_y; };
59
60 template <typename data_t>
61 struct simpler_nms_roi_t
62 {
63     data_t x0, y0, x1, y1;
64
65     constexpr static inline const data_t clamp_v(const data_t v, const data_t v_min, const data_t v_max)
66     {
67         return (std::max)(v_min, (std::min)(v, v_max));
68     }
69
70     data_t area() const { return std::max<data_t>(0, y1 - y0 + 1) * std::max<data_t>(0, x1 - x0 + 1); }
71
72     simpler_nms_roi_t intersect (simpler_nms_roi_t other) const
73     {
74         return
75                 {
76                         (std::max)(x0, other.x0),
77                         (std::max)(y0, other.y0),
78                         (std::min)(x1, other.x1),
79                         (std::min)(y1, other.y1)
80                 };
81     }
82     simpler_nms_roi_t clamp (simpler_nms_roi_t other) const
83     {
84         return
85                 {
86                         clamp_v(x0, other.x0, other.x1),
87                         clamp_v(y0, other.y0, other.y1),
88                         clamp_v(x1, other.x0, other.x1),
89                         clamp_v(y1, other.y0, other.y1)
90                 };
91     }
92 };
93
94 template <typename data_t>
95 struct simpler_nms_proposal_t { simpler_nms_roi_t<data_t> roi; data_t confidence; size_t ord; };
96 template <typename data_t>
97 struct simpler_nms_delta_t { data_t shift_x, shift_y, log_w, log_h; };
98
99 template <typename data_t>
100 inline simpler_nms_roi_t<data_t> simpler_nms_gen_bbox(
101         const anchor& box,
102         const simpler_nms_delta_t<data_t>& delta,
103         int anchor_shift_x,
104         int anchor_shift_y)
105 {
106     auto anchor_w = box.end_x - box.start_x + 1;
107     auto anchor_h = box.end_y - box.start_y + 1;
108     auto center_x = box.start_x + anchor_w * .5f;
109     auto center_y = box.start_y + anchor_h *.5f;
110
111     data_t pred_center_x = delta.shift_x * anchor_w + center_x + anchor_shift_x;
112     data_t pred_center_y = delta.shift_y * anchor_h + center_y + anchor_shift_y;
113     data_t half_pred_w = exp(delta.log_w) * anchor_w * .5f;
114     data_t half_pred_h = exp(delta.log_h) * anchor_h * .5f;
115
116     return { pred_center_x - half_pred_w,
117              pred_center_y - half_pred_h,
118              pred_center_x + half_pred_w,
119              pred_center_y + half_pred_h };
120 }
121 template <typename data_t>
122 inline void sort_and_keep_at_most_top_n(std::vector<simpler_nms_proposal_t<data_t>>& proposals, size_t top_n)
123 {
124     const auto cmp_fn = [](const simpler_nms_proposal_t<data_t>& a,
125                            const simpler_nms_proposal_t<data_t>& b)
126     {
127         return a.confidence > b.confidence || (a.confidence == b.confidence && a.ord > b.ord);
128     };
129
130     if (proposals.size() > top_n) {
131         std::partial_sort(proposals.begin(), proposals.begin() + top_n, proposals.end(), cmp_fn);
132         proposals.resize(top_n);
133     }
134     else {
135         std::sort(proposals.begin(), proposals.end(), cmp_fn);
136     }
137 }
138
139 template <typename data_t>
140 std::vector<simpler_nms_roi_t<data_t>> simpler_nms_perform_nms(const std::vector<simpler_nms_proposal_t<data_t>>& proposals,
141                                                        float iou_threshold, size_t top_n) {
142     //TODO(ruv): can I mark the 1st arg, proposals as const? ifndef DONT_PRECALC_AREA, i can
143     //TODO(ruv): is it better to do the precalc or not? since we need to fetch the floats from memory anyway for -
144     //           intersect calc, it's only a question of whether it's faster to do (f-f)*(f-f) or fetch another val
145 #define DONT_PRECALC_AREA
146
147 #ifndef DONT_PRECALC_AREA
148     std::vector<Dtype> areas;
149     areas.reserve(proposals.size());
150     std::transform(proposals.begin(), proposals.end(), areas.begin(), [](const simpler_nms_proposals_t>& v)
151                    {
152                        return v.roi.area();
153                    });
154 #endif
155
156     std::vector<simpler_nms_roi_t<data_t>> res;
157     res.reserve(top_n);
158 #ifdef DONT_PRECALC_AREA
159     for (const auto & prop : proposals) {
160         const auto bbox = prop.roi;
161         const data_t area = bbox.area();
162 #else
163         size_t proposal_count = proposals.size();
164         for (size_t proposalIndex = 0; proposalIndex < proposal_count; ++proposalIndex) {
165             const auto & bbox = proposals[proposalIndex].roi;
166 #endif
167
168         // For any realistic WL, this condition is true for all top_n values anyway
169         if (prop.confidence > 0) {
170             bool overlaps = std::any_of(res.begin(), res.end(), [&](const simpler_nms_roi_t<data_t>& res_bbox)
171             {
172                 data_t interArea = bbox.intersect(res_bbox).area();
173 #ifdef DONT_PRECALC_AREA
174                 data_t unionArea = res_bbox.area() + area - interArea;
175 #else
176                 data_t unionArea = res_bbox.area() + areas[proposalIndex] - interArea;
177 #endif
178                 return interArea > iou_threshold * unionArea;
179             });
180
181             if (! overlaps) {
182                 res.push_back(bbox);
183                 if (res.size() == top_n) break;
184             }
185         }
186     }
187
188     return res;
189 }
190
191 template <typename data_t>
192 void ref_simplernms(const InferenceEngine::TBlob<data_t> &src_cls, const InferenceEngine::TBlob<data_t> &src_delta, const InferenceEngine::TBlob<data_t> &src_info, InferenceEngine::TBlob<data_t> &dst_blob, simplernms_test_params prm) {
193     int anchors_num = 3 * 3;
194     data_t *anchors_ = new data_t[anchors_num * sizeof(anchor) / sizeof(float)];
195     const anchor* anchors = (anchor*)anchors_;
196
197     int H = src_cls.dims()[2];
198     int W = src_cls.dims()[3];
199
200     int SZ = H * W;
201
202     data_t* dst = dst_blob.data();
203
204     const data_t* cls_scores = src_cls.readOnly();
205     const data_t* delta_pred = src_delta.readOnly();
206     const data_t* im_info = src_info.readOnly();
207
208     int IW = im_info[0];
209     int IH = im_info[1];
210     int IS = im_info[2];
211
212     int scaled_min_bbox_size = prm.minBoxSize * IS;
213
214     std::vector<simpler_nms_proposal_t<data_t>> sorted_proposals_confidence;
215
216     for (auto y = 0; y < H; ++y)
217     {
218         int anchor_shift_y = y * prm.featStride;
219
220         for (auto x = 0; x < W; ++x) {
221             int anchor_shift_x = x * prm.featStride;
222             int location_index = y * W + x;
223
224             // we assume proposals are grouped by window location
225             for (int anchor_index = 0; anchor_index < anchors_num ; anchor_index++) {
226                 data_t dx0 = delta_pred[location_index + SZ * (anchor_index * 4 + 0)];
227                 data_t dy0 = delta_pred[location_index + SZ * (anchor_index * 4 + 1)];
228                 data_t dx1 = delta_pred[location_index + SZ * (anchor_index * 4 + 2)];
229                 data_t dy1 = delta_pred[location_index + SZ * (anchor_index * 4 + 3)];
230
231                 simpler_nms_delta_t<data_t> bbox_delta { dx0, dy0, dx1, dy1 };
232
233                 data_t proposal_confidence = cls_scores[location_index + SZ * (anchor_index + anchors_num * 1)];
234
235                 simpler_nms_roi_t<data_t> tmp_roi = simpler_nms_gen_bbox(anchors[anchor_index], bbox_delta, anchor_shift_x, anchor_shift_y);
236                 simpler_nms_roi_t<data_t> roi = tmp_roi.clamp({ 0, 0, data_t(IW - 1), data_t(IH - 1) });
237
238                 int bbox_w = roi.x1 - roi.x0 + 1;
239                 int bbox_h = roi.y1 - roi.y0 + 1;
240
241                 if (bbox_w >= scaled_min_bbox_size && bbox_h >= scaled_min_bbox_size) {
242                     simpler_nms_proposal_t<data_t> proposal { roi, proposal_confidence, sorted_proposals_confidence.size() };
243                     sorted_proposals_confidence.push_back(proposal);
244                 }
245             }
246         }
247     }
248
249     sort_and_keep_at_most_top_n(sorted_proposals_confidence, prm.preNmsTopn);
250     auto res = simpler_nms_perform_nms(sorted_proposals_confidence, prm.iouThreshold, prm.postNmsTopn);
251
252     size_t res_num_rois = res.size();
253
254     for (size_t i = 0; i < res_num_rois; ++i) {
255         dst[5 * i + 0] = 0;    // roi_batch_ind, always zero on test time
256         dst[5 * i + 1] = res[i].x0;
257         dst[5 * i + 2] = res[i].y0;
258         dst[5 * i + 3] = res[i].x1;
259         dst[5 * i + 4] = res[i].y1;
260     }
261
262     delete[] anchors_;
263 }
264
265 class MKLDNNGraphSimplerNMSTests: public TestsCommon,
266                                      public WithParamInterface<simplernms_test_params> {
267     std::string model_t = R"V0G0N(
268 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
269     <layers>
270         <layer name="in1" type="Input" precision="FP32" id="0">
271             <output>
272                 <port id="0">
273                     <dim>_INC_</dim>
274                     <dim>_ICC_</dim>
275                     <dim>_IHC_</dim>
276                     <dim>_IWC_</dim>
277                 </port>
278             </output>
279         </layer>
280         <layer name="in2" type="Input" precision="FP32" id="1">
281             <output>
282                 <port id="1">
283                     <dim>_IND_</dim>
284                     <dim>_ICD_</dim>
285                     <dim>_IHD_</dim>
286                     <dim>_IWD_</dim>
287                 </port>
288             </output>
289         </layer>
290         <layer name="in3" type="Input" precision="FP32" id="2">
291             <output>
292                 <port id="2">
293                     <dim>_INI_</dim>
294                     <dim>_ICI_</dim>
295                 </port>
296             </output>
297         </layer>
298         <layer name="proposal" id="3" type="SimplerNMS" precision="FP32">
299             <data cls_threshold="0.500000" max_num_proposals="300" iou_threshold="_IOU_THRESHOLD_"
300             min_bbox_size="_MIN_BOX_SIZE_" feat_stride="_FSRD_" pre_nms_topn="_PRENT_" post_nms_topn="_POSTNT_"
301             scale="8.000000,16.000000,32.000000"/>
302
303             <input>
304                 <port id="3">
305                     <dim>_INC_</dim>
306                     <dim>_ICC_</dim>
307                     <dim>_IHC_</dim>
308                     <dim>_IWC_</dim>
309                 </port>
310                 <port id="4">
311                     <dim>_IND_</dim>
312                     <dim>_ICD_</dim>
313                     <dim>_IHD_</dim>
314                     <dim>_IWD_</dim>
315                 </port>
316                 <port id="5">
317                     <dim>_INI_</dim>
318                     <dim>_ICI_</dim>
319                 </port>
320             </input>
321             <output>
322                 <port id="6">
323                     <dim>_ON_</dim>
324                     <dim>_OC_</dim>
325                 </port>
326             </output>
327         </layer>
328     </layers>
329     <edges>
330         <edge from-layer="0" from-port="0" to-layer="3" to-port="3"/>
331         <edge from-layer="1" from-port="1" to-layer="3" to-port="4"/>
332         <edge from-layer="2" from-port="2" to-layer="3" to-port="5"/>
333     </edges>
334 </Net>
335 )V0G0N";
336
337     std::string getModel(simplernms_test_params p) {
338         std::string model = model_t;
339
340         REPLACE_WITH_NUM(model, "_IWC_", p.in_cls.w);
341         REPLACE_WITH_NUM(model, "_IHC_", p.in_cls.h);
342         REPLACE_WITH_NUM(model, "_ICC_", p.in_cls.c);
343         REPLACE_WITH_NUM(model, "_INC_", p.in_cls.n);
344
345         REPLACE_WITH_NUM(model, "_IWD_", p.in_delta.w);
346         REPLACE_WITH_NUM(model, "_IHD_", p.in_delta.h);
347         REPLACE_WITH_NUM(model, "_ICD_", p.in_delta.c);
348         REPLACE_WITH_NUM(model, "_IND_", p.in_delta.n);
349
350         REPLACE_WITH_NUM(model, "_ICI_", p.in_info.c);
351         REPLACE_WITH_NUM(model, "_INI_", p.in_info.n);
352
353         REPLACE_WITH_NUM(model, "_OC_", p.out.c);
354         REPLACE_WITH_NUM(model, "_ON_", p.out.n);
355
356         REPLACE_WITH_NUM(model, "_MIN_BOX_SIZE_", p.minBoxSize);
357         REPLACE_WITH_NUM(model, "_FSRD_", p.featStride);
358         REPLACE_WITH_NUM(model, "_PRENT_", p.preNmsTopn);
359         REPLACE_WITH_NUM(model, "_POSTNT_", p.postNmsTopn);
360         REPLACE_WITH_NUM(model, "_IOU_THRESHOLD_", p.iouThreshold);
361
362         return model;
363     }
364
365 protected:
366     virtual void TearDown() {
367     }
368
369     virtual void SetUp() {
370         try {
371             TestsCommon::SetUp();
372             simplernms_test_params p = ::testing::WithParamInterface<simplernms_test_params>::GetParam();
373             std::string model = getModel(p);
374
375             InferenceEngine::CNNNetReader net_reader;
376             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
377
378             MKLDNNGraphTestClass graph;
379             graph.CreateGraph(net_reader.getNetwork());
380             auto& nodes = graph.getNodes();
381             for (int i = 0; i < nodes.size(); i++) {
382                 if (nodes[i]->getType() == MKLDNNPlugin::SimplerNMS) {
383                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
384                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
385                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
386                     }
387                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
388                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
389                 }
390             }
391             InferenceEngine::SizeVector dims_src_cls = {p.in_cls.n, p.in_cls.c, p.in_cls.h, p.in_cls.w};
392
393             InferenceEngine::Blob::Ptr src_cls = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src_cls);
394             src_cls->allocate();
395             fill_data(src_cls->buffer(), src_cls->size());
396
397             InferenceEngine::TBlob<float>* srcClsPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_cls.get());
398
399             if (srcClsPtr == nullptr)
400                 FAIL() << "Cannot cast blob to TBlob<float>.";
401
402             InferenceEngine::SizeVector dims_delta = {p.in_delta.n, p.in_delta.c, p.in_delta.h, p.in_delta.w};
403
404             InferenceEngine::Blob::Ptr src_delta = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_delta);
405             src_delta->allocate();
406             fill_data(src_delta->buffer(), src_delta->size());
407
408             InferenceEngine::TBlob<float>* srcDeltaPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_delta.get());
409
410             if (srcDeltaPtr == nullptr)
411                 FAIL() << "Cannot cast blob to TBlob<float>.";
412
413             InferenceEngine::SizeVector dims_info = {p.in_info.n, p.in_info.c};
414
415             InferenceEngine::Blob::Ptr src_info = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NC, dims_info);
416             src_info->allocate();
417             fill_data(src_info->buffer(), src_info->size());
418             float * data_info = src_info->buffer();
419             data_info[0] = 20;
420             data_info[1] = 20;
421             data_info[2] = 3;
422
423             InferenceEngine::TBlob<float>* srcInfoPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_info.get());
424
425             if (srcInfoPtr == nullptr)
426                 FAIL() << "Cannot cast blob to TBlob<float>.";
427
428             InferenceEngine::BlobMap srcs;
429             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src_cls));
430             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src_delta));
431             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src_info));
432
433             InferenceEngine::OutputsDataMap out;
434             out = net_reader.getNetwork().getOutputsInfo();
435             InferenceEngine::BlobMap outputBlobs;
436
437             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
438
439             InferenceEngine::TBlob<float>::Ptr output;
440             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
441             output->allocate();
442             outputBlobs[item.first] = output;
443
444             graph.Infer(srcs, outputBlobs);
445
446             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
447             dst_ref.allocate();
448
449             ref_simplernms(*srcClsPtr, *srcDeltaPtr, *srcInfoPtr, dst_ref, p);
450
451             compare(*output, dst_ref);
452         } catch (const InferenceEngine::details::InferenceEngineException &e) {
453             FAIL() << e.what();
454         }
455     }
456 };
457
458 TEST_P(MKLDNNGraphSimplerNMSTests, TestsSimplerNMS) {}
459
460
461 INSTANTIATE_TEST_CASE_P(
462         DISABLED_TestsSimplerNMS, MKLDNNGraphSimplerNMSTests,
463         ::testing::Values(
464                 simplernms_test_params{{1, 18, 39, 64}, {1, 36, 39, 64}, {1, 3}, {150, 5}, 16, 16, 6000, 150, 0.7f, 1,
465                                        MKLDNNPlugin::impl_desc_type::ref, {
466                                          [](MKLDNNPlugin::PrimitiveDescInfo impl) {
467                                              ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
468                                              ASSERT_EQ(3, impl.getConfig().inConfs.size());
469                                              ASSERT_EQ(1, impl.getConfig().outConfs.size());
470                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
471                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(1).desc.getLayout());
472                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(2).desc.getLayout());
473                                              ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
474                                          }
475                                  }}));