Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_roi_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "jit_uni_roi_pooling.hpp"
21 #include "type_helpers.hpp"
22 #include "utils.hpp"
23 #include "nstl.hpp"
24 #include "mkldnn_thread.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 template <cpu_isa_t isa>
31 void jit_uni_roi_pooling_fwd_t<isa>::execute_forward() const {
32     auto src_data = reinterpret_cast<const data_t *>(this->input_memory(0));
33     auto src_roi = reinterpret_cast<const data_t*>(this->input_memory(1));
34     auto dst = reinterpret_cast<data_t*>(this->memory(0));
35
36     const memory_desc_wrapper src_d(pd()->src_pd(0));
37     const memory_desc_wrapper src_roi_d(pd()->src_pd(1));
38     const memory_desc_wrapper dst_d(pd()->dst_pd());
39
40     const auto &jpp = pd()->jpp_;
41
42     int cb_work = utils::div_up(jpp.nb_c, jpp.nb_c_blocking);
43     int MB = jpp.mb;
44
45     int real_rois = 0;
46     for (; real_rois < MB; real_rois++) {
47         int roi_off;
48         if (src_roi_d.ndims() == 4) {
49             roi_off = src_roi_d.off(real_rois, 0, 0, 0);
50         } else {
51             roi_off = src_roi_d.off(real_rois, 0);
52         }
53
54         const data_t *src_roi_ptr = &src_roi[roi_off];
55         int roi_batch_ind = src_roi_ptr[0];
56         if (roi_batch_ind == -1) {
57             break;
58         }
59     }
60
61     const int work_amount = MB * cb_work * jpp.oh * jpp.ow;
62
63     auto ker = [&](const int ithr, const int nthr) {
64         int start{0}, end{0};
65         balance211(work_amount, nthr, ithr, start, end);
66
67         int n{0}, cbb{0}, oh{0}, ow{0};
68         utils::nd_iterator_init(start, n, MB, cbb, cb_work, oh, jpp.oh, ow, jpp.ow);
69
70         for (int iwork = start; iwork < end; iwork++) {
71             auto arg = jit_roi_pool_call_s();
72
73             int cb = cbb * jpp.nb_c_blocking;
74             int cb_num = jpp.nb_c_blocking;
75
76             arg.c_blocks = nstl::min(cb + cb_num, jpp.nb_c) - cb;
77
78             if (n >= real_rois) {
79                 arg.dst = &dst[dst_d.blk_off(n, cb, oh, ow)];
80                 arg.bin_area = 0;
81
82                 (*kernel_)(&arg);
83             } else {
84                 int roi_off;
85                 if(src_roi_d.ndims() == 4) {
86                     roi_off = src_roi_d.off((int)n, 0, 0, 0);
87                 }
88                 else {
89                     roi_off = src_roi_d.off((int)n, 0);
90                 }
91                 const data_t* src_roi_ptr = &src_roi[roi_off];
92
93                 int roi_batch_ind = src_roi_ptr[0];
94
95                 if (jpp.alg == mkldnn_roi_pooling_max) {
96                     int roi_start_w = round(src_roi_ptr[1] * jpp.spatial_scale);
97                     int roi_start_h = round(src_roi_ptr[2] * jpp.spatial_scale);
98                     int roi_end_w = round(src_roi_ptr[3] * jpp.spatial_scale);
99                     int roi_end_h = round(src_roi_ptr[4] * jpp.spatial_scale);
100
101                     int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
102                     int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
103
104
105                     int hstart = (oh * roi_height) / jpp.pooled_h;
106                     if ((hstart * jpp.pooled_h) > (oh * roi_height)) {
107                         --hstart;
108                     }
109
110                     int wstart = (ow * roi_width) / jpp.pooled_w;
111                     if ((wstart * jpp.pooled_w) > (ow * roi_width)) {
112                         --wstart;
113                     }
114
115                     int hend = ((oh + 1) * roi_height) / jpp.pooled_h;
116                     if ((hend * jpp.pooled_h) < ((oh + 1) * roi_height)) {
117                         ++hend;
118                     }
119
120                     int wend = ((ow + 1) * roi_width) / jpp.pooled_w;
121                     if ((wend * jpp.pooled_w) < ((ow + 1) * roi_width)) {
122                         ++wend;
123                     }
124
125                     hstart = std::min(std::max(hstart + roi_start_h, 0), jpp.ih);
126                     hend = std::min(std::max(hend + roi_start_h, 0), jpp.ih);
127                     wstart = std::min(std::max(wstart + roi_start_w, 0), jpp.iw);
128                     wend = std::min(std::max(wend + roi_start_w, 0), jpp.iw);
129
130                     arg.src = &src_data[src_d.blk_off(roi_batch_ind, cb, hstart, wstart)];
131                     arg.dst = &dst[dst_d.blk_off(n, cb, oh, ow)];
132
133                     arg.bin_area = (hend - hstart) * (wend - wstart);
134                     arg.kh = hend - hstart;
135                     arg.kw = wend - wstart;
136                 } else {
137                     float roi_start_w_ = src_roi_ptr[1];
138                     float roi_start_h_ = src_roi_ptr[2];
139                     float roi_end_w_   = src_roi_ptr[3];
140                     float roi_end_h_   = src_roi_ptr[4];
141
142                     float height_scale = ((roi_end_h_ - roi_start_h_) * (jpp.ih - 1)) / (jpp.pooled_h - 1);
143                     float width_scale  = ((roi_end_w_ - roi_start_w_) * (jpp.iw - 1)) / (jpp.pooled_w - 1);
144
145                     float in_y = (oh * height_scale + roi_start_h_ * (jpp.ih - 1));
146                     float in_x = (ow * width_scale  + roi_start_w_ * (jpp.iw - 1));
147
148                     arg.dst = &dst[dst_d.blk_off(n, cb, oh, ow)];
149                     if (in_y < 0 || in_y > jpp.ih - 1 || in_x < 0 || in_x > jpp.iw - 1) {
150                         arg.bin_area = 0;
151                     } else {
152                         int top_y_index    = static_cast<int>(floorf(in_y));
153                         int bottom_y_index = static_cast<int>(ceilf(in_y));
154                         int left_x_index   = static_cast<int>(floorf(in_x));
155                         int right_x_index  = static_cast<int>(ceilf(in_x));
156
157                         arg.xf = in_x - left_x_index;
158                         arg.yf = in_y - top_y_index;
159
160                         if (right_x_index > jpp.iw - 1)
161                             right_x_index = jpp.iw - 1;
162
163                         if (bottom_y_index > jpp.ih - 1)
164                             bottom_y_index = jpp.ih - 1;
165
166                         arg.xoff = (size_t)((right_x_index - left_x_index) * jpp.c_block * sizeof(float));
167                         arg.yoff = (size_t)((bottom_y_index - top_y_index) * jpp.iw * jpp.c_block * sizeof(float));
168
169                         arg.src = &src_data[src_d.blk_off(roi_batch_ind, cb, top_y_index, left_x_index)];
170                         arg.bin_area = 1;
171                     }
172                 }
173
174
175                 (*kernel_)(&arg);
176             }
177
178             utils::nd_iterator_step(n, MB, cbb, cb_work, oh, jpp.oh, ow, jpp.ow);
179         }
180     };
181
182     parallel(0, ker);
183 }
184
185 template struct jit_uni_roi_pooling_fwd_t<sse42>;
186 template struct jit_uni_roi_pooling_fwd_t<avx2>;
187 template struct jit_uni_roi_pooling_fwd_t<avx512_common>;
188
189 }
190 }
191 }