1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
15 using namespace ::testing;
17 using namespace mkldnn;
20 struct simplernms_test_params {
53 MKLDNNPlugin::impl_desc_type selectedType;
55 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
58 struct anchor { float start_x; float start_y; float end_x; float end_y; };
60 template <typename data_t>
61 struct simpler_nms_roi_t
63 data_t x0, y0, x1, y1;
65 constexpr static inline const data_t clamp_v(const data_t v, const data_t v_min, const data_t v_max)
67 return (std::max)(v_min, (std::min)(v, v_max));
70 data_t area() const { return std::max<data_t>(0, y1 - y0 + 1) * std::max<data_t>(0, x1 - x0 + 1); }
72 simpler_nms_roi_t intersect (simpler_nms_roi_t other) const
76 (std::max)(x0, other.x0),
77 (std::max)(y0, other.y0),
78 (std::min)(x1, other.x1),
79 (std::min)(y1, other.y1)
82 simpler_nms_roi_t clamp (simpler_nms_roi_t other) const
86 clamp_v(x0, other.x0, other.x1),
87 clamp_v(y0, other.y0, other.y1),
88 clamp_v(x1, other.x0, other.x1),
89 clamp_v(y1, other.y0, other.y1)
94 template <typename data_t>
95 struct simpler_nms_proposal_t { simpler_nms_roi_t<data_t> roi; data_t confidence; size_t ord; };
96 template <typename data_t>
97 struct simpler_nms_delta_t { data_t shift_x, shift_y, log_w, log_h; };
99 template <typename data_t>
100 inline simpler_nms_roi_t<data_t> simpler_nms_gen_bbox(
102 const simpler_nms_delta_t<data_t>& delta,
106 auto anchor_w = box.end_x - box.start_x + 1;
107 auto anchor_h = box.end_y - box.start_y + 1;
108 auto center_x = box.start_x + anchor_w * .5f;
109 auto center_y = box.start_y + anchor_h *.5f;
111 data_t pred_center_x = delta.shift_x * anchor_w + center_x + anchor_shift_x;
112 data_t pred_center_y = delta.shift_y * anchor_h + center_y + anchor_shift_y;
113 data_t half_pred_w = exp(delta.log_w) * anchor_w * .5f;
114 data_t half_pred_h = exp(delta.log_h) * anchor_h * .5f;
116 return { pred_center_x - half_pred_w,
117 pred_center_y - half_pred_h,
118 pred_center_x + half_pred_w,
119 pred_center_y + half_pred_h };
121 template <typename data_t>
122 inline void sort_and_keep_at_most_top_n(std::vector<simpler_nms_proposal_t<data_t>>& proposals, size_t top_n)
124 const auto cmp_fn = [](const simpler_nms_proposal_t<data_t>& a,
125 const simpler_nms_proposal_t<data_t>& b)
127 return a.confidence > b.confidence || (a.confidence == b.confidence && a.ord > b.ord);
130 if (proposals.size() > top_n) {
131 std::partial_sort(proposals.begin(), proposals.begin() + top_n, proposals.end(), cmp_fn);
132 proposals.resize(top_n);
135 std::sort(proposals.begin(), proposals.end(), cmp_fn);
139 template <typename data_t>
140 std::vector<simpler_nms_roi_t<data_t>> simpler_nms_perform_nms(const std::vector<simpler_nms_proposal_t<data_t>>& proposals,
141 float iou_threshold, size_t top_n) {
142 //TODO(ruv): can I mark the 1st arg, proposals as const? ifndef DONT_PRECALC_AREA, i can
143 //TODO(ruv): is it better to do the precalc or not? since we need to fetch the floats from memory anyway for -
144 // intersect calc, it's only a question of whether it's faster to do (f-f)*(f-f) or fetch another val
145 #define DONT_PRECALC_AREA
147 #ifndef DONT_PRECALC_AREA
148 std::vector<Dtype> areas;
149 areas.reserve(proposals.size());
150 std::transform(proposals.begin(), proposals.end(), areas.begin(), [](const simpler_nms_proposals_t>& v)
156 std::vector<simpler_nms_roi_t<data_t>> res;
158 #ifdef DONT_PRECALC_AREA
159 for (const auto & prop : proposals) {
160 const auto bbox = prop.roi;
161 const data_t area = bbox.area();
163 size_t proposal_count = proposals.size();
164 for (size_t proposalIndex = 0; proposalIndex < proposal_count; ++proposalIndex) {
165 const auto & bbox = proposals[proposalIndex].roi;
168 // For any realistic WL, this condition is true for all top_n values anyway
169 if (prop.confidence > 0) {
170 bool overlaps = std::any_of(res.begin(), res.end(), [&](const simpler_nms_roi_t<data_t>& res_bbox)
172 data_t interArea = bbox.intersect(res_bbox).area();
173 #ifdef DONT_PRECALC_AREA
174 data_t unionArea = res_bbox.area() + area - interArea;
176 data_t unionArea = res_bbox.area() + areas[proposalIndex] - interArea;
178 return interArea > iou_threshold * unionArea;
183 if (res.size() == top_n) break;
191 template <typename data_t>
192 void ref_simplernms(const InferenceEngine::TBlob<data_t> &src_cls, const InferenceEngine::TBlob<data_t> &src_delta, const InferenceEngine::TBlob<data_t> &src_info, InferenceEngine::TBlob<data_t> &dst_blob, simplernms_test_params prm) {
193 int anchors_num = 3 * 3;
194 data_t *anchors_ = new data_t[anchors_num * sizeof(anchor) / sizeof(float)];
195 const anchor* anchors = (anchor*)anchors_;
197 int H = src_cls.dims()[2];
198 int W = src_cls.dims()[3];
202 data_t* dst = dst_blob.data();
204 const data_t* cls_scores = src_cls.readOnly();
205 const data_t* delta_pred = src_delta.readOnly();
206 const data_t* im_info = src_info.readOnly();
212 int scaled_min_bbox_size = prm.minBoxSize * IS;
214 std::vector<simpler_nms_proposal_t<data_t>> sorted_proposals_confidence;
216 for (auto y = 0; y < H; ++y)
218 int anchor_shift_y = y * prm.featStride;
220 for (auto x = 0; x < W; ++x) {
221 int anchor_shift_x = x * prm.featStride;
222 int location_index = y * W + x;
224 // we assume proposals are grouped by window location
225 for (int anchor_index = 0; anchor_index < anchors_num ; anchor_index++) {
226 data_t dx0 = delta_pred[location_index + SZ * (anchor_index * 4 + 0)];
227 data_t dy0 = delta_pred[location_index + SZ * (anchor_index * 4 + 1)];
228 data_t dx1 = delta_pred[location_index + SZ * (anchor_index * 4 + 2)];
229 data_t dy1 = delta_pred[location_index + SZ * (anchor_index * 4 + 3)];
231 simpler_nms_delta_t<data_t> bbox_delta { dx0, dy0, dx1, dy1 };
233 data_t proposal_confidence = cls_scores[location_index + SZ * (anchor_index + anchors_num * 1)];
235 simpler_nms_roi_t<data_t> tmp_roi = simpler_nms_gen_bbox(anchors[anchor_index], bbox_delta, anchor_shift_x, anchor_shift_y);
236 simpler_nms_roi_t<data_t> roi = tmp_roi.clamp({ 0, 0, data_t(IW - 1), data_t(IH - 1) });
238 int bbox_w = roi.x1 - roi.x0 + 1;
239 int bbox_h = roi.y1 - roi.y0 + 1;
241 if (bbox_w >= scaled_min_bbox_size && bbox_h >= scaled_min_bbox_size) {
242 simpler_nms_proposal_t<data_t> proposal { roi, proposal_confidence, sorted_proposals_confidence.size() };
243 sorted_proposals_confidence.push_back(proposal);
249 sort_and_keep_at_most_top_n(sorted_proposals_confidence, prm.preNmsTopn);
250 auto res = simpler_nms_perform_nms(sorted_proposals_confidence, prm.iouThreshold, prm.postNmsTopn);
252 size_t res_num_rois = res.size();
254 for (size_t i = 0; i < res_num_rois; ++i) {
255 dst[5 * i + 0] = 0; // roi_batch_ind, always zero on test time
256 dst[5 * i + 1] = res[i].x0;
257 dst[5 * i + 2] = res[i].y0;
258 dst[5 * i + 3] = res[i].x1;
259 dst[5 * i + 4] = res[i].y1;
265 class MKLDNNGraphSimplerNMSTests: public TestsCommon,
266 public WithParamInterface<simplernms_test_params> {
267 std::string model_t = R"V0G0N(
268 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
270 <layer name="in1" type="Input" precision="FP32" id="0">
280 <layer name="in2" type="Input" precision="FP32" id="1">
290 <layer name="in3" type="Input" precision="FP32" id="2">
298 <layer name="proposal" id="3" type="SimplerNMS" precision="FP32">
299 <data cls_threshold="0.500000" max_num_proposals="300" iou_threshold="_IOU_THRESHOLD_"
300 min_bbox_size="_MIN_BOX_SIZE_" feat_stride="_FSRD_" pre_nms_topn="_PRENT_" post_nms_topn="_POSTNT_"
301 scale="8.000000,16.000000,32.000000"/>
330 <edge from-layer="0" from-port="0" to-layer="3" to-port="3"/>
331 <edge from-layer="1" from-port="1" to-layer="3" to-port="4"/>
332 <edge from-layer="2" from-port="2" to-layer="3" to-port="5"/>
337 std::string getModel(simplernms_test_params p) {
338 std::string model = model_t;
340 REPLACE_WITH_NUM(model, "_IWC_", p.in_cls.w);
341 REPLACE_WITH_NUM(model, "_IHC_", p.in_cls.h);
342 REPLACE_WITH_NUM(model, "_ICC_", p.in_cls.c);
343 REPLACE_WITH_NUM(model, "_INC_", p.in_cls.n);
345 REPLACE_WITH_NUM(model, "_IWD_", p.in_delta.w);
346 REPLACE_WITH_NUM(model, "_IHD_", p.in_delta.h);
347 REPLACE_WITH_NUM(model, "_ICD_", p.in_delta.c);
348 REPLACE_WITH_NUM(model, "_IND_", p.in_delta.n);
350 REPLACE_WITH_NUM(model, "_ICI_", p.in_info.c);
351 REPLACE_WITH_NUM(model, "_INI_", p.in_info.n);
353 REPLACE_WITH_NUM(model, "_OC_", p.out.c);
354 REPLACE_WITH_NUM(model, "_ON_", p.out.n);
356 REPLACE_WITH_NUM(model, "_MIN_BOX_SIZE_", p.minBoxSize);
357 REPLACE_WITH_NUM(model, "_FSRD_", p.featStride);
358 REPLACE_WITH_NUM(model, "_PRENT_", p.preNmsTopn);
359 REPLACE_WITH_NUM(model, "_POSTNT_", p.postNmsTopn);
360 REPLACE_WITH_NUM(model, "_IOU_THRESHOLD_", p.iouThreshold);
366 virtual void TearDown() {
369 virtual void SetUp() {
371 TestsCommon::SetUp();
372 simplernms_test_params p = ::testing::WithParamInterface<simplernms_test_params>::GetParam();
373 std::string model = getModel(p);
375 InferenceEngine::CNNNetReader net_reader;
376 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
378 MKLDNNGraphTestClass graph;
379 graph.CreateGraph(net_reader.getNetwork());
380 auto& nodes = graph.getNodes();
381 for (int i = 0; i < nodes.size(); i++) {
382 if (nodes[i]->getType() == MKLDNNPlugin::SimplerNMS) {
383 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
384 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
385 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
387 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
388 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
391 InferenceEngine::SizeVector dims_src_cls = {p.in_cls.n, p.in_cls.c, p.in_cls.h, p.in_cls.w};
393 InferenceEngine::Blob::Ptr src_cls = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src_cls);
395 fill_data(src_cls->buffer(), src_cls->size());
397 InferenceEngine::TBlob<float>* srcClsPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_cls.get());
399 if (srcClsPtr == nullptr)
400 FAIL() << "Cannot cast blob to TBlob<float>.";
402 InferenceEngine::SizeVector dims_delta = {p.in_delta.n, p.in_delta.c, p.in_delta.h, p.in_delta.w};
404 InferenceEngine::Blob::Ptr src_delta = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_delta);
405 src_delta->allocate();
406 fill_data(src_delta->buffer(), src_delta->size());
408 InferenceEngine::TBlob<float>* srcDeltaPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_delta.get());
410 if (srcDeltaPtr == nullptr)
411 FAIL() << "Cannot cast blob to TBlob<float>.";
413 InferenceEngine::SizeVector dims_info = {p.in_info.n, p.in_info.c};
415 InferenceEngine::Blob::Ptr src_info = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NC, dims_info);
416 src_info->allocate();
417 fill_data(src_info->buffer(), src_info->size());
418 float * data_info = src_info->buffer();
423 InferenceEngine::TBlob<float>* srcInfoPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src_info.get());
425 if (srcInfoPtr == nullptr)
426 FAIL() << "Cannot cast blob to TBlob<float>.";
428 InferenceEngine::BlobMap srcs;
429 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src_cls));
430 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src_delta));
431 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src_info));
433 InferenceEngine::OutputsDataMap out;
434 out = net_reader.getNetwork().getOutputsInfo();
435 InferenceEngine::BlobMap outputBlobs;
437 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
439 InferenceEngine::TBlob<float>::Ptr output;
440 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
442 outputBlobs[item.first] = output;
444 graph.Infer(srcs, outputBlobs);
446 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
449 ref_simplernms(*srcClsPtr, *srcDeltaPtr, *srcInfoPtr, dst_ref, p);
451 compare(*output, dst_ref);
452 } catch (const InferenceEngine::details::InferenceEngineException &e) {
458 TEST_P(MKLDNNGraphSimplerNMSTests, TestsSimplerNMS) {}
461 INSTANTIATE_TEST_CASE_P(
462 DISABLED_TestsSimplerNMS, MKLDNNGraphSimplerNMSTests,
464 simplernms_test_params{{1, 18, 39, 64}, {1, 36, 39, 64}, {1, 3}, {150, 5}, 16, 16, 6000, 150, 0.7f, 1,
465 MKLDNNPlugin::impl_desc_type::ref, {
466 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
467 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
468 ASSERT_EQ(3, impl.getConfig().inConfs.size());
469 ASSERT_EQ(1, impl.getConfig().outConfs.size());
470 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
471 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(1).desc.getLayout());
472 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(2).desc.getLayout());
473 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());