Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_roi_pooling_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include "tests_common.hpp"
16
17
18 using namespace ::testing;
19 using namespace std;
20 using namespace mkldnn;
21
22
23 struct roi_pooling_test_params {
24     struct {
25         size_t n;
26         size_t c;
27         size_t h;
28         size_t w;
29     } in1;
30
31     struct {
32         size_t n;
33         size_t c;
34     } in2;
35
36     size_t pooled_h;
37     size_t pooled_w;
38     float spatial_scale;
39
40     size_t num_prim_desc;
41
42     int selectedType;
43
44     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
45 };
46
47 template <typename data_t>
48 void ref_roipooling(const InferenceEngine::TBlob<data_t> &src, const InferenceEngine::TBlob<data_t> &roi,
49                     InferenceEngine::TBlob<data_t> &dst_blob, roi_pooling_test_params& params) {
50     data_t* dst = dst_blob.data();
51     const data_t* src_data = src.readOnly();
52     const data_t* src_roi = roi.readOnly();
53
54     int C = src.dims()[1];
55     int H = src.dims()[2];
56     int W = src.dims()[3];
57
58     int ROIS = roi.dims()[0];
59
60     double spatial_scale = params.spatial_scale;
61     int pooled_h = params.pooled_h;
62     int pooled_w = params.pooled_w;
63
64     data_t *arg_max_ = new data_t[dst_blob.size()];
65
66     for (size_t i = 0; i < dst_blob.size(); i++) {
67         arg_max_[i] = -1;
68         dst[i] = -FLT_MAX;
69     }
70
71     int roi_off;
72
73     for (int n = 0; n < ROIS; ++n) {
74         if(roi.dims().size() == 4) {
75             roi_off = n*roi.dims()[1]*roi.dims()[2]*roi.dims()[3];
76         }
77         else {
78             roi_off = n*roi.dims()[1];
79         }
80
81         const data_t* src_roi_ptr = &src_roi[roi_off];
82
83         int roi_batch_ind = src_roi_ptr[0];
84         int roi_start_w = round(src_roi_ptr[1] * spatial_scale);
85         int roi_start_h = round(src_roi_ptr[2] * spatial_scale);
86         int roi_end_w = round(src_roi_ptr[3] * spatial_scale);
87         int roi_end_h = round(src_roi_ptr[4] * spatial_scale);
88
89         int roi_height = (std::max)(roi_end_h - roi_start_h + 1, 1);
90         int roi_width = (std::max)(roi_end_w - roi_start_w + 1, 1);
91
92         for (int c = 0; c < C; ++c) {
93
94             for (int ph = 0; ph < pooled_h; ++ph) {
95                 for (int pw = 0; pw < pooled_w; ++pw) {
96                     int hstart = (ph * roi_height) / pooled_h;
97                     if ( (hstart * pooled_h) > (ph * roi_height) ) {
98                         --hstart;
99                     }
100
101                     int wstart = (pw * roi_width) / pooled_w;
102                     if ( (wstart * pooled_w) > (pw * roi_width) ) {
103                         --wstart;
104                     }
105
106                     int hend = ((ph + 1) * roi_height) / pooled_h;
107                     if ( (hend * pooled_h) < ((ph + 1) * roi_height) ) {
108                         ++hend;
109                     }
110
111                     int wend = ((pw + 1) * roi_width) / pooled_w;
112                     if ( (wend * pooled_w) < ((pw + 1) * roi_width) ) {
113                         ++wend;
114                     }
115
116                     hstart = (std::min)((std::max)(hstart + roi_start_h, 0), H);
117                     hend = (std::min)((std::max)(hend + roi_start_h, 0), H);
118                     wstart = (std::min)((std::max)(wstart + roi_start_w, 0), W);
119                     wend = (std::min)((std::max)(wend + roi_start_w, 0), W);
120
121                     bool is_empty = (hend <= hstart) || (wend <= wstart);
122
123                     const int pool_index = n*dst_blob.dims()[2]*dst_blob.dims()[1]*dst_blob.dims()[0] +
124                             c*dst_blob.dims()[1]*dst_blob.dims()[0] + ph*dst_blob.dims()[0] + pw;
125
126                     if (is_empty) {
127                         dst[pool_index] = 0;
128                         arg_max_[pool_index] = -1;
129                     }
130
131                     for (int h = hstart; h < hend; ++h) {
132                         for (int w = wstart; w < wend; ++w) {
133                             int src_index_data = roi_batch_ind*src.dims()[1]*src.dims()[2]*src.dims()[3] +
134                                                  c*src.dims()[2]*src.dims()[3] + h*src.dims()[3] + w;
135                             data_t batch_data = src_data[src_index_data];
136
137                             if (batch_data > dst[pool_index]) {
138                                 dst[pool_index] = batch_data;
139                                 arg_max_[pool_index] = batch_data;
140                             }
141                         }
142                     }
143                 }
144             }
145         }
146     }
147     delete[] arg_max_;
148 }
149
150 class MKLDNNGraphRoiPoolingTests: public TestsCommon,
151                                      public WithParamInterface<roi_pooling_test_params> {
152     std::string model_t = R"V0G0N(
153 <Net Name="ROIPooling_Only" version="2" precision="FP32" batch="1">
154     <layers>
155         <layer name="in1" type="Input" precision="FP32" id="0">
156             <output>
157                 <port id="0">
158                     <dim>_IN1_</dim>
159                     <dim>_IC1_</dim>
160                     <dim>_IH1_</dim>
161                     <dim>_IW1_</dim>
162                 </port>
163             </output>
164         </layer>
165         <layer name="in2" type="Input" precision="FP32" id="1">
166             <output>
167                 <port id="1">
168                     <dim>_IN2_</dim>
169                     <dim>_IC2_</dim>
170                 </port>
171             </output>
172         </layer>
173         <layer name="roi_pool" id="2" type="ROIPooling" precision="FP32">
174             <data pooled_h="_PH_" pooled_w="_PW_" spatial_scale="_SS_"/>
175             <input>
176                 <port id="2">
177                     <dim>_IN1_</dim>
178                     <dim>_IC1_</dim>
179                     <dim>_IH1_</dim>
180                     <dim>_IW1_</dim>
181                 </port>
182                 <port id="3">
183                     <dim>_IN2_</dim>
184                     <dim>_IC2_</dim>
185                 </port>
186             </input>
187             <output>
188                 <port id="4">
189                     <dim>_ON_</dim>
190                     <dim>_OC_</dim>
191                     <dim>_OH_</dim>
192                     <dim>_OW_</dim>
193                 </port>
194             </output>
195         </layer>
196     </layers>
197     <edges>
198         <edge from-layer="0" from-port="0" to-layer="2" to-port="2"/>
199         <edge from-layer="1" from-port="1" to-layer="2" to-port="3"/>
200     </edges>
201 </Net>
202 )V0G0N";
203
204     std::string getModel(roi_pooling_test_params p) {
205         std::string model = model_t;
206
207         REPLACE_WITH_NUM(model, "_IW1_", p.in1.w);
208         REPLACE_WITH_NUM(model, "_IH1_", p.in1.h);
209         REPLACE_WITH_NUM(model, "_IC1_", p.in1.c);
210         REPLACE_WITH_NUM(model, "_IN1_", p.in1.n);
211
212         REPLACE_WITH_NUM(model, "_IC2_", p.in2.c);
213         REPLACE_WITH_NUM(model, "_IN2_", p.in2.n);
214
215         REPLACE_WITH_NUM(model, "_OW_", p.pooled_w);
216         REPLACE_WITH_NUM(model, "_OH_", p.pooled_h);
217         REPLACE_WITH_NUM(model, "_OC_", (std::max)(p.in1.c, p.in2.c));
218         REPLACE_WITH_NUM(model, "_ON_", (std::max)(p.in1.n, p.in2.n));
219
220         REPLACE_WITH_NUM(model, "_PH_", p.pooled_h);
221         REPLACE_WITH_NUM(model, "_PW_", p.pooled_w);
222         REPLACE_WITH_NUM(model, "_SS_", p.spatial_scale);
223
224         return model;
225     }
226
227 protected:
228     virtual void TearDown() {
229     }
230
231     virtual void SetUp() {
232         try {
233             TestsCommon::SetUp();
234             roi_pooling_test_params p = ::testing::WithParamInterface<roi_pooling_test_params>::GetParam();
235             std::string model = getModel(p);
236
237             InferenceEngine::CNNNetReader net_reader;
238             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
239
240             MKLDNNGraphTestClass graph;
241             graph.CreateGraph(net_reader.getNetwork());
242             auto& nodes = graph.getNodes();
243             for (int i = 0; i < nodes.size(); i++) {
244                 if (nodes[i]->getType() == MKLDNNPlugin::ROIPooling) {
245                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
246                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
247                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
248                     }
249                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
250                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
251                 }
252             }
253             InferenceEngine::SizeVector dims_src = {p.in1.n, p.in1.c, p.in1.h, p.in1.w};
254
255             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
256             src->allocate();
257             fill_data(src->buffer(), src->size());
258
259             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
260
261             if (srcPtr == nullptr)
262                 FAIL() << "Cannot cast blob to TBlob<float>.";
263
264             InferenceEngine::SizeVector dims_roi = {p.in2.n, p.in2.c};
265
266             InferenceEngine::Blob::Ptr roi = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NC, dims_roi);
267             roi->allocate();
268             fill_data(roi->buffer(), roi->size());
269
270             InferenceEngine::TBlob<float>* roiPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(roi.get());
271
272             if (roiPtr == nullptr)
273                 FAIL() << "Cannot cast blob to TBlob<float>.";
274
275             InferenceEngine::BlobMap srcs;
276             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
277             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", roi));
278
279             InferenceEngine::OutputsDataMap out;
280             out = net_reader.getNetwork().getOutputsInfo();
281             InferenceEngine::BlobMap outputBlobs;
282
283             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
284
285             InferenceEngine::TBlob<float>::Ptr output;
286             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
287             output->allocate();
288             outputBlobs[item.first] = output;
289
290             graph.Infer(srcs, outputBlobs);
291
292             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
293             dst_ref.allocate();
294
295             ref_roipooling(*srcPtr, *roiPtr, dst_ref, p);
296
297             compare(*output, dst_ref);
298         } catch (const InferenceEngine::details::InferenceEngineException &e) {
299             FAIL() << e.what();
300         }
301     }
302 };
303
304 TEST_P(MKLDNNGraphRoiPoolingTests, TestsRoiPooling) {}
305
306
307 INSTANTIATE_TEST_CASE_P(
308         TestsRoiPooling, MKLDNNGraphRoiPoolingTests,
309         ::testing::Values(
310                 roi_pooling_test_params{
311                         {1, 256, 39, 64}, {150, 5}, 6, 6, 0.0625f, 5, MKLDNNPlugin::impl_desc_type::jit}));