Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_roi_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19 #include <float.h>
20 #include <algorithm>
21 #include <mkldnn_types.h>
22 #include <iostream>
23
24 #include "c_types_map.hpp"
25 #include "type_helpers.hpp"
26
27 #include "ref_roi_pooling.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 template <impl::data_type_t data_type>
34 void ref_roi_pooling_fwd_t<data_type>::execute_forward_generic() const {
35     int roi_idx = 1;
36     int data_idx = 0;
37
38     const memory_desc_wrapper dst_d(pd()->dst_pd());
39     memory_desc_wrapper src_data_d = pd()->src_pd(data_idx);
40     memory_desc_wrapper src_roi_d = pd()->src_pd(roi_idx);
41
42     if (src_roi_d.dims()[0] < src_data_d.dims()[0]) {
43         roi_idx = 0;
44         data_idx = 1;
45
46         src_data_d = pd()->src_pd(data_idx);
47         src_roi_d = pd()->src_pd(roi_idx);
48     }
49
50     auto dst = reinterpret_cast<data_t*>(this->memory(0));
51     const data_t* src_data = reinterpret_cast<const data_t*>(this->input_memory(data_idx));
52     const data_t* src_roi = reinterpret_cast<const data_t*>(this->input_memory(roi_idx));
53
54     int C = src_data_d.dims()[1];
55     int H = src_data_d.dims()[2];
56     int W = src_data_d.dims()[3];
57
58     int ROIS = src_roi_d.dims()[0];
59
60     double spatial_scale = pd()->spatialScale();
61     int pooled_h = pd()->pooledH();
62     int pooled_w = pd()->pooledW();
63
64     for (size_t i = 0; i < dst_d.size() / sizeof(data_t); i++) {
65         dst[i] = -FLT_MAX;
66     }
67
68     int real_rois = 0;
69     for (; real_rois < ROIS; real_rois++) {
70         int roi_off;
71         if(src_roi_d.ndims() == 4) {
72             roi_off = src_roi_d.off(real_rois, 0, 0, 0);
73         }
74         else {
75             roi_off = src_roi_d.off(real_rois, 0);
76         }
77
78         const data_t* src_roi_ptr = &src_roi[roi_off];
79         int roi_batch_ind = src_roi_ptr[0];
80         if (roi_batch_ind == -1) {
81             break;
82         }
83     }
84     int n = 0;
85     for (; n < real_rois; ++n) {
86         int roi_off;
87         if(src_roi_d.ndims() == 4) {
88             roi_off = src_roi_d.off(n, 0, 0, 0);
89         }
90         else {
91             roi_off = src_roi_d.off(n, 0);
92         }
93
94         const data_t* src_roi_ptr = &src_roi[roi_off];
95         int roi_batch_ind = src_roi_ptr[0];
96
97         if (pd()->desc()->alg_kind == mkldnn_roi_pooling_max) {
98             int roi_start_w = round(src_roi_ptr[1] * spatial_scale);
99             int roi_start_h = round(src_roi_ptr[2] * spatial_scale);
100             int roi_end_w = round(src_roi_ptr[3] * spatial_scale);
101             int roi_end_h = round(src_roi_ptr[4] * spatial_scale);
102
103             int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
104             int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
105
106
107             for (int c = 0; c < C; ++c) {
108                 for (int ph = 0; ph < pooled_h; ++ph) {
109                     for (int pw = 0; pw < pooled_w; ++pw) {
110                         int hstart = (ph * roi_height) / pooled_h;
111                         if ((hstart * pooled_h) > (ph * roi_height)) {
112                             --hstart;
113                         }
114
115                         int wstart = (pw * roi_width) / pooled_w;
116                         if ((wstart * pooled_w) > (pw * roi_width)) {
117                             --wstart;
118                         }
119
120                         int hend = ((ph + 1) * roi_height) / pooled_h;
121                         if ((hend * pooled_h) < ((ph + 1) * roi_height)) {
122                             ++hend;
123                         }
124
125                         int wend = ((pw + 1) * roi_width) / pooled_w;
126                         if ((wend * pooled_w) < ((pw + 1) * roi_width)) {
127                             ++wend;
128                         }
129
130                         hstart = std::min(std::max(hstart + roi_start_h, 0), H);
131                         hend = std::min(std::max(hend + roi_start_h, 0), H);
132                         wstart = std::min(std::max(wstart + roi_start_w, 0), W);
133                         wend = std::min(std::max(wend + roi_start_w, 0), W);
134
135                         bool is_empty = (hend <= hstart) || (wend <= wstart);
136
137                         const int pool_index = dst_d.off(n, c, ph, pw);
138
139                         if (is_empty) {
140                             dst[pool_index] = 0;
141                         }
142
143                         for (int h = hstart; h < hend; ++h) {
144                             for (int w = wstart; w < wend; ++w) {
145                                 data_t batch_data = src_data[src_data_d.off(roi_batch_ind, c, h, w)];
146
147                                 if (batch_data > dst[pool_index]) {
148                                     dst[pool_index] = batch_data;
149                                 }
150                             }
151                         }
152                     }
153                 }
154             }
155         } else if (pd()->desc()->alg_kind == mkldnn_roi_pooling_bilinear) {
156             float roi_start_w_ = src_roi_ptr[1];
157             float roi_start_h_ = src_roi_ptr[2];
158             float roi_end_w_   = src_roi_ptr[3];
159             float roi_end_h_   = src_roi_ptr[4];
160
161             float height_scale = (roi_end_h_ - roi_start_h_) * (H - 1) / (pooled_h - 1);
162             float width_scale  = (roi_end_w_ - roi_start_w_) * (W - 1) / (pooled_w - 1);
163
164             for (int c = 0; c < C; ++c) {
165                 for (int ph = 0; ph < pooled_h; ++ph) {
166                     for (int pw = 0; pw < pooled_w; ++pw) {
167                         float in_y = (ph * height_scale + roi_start_h_ * (H - 1));
168                         float in_x = (pw * width_scale  + roi_start_w_ * (W - 1));
169
170                         if (in_y < 0 || in_y > H - 1 || in_x < 0 || in_x > W - 1) {
171                             dst[dst_d.off(n, c, ph, pw)] = 0;
172                         } else {
173                             int top_y_index    = static_cast<int>(floorf(in_y));
174                             int bottom_y_index = static_cast<int>(ceilf(in_y));
175                             int left_x_index   = static_cast<int>(floorf(in_x));
176                             int right_x_index  = static_cast<int>(ceilf(in_x));
177
178                             if (right_x_index > W - 1)
179                                 right_x_index = W - 1;
180
181                             if (bottom_y_index > H - 1)
182                                 bottom_y_index = H - 1;
183
184                             const float top_left     = src_data[src_data_d.off(roi_batch_ind, c, top_y_index, left_x_index)];
185                             const float top_right    = src_data[src_data_d.off(roi_batch_ind, c, top_y_index, right_x_index)];
186                             const float bottom_left  = src_data[src_data_d.off(roi_batch_ind, c, bottom_y_index, left_x_index)];
187                             const float bottom_right = src_data[src_data_d.off(roi_batch_ind, c, bottom_y_index, right_x_index)];
188
189                             const float top    = top_left + (top_right - top_left) * (in_x - left_x_index);
190                             const float bottom = bottom_left + (bottom_right - bottom_left) * (in_x - left_x_index);
191
192                             dst[dst_d.off(n, c, ph, pw)] = top + (bottom - top) * (in_y - top_y_index);
193                         }
194                     }
195                 }
196             }
197         }
198     }
199
200     for (; n < ROIS; ++n) {
201         for (int c = 0; c < C; ++c) {
202             for (int ph = 0; ph < pooled_h; ++ph) {
203                 for (int pw = 0; pw < pooled_w; ++pw) {
204                     dst[dst_d.off(n, c, ph, pw)] = 0;
205                 }
206             }
207         }
208     }
209 }
210
211 template struct ref_roi_pooling_fwd_t<data_type::f32>;
212
213 }
214 }
215 }
216
217 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s