1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "jit_uni_roi_pooling.hpp"
21 #include "type_helpers.hpp"
24 #include "mkldnn_thread.hpp"
30 template <cpu_isa_t isa>
31 void jit_uni_roi_pooling_fwd_t<isa>::execute_forward() const {
32 auto src_data = reinterpret_cast<const data_t *>(this->input_memory(0));
33 auto src_roi = reinterpret_cast<const data_t*>(this->input_memory(1));
34 auto dst = reinterpret_cast<data_t*>(this->memory(0));
36 const memory_desc_wrapper src_d(pd()->src_pd(0));
37 const memory_desc_wrapper src_roi_d(pd()->src_pd(1));
38 const memory_desc_wrapper dst_d(pd()->dst_pd());
40 const auto &jpp = pd()->jpp_;
42 int cb_work = utils::div_up(jpp.nb_c, jpp.nb_c_blocking);
46 for (; real_rois < MB; real_rois++) {
48 if (src_roi_d.ndims() == 4) {
49 roi_off = src_roi_d.off(real_rois, 0, 0, 0);
51 roi_off = src_roi_d.off(real_rois, 0);
54 const data_t *src_roi_ptr = &src_roi[roi_off];
55 int roi_batch_ind = src_roi_ptr[0];
56 if (roi_batch_ind == -1) {
61 const int work_amount = MB * cb_work * jpp.oh * jpp.ow;
63 auto ker = [&](const int ithr, const int nthr) {
65 balance211(work_amount, nthr, ithr, start, end);
67 int n{0}, cbb{0}, oh{0}, ow{0};
68 utils::nd_iterator_init(start, n, MB, cbb, cb_work, oh, jpp.oh, ow, jpp.ow);
70 for (int iwork = start; iwork < end; iwork++) {
71 auto arg = jit_roi_pool_call_s();
73 int cb = cbb * jpp.nb_c_blocking;
74 int cb_num = jpp.nb_c_blocking;
76 arg.c_blocks = nstl::min(cb + cb_num, jpp.nb_c) - cb;
79 arg.dst = &dst[dst_d.blk_off(n, cb, oh, ow)];
85 if(src_roi_d.ndims() == 4) {
86 roi_off = src_roi_d.off((int)n, 0, 0, 0);
89 roi_off = src_roi_d.off((int)n, 0);
91 const data_t* src_roi_ptr = &src_roi[roi_off];
93 int roi_batch_ind = src_roi_ptr[0];
95 if (jpp.alg == mkldnn_roi_pooling_max) {
96 int roi_start_w = round(src_roi_ptr[1] * jpp.spatial_scale);
97 int roi_start_h = round(src_roi_ptr[2] * jpp.spatial_scale);
98 int roi_end_w = round(src_roi_ptr[3] * jpp.spatial_scale);
99 int roi_end_h = round(src_roi_ptr[4] * jpp.spatial_scale);
101 int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
102 int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
105 int hstart = (oh * roi_height) / jpp.pooled_h;
106 if ((hstart * jpp.pooled_h) > (oh * roi_height)) {
110 int wstart = (ow * roi_width) / jpp.pooled_w;
111 if ((wstart * jpp.pooled_w) > (ow * roi_width)) {
115 int hend = ((oh + 1) * roi_height) / jpp.pooled_h;
116 if ((hend * jpp.pooled_h) < ((oh + 1) * roi_height)) {
120 int wend = ((ow + 1) * roi_width) / jpp.pooled_w;
121 if ((wend * jpp.pooled_w) < ((ow + 1) * roi_width)) {
125 hstart = std::min(std::max(hstart + roi_start_h, 0), jpp.ih);
126 hend = std::min(std::max(hend + roi_start_h, 0), jpp.ih);
127 wstart = std::min(std::max(wstart + roi_start_w, 0), jpp.iw);
128 wend = std::min(std::max(wend + roi_start_w, 0), jpp.iw);
130 arg.src = &src_data[src_d.blk_off(roi_batch_ind, cb, hstart, wstart)];
131 arg.dst = &dst[dst_d.blk_off(n, cb, oh, ow)];
133 arg.bin_area = (hend - hstart) * (wend - wstart);
134 arg.kh = hend - hstart;
135 arg.kw = wend - wstart;
137 float roi_start_w_ = src_roi_ptr[1];
138 float roi_start_h_ = src_roi_ptr[2];
139 float roi_end_w_ = src_roi_ptr[3];
140 float roi_end_h_ = src_roi_ptr[4];
142 float height_scale = ((roi_end_h_ - roi_start_h_) * (jpp.ih - 1)) / (jpp.pooled_h - 1);
143 float width_scale = ((roi_end_w_ - roi_start_w_) * (jpp.iw - 1)) / (jpp.pooled_w - 1);
145 float in_y = (oh * height_scale + roi_start_h_ * (jpp.ih - 1));
146 float in_x = (ow * width_scale + roi_start_w_ * (jpp.iw - 1));
148 arg.dst = &dst[dst_d.blk_off(n, cb, oh, ow)];
149 if (in_y < 0 || in_y > jpp.ih - 1 || in_x < 0 || in_x > jpp.iw - 1) {
152 int top_y_index = static_cast<int>(floorf(in_y));
153 int bottom_y_index = static_cast<int>(ceilf(in_y));
154 int left_x_index = static_cast<int>(floorf(in_x));
155 int right_x_index = static_cast<int>(ceilf(in_x));
157 arg.xf = in_x - left_x_index;
158 arg.yf = in_y - top_y_index;
160 if (right_x_index > jpp.iw - 1)
161 right_x_index = jpp.iw - 1;
163 if (bottom_y_index > jpp.ih - 1)
164 bottom_y_index = jpp.ih - 1;
166 arg.xoff = (size_t)((right_x_index - left_x_index) * jpp.c_block * sizeof(float));
167 arg.yoff = (size_t)((bottom_y_index - top_y_index) * jpp.iw * jpp.c_block * sizeof(float));
169 arg.src = &src_data[src_d.blk_off(roi_batch_ind, cb, top_y_index, left_x_index)];
178 utils::nd_iterator_step(n, MB, cbb, cb_work, oh, jpp.oh, ow, jpp.ow);
185 template struct jit_uni_roi_pooling_fwd_t<sse42>;
186 template struct jit_uni_roi_pooling_fwd_t<avx2>;
187 template struct jit_uni_roi_pooling_fwd_t<avx512_common>;