Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_roi_pooling_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
14
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20
21 struct roi_pooling_test_params {
22     struct {
23         size_t n;
24         size_t c;
25         size_t h;
26         size_t w;
27     } in1;
28
29     struct {
30         size_t n;
31         size_t c;
32     } in2;
33
34     size_t pooled_h;
35     size_t pooled_w;
36     float spatial_scale;
37
38     size_t num_prim_desc;
39
40     int selectedType;
41
42     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
43 };
44
45 template <typename data_t>
46 void ref_roipooling(const InferenceEngine::TBlob<data_t> &src, const InferenceEngine::TBlob<data_t> &roi,
47                     InferenceEngine::TBlob<data_t> &dst_blob, roi_pooling_test_params& params) {
48     data_t* dst = dst_blob.data();
49     const data_t* src_data = src.readOnly();
50     const data_t* src_roi = roi.readOnly();
51
52     int C = src.dims()[1];
53     int H = src.dims()[2];
54     int W = src.dims()[3];
55
56     int ROIS = roi.dims()[0];
57
58     double spatial_scale = params.spatial_scale;
59     int pooled_h = params.pooled_h;
60     int pooled_w = params.pooled_w;
61
62     data_t *arg_max_ = new data_t[dst_blob.size()];
63
64     for (size_t i = 0; i < dst_blob.size(); i++) {
65         arg_max_[i] = -1;
66         dst[i] = -FLT_MAX;
67     }
68
69     int roi_off;
70
71     for (int n = 0; n < ROIS; ++n) {
72         if(roi.dims().size() == 4) {
73             roi_off = n*roi.dims()[1]*roi.dims()[2]*roi.dims()[3];
74         }
75         else {
76             roi_off = n*roi.dims()[1];
77         }
78
79         const data_t* src_roi_ptr = &src_roi[roi_off];
80
81         int roi_batch_ind = src_roi_ptr[0];
82         int roi_start_w = round(src_roi_ptr[1] * spatial_scale);
83         int roi_start_h = round(src_roi_ptr[2] * spatial_scale);
84         int roi_end_w = round(src_roi_ptr[3] * spatial_scale);
85         int roi_end_h = round(src_roi_ptr[4] * spatial_scale);
86
87         int roi_height = (std::max)(roi_end_h - roi_start_h + 1, 1);
88         int roi_width = (std::max)(roi_end_w - roi_start_w + 1, 1);
89
90         for (int c = 0; c < C; ++c) {
91
92             for (int ph = 0; ph < pooled_h; ++ph) {
93                 for (int pw = 0; pw < pooled_w; ++pw) {
94                     int hstart = (ph * roi_height) / pooled_h;
95                     if ( (hstart * pooled_h) > (ph * roi_height) ) {
96                         --hstart;
97                     }
98
99                     int wstart = (pw * roi_width) / pooled_w;
100                     if ( (wstart * pooled_w) > (pw * roi_width) ) {
101                         --wstart;
102                     }
103
104                     int hend = ((ph + 1) * roi_height) / pooled_h;
105                     if ( (hend * pooled_h) < ((ph + 1) * roi_height) ) {
106                         ++hend;
107                     }
108
109                     int wend = ((pw + 1) * roi_width) / pooled_w;
110                     if ( (wend * pooled_w) < ((pw + 1) * roi_width) ) {
111                         ++wend;
112                     }
113
114                     hstart = (std::min)((std::max)(hstart + roi_start_h, 0), H);
115                     hend = (std::min)((std::max)(hend + roi_start_h, 0), H);
116                     wstart = (std::min)((std::max)(wstart + roi_start_w, 0), W);
117                     wend = (std::min)((std::max)(wend + roi_start_w, 0), W);
118
119                     bool is_empty = (hend <= hstart) || (wend <= wstart);
120
121                     const int pool_index = n*dst_blob.dims()[2]*dst_blob.dims()[1]*dst_blob.dims()[0] +
122                             c*dst_blob.dims()[1]*dst_blob.dims()[0] + ph*dst_blob.dims()[0] + pw;
123
124                     if (is_empty) {
125                         dst[pool_index] = 0;
126                         arg_max_[pool_index] = -1;
127                     }
128
129                     for (int h = hstart; h < hend; ++h) {
130                         for (int w = wstart; w < wend; ++w) {
131                             int src_index_data = roi_batch_ind*src.dims()[1]*src.dims()[2]*src.dims()[3] +
132                                                  c*src.dims()[2]*src.dims()[3] + h*src.dims()[3] + w;
133                             data_t batch_data = src_data[src_index_data];
134
135                             if (batch_data > dst[pool_index]) {
136                                 dst[pool_index] = batch_data;
137                                 arg_max_[pool_index] = batch_data;
138                             }
139                         }
140                     }
141                 }
142             }
143         }
144     }
145     delete[] arg_max_;
146 }
147
148 class MKLDNNGraphRoiPoolingTests: public TestsCommon,
149                                      public WithParamInterface<roi_pooling_test_params> {
150     std::string model_t = R"V0G0N(
151 <Net Name="ROIPooling_Only" version="2" precision="FP32" batch="1">
152     <layers>
153         <layer name="in1" type="Input" precision="FP32" id="0">
154             <output>
155                 <port id="0">
156                     <dim>_IN1_</dim>
157                     <dim>_IC1_</dim>
158                     <dim>_IH1_</dim>
159                     <dim>_IW1_</dim>
160                 </port>
161             </output>
162         </layer>
163         <layer name="in2" type="Input" precision="FP32" id="1">
164             <output>
165                 <port id="1">
166                     <dim>_IN2_</dim>
167                     <dim>_IC2_</dim>
168                 </port>
169             </output>
170         </layer>
171         <layer name="roi_pool" id="2" type="ROIPooling" precision="FP32">
172             <data pooled_h="_PH_" pooled_w="_PW_" spatial_scale="_SS_"/>
173             <input>
174                 <port id="2">
175                     <dim>_IN1_</dim>
176                     <dim>_IC1_</dim>
177                     <dim>_IH1_</dim>
178                     <dim>_IW1_</dim>
179                 </port>
180                 <port id="3">
181                     <dim>_IN2_</dim>
182                     <dim>_IC2_</dim>
183                 </port>
184             </input>
185             <output>
186                 <port id="4">
187                     <dim>_ON_</dim>
188                     <dim>_OC_</dim>
189                     <dim>_OH_</dim>
190                     <dim>_OW_</dim>
191                 </port>
192             </output>
193         </layer>
194     </layers>
195     <edges>
196         <edge from-layer="0" from-port="0" to-layer="2" to-port="2"/>
197         <edge from-layer="1" from-port="1" to-layer="2" to-port="3"/>
198     </edges>
199 </Net>
200 )V0G0N";
201
202     std::string getModel(roi_pooling_test_params p) {
203         std::string model = model_t;
204
205         REPLACE_WITH_NUM(model, "_IW1_", p.in1.w);
206         REPLACE_WITH_NUM(model, "_IH1_", p.in1.h);
207         REPLACE_WITH_NUM(model, "_IC1_", p.in1.c);
208         REPLACE_WITH_NUM(model, "_IN1_", p.in1.n);
209
210         REPLACE_WITH_NUM(model, "_IC2_", p.in2.c);
211         REPLACE_WITH_NUM(model, "_IN2_", p.in2.n);
212
213         REPLACE_WITH_NUM(model, "_OW_", p.pooled_w);
214         REPLACE_WITH_NUM(model, "_OH_", p.pooled_h);
215         REPLACE_WITH_NUM(model, "_OC_", (std::max)(p.in1.c, p.in2.c));
216         REPLACE_WITH_NUM(model, "_ON_", (std::max)(p.in1.n, p.in2.n));
217
218         REPLACE_WITH_NUM(model, "_PH_", p.pooled_h);
219         REPLACE_WITH_NUM(model, "_PW_", p.pooled_w);
220         REPLACE_WITH_NUM(model, "_SS_", p.spatial_scale);
221
222         return model;
223     }
224
225 protected:
226     virtual void TearDown() {
227     }
228
229     virtual void SetUp() {
230         try {
231             TestsCommon::SetUp();
232             roi_pooling_test_params p = ::testing::WithParamInterface<roi_pooling_test_params>::GetParam();
233             std::string model = getModel(p);
234
235             InferenceEngine::CNNNetReader net_reader;
236             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
237
238             MKLDNNGraphTestClass graph;
239             graph.CreateGraph(net_reader.getNetwork());
240             auto& nodes = graph.getNodes();
241             for (int i = 0; i < nodes.size(); i++) {
242                 if (nodes[i]->getType() == MKLDNNPlugin::ROIPooling) {
243                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
244                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
245                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
246                     }
247                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
248                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
249                 }
250             }
251             InferenceEngine::SizeVector dims_src = {p.in1.n, p.in1.c, p.in1.h, p.in1.w};
252
253             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
254             src->allocate();
255             fill_data(src->buffer(), src->size());
256
257             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
258
259             if (srcPtr == nullptr)
260                 FAIL() << "Cannot cast blob to TBlob<float>.";
261
262             InferenceEngine::SizeVector dims_roi = {p.in2.n, p.in2.c};
263
264             InferenceEngine::Blob::Ptr roi = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NC, dims_roi);
265             roi->allocate();
266             fill_data(roi->buffer(), roi->size());
267
268             InferenceEngine::TBlob<float>* roiPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(roi.get());
269
270             if (roiPtr == nullptr)
271                 FAIL() << "Cannot cast blob to TBlob<float>.";
272
273             InferenceEngine::BlobMap srcs;
274             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
275             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", roi));
276
277             InferenceEngine::OutputsDataMap out;
278             out = net_reader.getNetwork().getOutputsInfo();
279             InferenceEngine::BlobMap outputBlobs;
280
281             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
282
283             InferenceEngine::TBlob<float>::Ptr output;
284             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
285             output->allocate();
286             outputBlobs[item.first] = output;
287
288             graph.Infer(srcs, outputBlobs);
289
290             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
291             dst_ref.allocate();
292
293             ref_roipooling(*srcPtr, *roiPtr, dst_ref, p);
294
295             compare(*output, dst_ref);
296         } catch (const InferenceEngine::details::InferenceEngineException &e) {
297             FAIL() << e.what();
298         }
299     }
300 };
301
302 TEST_P(MKLDNNGraphRoiPoolingTests, TestsRoiPooling) {}
303
304
305 INSTANTIATE_TEST_CASE_P(
306         TestsRoiPooling, MKLDNNGraphRoiPoolingTests,
307         ::testing::Values(
308                 roi_pooling_test_params{
309                         {1, 256, 39, 64}, {150, 5}, 6, 6, 0.0625f, 5, MKLDNNPlugin::impl_desc_type::jit}));